Source code for nx_arangodb.classes.dict.adj

from __future__ import annotations

import warnings
from collections import UserDict
from collections.abc import Iterator
from itertools import islice
from typing import Any, Callable, Dict, List, Union

from arango.database import StandardDatabase
from arango.exceptions import DocumentDeleteError
from arango.graph import Graph
from phenolrs.networkx.typings import (
    DiGraphAdjDict,
    GraphAdjDict,
    MultiDiGraphAdjDict,
    MultiGraphAdjDict,
    NodeDict,
)

from nx_arangodb.exceptions import EdgeTypeAmbiguity, MultipleEdgesFound
from nx_arangodb.logger import logger

from ..enum import DIRECTED_GRAPH_TYPES, MULTIGRAPH_TYPES, GraphType, TraversalDirection
from ..function import (
    ArangoDBBatchError,
    aql,
    aql_doc_get_key,
    aql_doc_has_key,
    aql_edge_count_src,
    aql_edge_count_src_dst,
    aql_edge_exists,
    aql_edge_get,
    aql_edge_id,
    aql_fetch_data_edge,
    check_update_list_for_errors,
    doc_insert,
    doc_update,
    edge_get,
    edge_link,
    get_arangodb_graph,
    get_node_id,
    get_node_type,
    get_node_type_and_id,
    get_update_dict,
    json_serializable,
    key_is_adb_id_or_int,
    key_is_not_reserved,
    key_is_string,
    keys_are_not_reserved,
    keys_are_strings,
    separate_edges_by_collections,
    upsert_collection_edges,
)

AdjDict = Union[GraphAdjDict, DiGraphAdjDict, MultiGraphAdjDict, MultiDiGraphAdjDict]

#############
# Factories #
#############


def edge_attr_dict_factory(
    db: StandardDatabase, graph: Graph
) -> Callable[..., EdgeAttrDict]:
    """Factory function for creating an EdgeAttrDict."""
    return lambda: EdgeAttrDict(db, graph)


def edge_key_dict_factory(
    db: StandardDatabase,
    graph: Graph,
    edge_type_key: str,
    edge_type_func: Callable[[str, str], str],
    is_directed: bool,
    adjlist_inner_dict: AdjListInnerDict | None = None,
) -> Callable[..., EdgeKeyDict]:
    """Factory function for creating an EdgeKeyDict."""
    return lambda: EdgeKeyDict(
        db, graph, edge_type_key, edge_type_func, is_directed, adjlist_inner_dict
    )


def adjlist_inner_dict_factory(
    db: StandardDatabase,
    graph: Graph,
    default_node_type: str,
    edge_type_key: str,
    edge_type_func: Callable[[str, str], str],
    graph_type: str,
    adjlist_outer_dict: AdjListOuterDict | None = None,
) -> Callable[..., AdjListInnerDict]:
    """Factory function for creating an AdjListInnerDict."""
    return lambda: AdjListInnerDict(
        db,
        graph,
        default_node_type,
        edge_type_key,
        edge_type_func,
        graph_type,
        adjlist_outer_dict,
    )


def adjlist_outer_dict_factory(
    db: StandardDatabase,
    graph: Graph,
    default_node_type: str,
    read_parallelism: int,
    read_batch_size: int,
    edge_type_key: str,
    edge_type_func: Callable[[str, str], str],
    graph_type: str,
    symmetrize_edges_if_directed: bool,
) -> Callable[..., AdjListOuterDict]:
    """Factory function for creating an AdjListOuterDict."""
    return lambda: AdjListOuterDict(
        db,
        graph,
        default_node_type,
        read_parallelism,
        read_batch_size,
        edge_type_key,
        edge_type_func,
        graph_type,
        symmetrize_edges_if_directed,
    )


#############
# Adjacency #
#############


def build_edge_attr_dict_data(
    parent: EdgeAttrDict, data: dict[str, Any]
) -> dict[str, Any | EdgeAttrDict]:
    """Recursively build an EdgeAttrDict from a dict.

    It's possible that **value** is a nested dict, so we need to
    recursively build a EdgeAttrDict for each nested dict.

    Parameters
    ----------
    parent : EdgeAttrDict
        The parent EdgeAttrDict.
    data : dict[str, Any]
        The data to build the EdgeAttrDict from.

    Returns
    -------
    dict[str, Any | EdgeAttrDict]
        The data for the new EdgeAttrDict.
    """
    edge_attr_dict_data = {}
    for key, value in data.items():
        edge_attr_dict_value = process_edge_attr_dict_value(parent, key, value)
        edge_attr_dict_data[key] = edge_attr_dict_value

    return edge_attr_dict_data


def process_edge_attr_dict_value(parent: EdgeAttrDict, key: str, value: Any) -> Any:
    """Process the value of a particular key in an EdgeAttrDict.

    If the value is a dict, then we need to recursively build an EdgeAttrDict.
    Otherwise, we return the value as is.

    Parameters
    ----------
    parent : EdgeAttrDict
        The parent EdgeAttrDict.
    key : str
        The key of the value.
    value : Any
        The value to process.

    Returns
    -------
    Any
        The processed value.
    """
    if not isinstance(value, dict):
        return value

    edge_attr_dict = parent.edge_attr_dict_factory()
    edge_attr_dict.edge_id = parent.edge_id
    edge_attr_dict.parent_keys = parent.parent_keys + [key]
    edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, value)

    return edge_attr_dict


[docs] @json_serializable class EdgeAttrDict(UserDict[str, Any]): """The innermost-level of the dict of dict (of dict) of dict structure representing the Adjacency List of a graph. EdgeAttrDict is keyed by the edge attribute key. Parameters ---------- db : arango.database.StandardDatabase The ArangoDB database. graph : arango.graph.Graph The ArangoDB graph. Examples -------- >>> g = nxadb.Graph(name="MyGraph") >>> g.add_edge("node/1", "node/2", foo="bar") >>> g["node/1"]["node/2"] EdgeAttrDict({'foo': 'bar', '_key': ..., '_id': ...}) """ def __init__( self, db: StandardDatabase, graph: Graph, *args: Any, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) self.data: dict[str, Any] = {} self.db = db self.graph = graph self.edge_id: str | None = None # established in __setitem__ # EdgeAttrDict may be a child of another EdgeAttrDict # e.g G._adj['node/1']['node/2']['object']['foo'] = 'bar' # In this case, **parent_keys** would be ['object'] self.parent_keys: list[str] = [] self.edge_attr_dict_factory = edge_attr_dict_factory(self.db, self.graph) def clear(self) -> None: raise NotImplementedError("Cannot clear EdgeAttrDict") def copy(self) -> Any: return { key: value.copy() if hasattr(value, "copy") else value for key, value in self.data.items() } @key_is_string def __contains__(self, key: str) -> bool: """'foo' in G._adj['node/1']['node/2']""" if key in self.data: return True assert self.edge_id return aql_doc_has_key(self.db, self.edge_id, key, self.parent_keys) @key_is_string def __getitem__(self, key: str) -> Any: """G._adj['node/1']['node/2']['foo']""" if key in self.data: return self.data[key] assert self.edge_id result = aql_doc_get_key(self.db, self.edge_id, key, self.parent_keys) if result is None: raise KeyError(key) edge_attr_dict_value = process_edge_attr_dict_value(self, key, result) self.data[key] = edge_attr_dict_value return edge_attr_dict_value @key_is_string @key_is_not_reserved # @value_is_json_serializable # TODO? def __setitem__(self, key: str, value: Any) -> None: """G._adj['node/1']['node/2']['foo'] = 'bar'""" if value is None: self.__delitem__(key) return assert self.edge_id edge_attr_dict_value = process_edge_attr_dict_value(self, key, value) update_dict = get_update_dict(self.parent_keys, {key: value}) self.data[key] = edge_attr_dict_value doc_update(self.db, self.edge_id, update_dict) @key_is_string @key_is_not_reserved def __delitem__(self, key: str) -> None: """del G._adj['node/1']['node/2']['foo']""" assert self.edge_id self.data.pop(key, None) update_dict = get_update_dict(self.parent_keys, {key: None}) doc_update(self.db, self.edge_id, update_dict) @keys_are_strings @keys_are_not_reserved def update(self, attrs: Any) -> None: """G._adj['node/1']['node/'2].update({'foo': 'bar'})""" if not attrs: return self.data.update(build_edge_attr_dict_data(self, attrs)) if not self.edge_id: logger.debug("Edge ID not set, skipping EdgeAttrDict(?).update()") return update_dict = get_update_dict(self.parent_keys, attrs) doc_update(self.db, self.edge_id, update_dict)
[docs] class EdgeKeyDict(UserDict[str, EdgeAttrDict]): """The (optional) 3rd level of the dict of dict (*of dict*) of dict structure representing the Adjacency List of a MultiGraph. EdgeKeyDict is keyed by ArangoDB Edge IDs. Unique to MultiGraphs, edges are keyed by ArangoDB Edge IDs, allowing for multiple edges between the same nodes. Alternatively, if an Edge is already fetched, then it can also be keyed by a numerical index. However, this is not recommended because consistent ordering of edges is not guaranteed. ASSUMPTIONS (for now): - keys must be ArangoDB Edge IDs - key-to-edge mapping is 1-to-1 Parameters ---------- db : arango.database.StandardDatabase The ArangoDB database. graph : arango.graph.Graph The ArangoDB graph. edge_type_key : str The key used to store the edge type in the edge attribute dictionary. edge_type_func : Callable[[str, str], str] The function to generate the edge type from the source and destination node types. is_directed : bool Whether the graph is directed or not. adjlist_inner_dict : AdjListInnerDict | None The parent AdjListInnerDict. Examples -------- >>> g = nxadb.MultiGraph(name="MyGraph") >>> edge_id = g.add_edge("node/1", "node/2", foo="bar") >>> g["node/1"]["node/2"][edge_id] EdgeAttrDict({'foo': 'bar', '_key': ..., '_id': ...}) """ def __init__( self, db: StandardDatabase, graph: Graph, edge_type_key: str, edge_type_func: Callable[[str, str], str], is_directed: bool, adjlist_inner_dict: AdjListInnerDict | None = None, *args: Any, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) self.data: dict[str, EdgeAttrDict] = {} self.is_directed = is_directed self.db = db self.graph = graph self.edge_type_key = edge_type_key self.edge_type_func = edge_type_func self._default_edge_type: str | None = None self.graph_name = graph.name self.edge_attr_dict_factory = edge_attr_dict_factory(self.db, self.graph) self.src_node_id: str | None = None self.dst_node_id: str | None = None self.adjlist_inner_dict = adjlist_inner_dict self.FETCHED_ALL_DATA = False self.FETCHED_ALL_IDS = False self.traversal_direction = ( adjlist_inner_dict.traversal_direction if adjlist_inner_dict is not None else ( TraversalDirection.OUTBOUND if self.is_directed else TraversalDirection.ANY ) ) edge_validation_methods = { TraversalDirection.OUTBOUND: self.__is_valid_edge_outbound, TraversalDirection.INBOUND: self.__is_valid_edge_inbound, TraversalDirection.ANY: self.__is_valid_edge_any, } self.__is_valid_edge = edge_validation_methods[self.traversal_direction] @property def default_edge_type(self) -> str: if self._default_edge_type is None: assert self.src_node_id assert self.dst_node_id src_node_type = self.src_node_id.split("/")[0] dst_node_type = self.dst_node_id.split("/")[0] self._default_edge_type = self.edge_type_func(src_node_type, dst_node_type) return self._default_edge_type def __process_int_edge_key(self, key: int) -> str: if key < 0: key = len(self.data) + key if key < 0 or key >= len(self.data): raise KeyError(key) return next(islice(self.data.keys(), key, key + 1)) def __is_valid_edge_outbound(self, edge: dict[str, Any]) -> bool: a = edge["_from"] == self.src_node_id b = edge["_to"] == self.dst_node_id return bool(a and b) def __is_valid_edge_inbound(self, edge: dict[str, Any]) -> bool: a = edge["_from"] == self.dst_node_id b = edge["_to"] == self.src_node_id return bool(a and b) def __is_valid_edge_any(self, edge: dict[str, Any]) -> bool: return self.__is_valid_edge_outbound(edge) or self.__is_valid_edge_inbound(edge) def __get_mirrored_edge_attr(self, edge_id: str) -> EdgeAttrDict | None: """This method is used to get the EdgeAttrDict of the "mirrored" EdgeKeyDict. A "mirrored edge" is defined as a reference to an edge that represents both the forward and reverse edge between two nodes. This is useful because ArangoDB does not need to duplicate edges in both directions in the database. If the Graph is Undirected: - The "mirror" is the same adjlist_outer_dict because the adjacency list is the same in both directions (i.e _adj) If the Graph is Directed: - The "mirror" is the "reverse" adjlist_outer_dict because the adjacency list is different in both directions (i.e _pred and _succ) Parameters ---------- edge_id : str The edge ID. Returns ------- EdgeAttrDict | None The edge attribute dictionary if it exists. """ if self.adjlist_inner_dict is None: return None if self.adjlist_inner_dict.adjlist_outer_dict is None: return None mirror = self.adjlist_inner_dict.adjlist_outer_dict # fake mirror (i.e G._adj) if self.is_directed: mirror = mirror.mirror # real mirror (i.e _pred or _succ) if self.dst_node_id in mirror.data: if self.src_node_id in mirror.data[self.dst_node_id].data: if edge_id in mirror.data[self.dst_node_id].data[self.src_node_id].data: return ( mirror.data[self.dst_node_id] .data[self.src_node_id] .data[edge_id] ) return None def _create_edge_attr_dict(self, edge: dict[str, Any]) -> EdgeAttrDict: edge_attr_dict = self.edge_attr_dict_factory() edge_attr_dict.edge_id = edge["_id"] edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, edge) return edge_attr_dict def __repr__(self) -> str: if self.FETCHED_ALL_DATA: return self.data.__repr__() return f"EdgeKeyDict('{self.src_node_id}', '{self.dst_node_id}')" def __str__(self) -> str: return self.__repr__() @key_is_adb_id_or_int def __contains__(self, key: str | int) -> bool: """ Examples -------- >>> 'edge/1' in G._adj['node/1']['node/2'] >>> 0 in G._adj['node/1']['node/2'] """ # HACK: This is a workaround for the fact that # nxadb.MultiGraph does not yet support custom edge keys if key == "-1": return False if isinstance(key, int): key = self.__process_int_edge_key(key) if key in self.data: return True if self.FETCHED_ALL_IDS: return False edge = edge_get(self.graph, key) if edge is None: logger.warning(f"Edge '{key}' does not exist in Graph.") return False if not self.__is_valid_edge(edge): m = f"Edge '{key}' exists, but does not match the source & destination nodes." # noqa logger.warning(m) return False # Contrary to other __contains__ methods, we immediately # populate the Dict Data because we had to retrieve # the entire edge from the database to check if it is valid. edge_attr_dict = self._create_edge_attr_dict(edge) self.data[key] = edge_attr_dict return True @key_is_adb_id_or_int def __getitem__(self, key: str | int) -> EdgeAttrDict: """ Examples -------- >>> G._adj['node/1']['node/2']['edge/1'] >>> G._adj['node/1']['node/2'][0] """ # HACK: This is a workaround for the fact that # nxadb.MultiGraph does not yet support custom edge keys if key == "-1": raise KeyError(key) if isinstance(key, int): key = self.__process_int_edge_key(key) # Notice the use of walrus operator here, # because we can return the value immediately # given that __contains__ builds EdgeAttrDict.data if value := self.data.get(key): return value if result := self.__get_mirrored_edge_attr(key): self.data[key] = result return result if key not in self.data and self.FETCHED_ALL_IDS: raise KeyError(key) edge = edge_get(self.graph, key) if edge is None: raise KeyError(key) if not self.__is_valid_edge(edge): m = f"Edge '{key}' exists, but does not match the source & destination nodes." # noqa raise KeyError(m) edge_attr_dict: EdgeAttrDict = self._create_edge_attr_dict(edge) self.data[key] = edge_attr_dict return edge_attr_dict def __setitem__(self, key: int, edge_attr_dict: EdgeAttrDict) -> None: # type: ignore[override] # noqa """G._adj['node/1']['node/2'][0] = {'foo': 'bar'}""" self.data[str(key)] = edge_attr_dict if edge_attr_dict.edge_id: # NOTE: We can get here from L514 in networkx/multigraph.py # Assuming that keydict.get(key) did not return None (L513) # If the edge_id is already set, it means that the # EdgeAttrDict.update() that was just called was # able to update the edge in the database. # Therefore, we don't need to insert anything. if self.edge_type_key in edge_attr_dict.data: m = f"Cannot set '{self.edge_type_key}' if edge already exists in DB." raise EdgeTypeAmbiguity(m) return if not self.src_node_id or not self.dst_node_id: # We can get here from L521 in networkx/multigraph.py logger.debug("Node IDs not set, skipping EdgeKeyDict(?).__setitem__()") return # NOTE: We can get here from L514 in networkx/multigraph.py # Assuming that keydict.get(key) returned None (L513) edge_type = edge_attr_dict.data.pop(self.edge_type_key, None) if not edge_type: edge_type = self.default_edge_type edge = edge_link( self.graph, edge_type, self.src_node_id, self.dst_node_id, edge_attr_dict.data, ) edge_data: dict[str, Any] = { **edge_attr_dict.data, **edge, "_from": self.src_node_id, "_to": self.dst_node_id, } # We have to re-create the EdgeAttrDict because the # previous one was created without any **edge_id** # TODO: Could we somehow update the existing EdgeAttrDict? # i.e edge_attr_dict.data = edge_data # + some extra code to set the **edge_id** attribute # for any nested EdgeAttrDicts within edge_attr_dict edge_id = edge["_id"] edge_attr_dict = self._create_edge_attr_dict(edge_data) self.data[edge_id] = edge_attr_dict del self.data[str(key)] def __delitem__(self, key: str) -> None: """ Examples -------- >>> del G._adj['node/1']['node/2']['edge/1'] >>> del G._adj['node/1']['node/2'][0] """ if isinstance(key, int): key = self.__process_int_edge_key(key) self.data.pop(key, None) if self.__get_mirrored_edge_attr(key): # We're skipping the DB deletion because the # edge deletion for mirrored edges is handled # twice (once for each direction). # i.e the DB edge will be deleted in via the # delitem() call on the mirrored edge return try: self.graph.delete_edge(key) except DocumentDeleteError: # TODO: Should we just return here? raise KeyError(key) def clear(self) -> None: """G._adj['node/1']['node/2'].clear()""" self.data.clear() self.FETCHED_ALL_DATA = False self.FETCHED_ALL_IDS = False def copy(self) -> Any: """G._adj['node/1']['node/2'].copy()""" if not self.FETCHED_ALL_DATA: self._fetch_all() return {key: value.copy() for key, value in self.data.items()} @keys_are_strings def update(self, edges: Any) -> None: """g._adj['node/1']['node/2'].update( {'edge/1': {'foo': 'bar'}, 'edge/2': {'baz': 'qux'}} ) """ raise NotImplementedError("EdgeKeyDict.update()") def popitem(self) -> tuple[str, dict[str, Any]]: # type: ignore """G._adj['node/1']['node/2'].popitem()""" last_key = list(self.keys())[-1] edge_attr_dict = self.data[last_key] assert hasattr(edge_attr_dict, "to_dict") dict = edge_attr_dict.to_dict() self.__delitem__(last_key) return (last_key, dict) def __len__(self) -> int: """len(g._adj['node/1']['node/2'])""" assert self.src_node_id assert self.dst_node_id if self.FETCHED_ALL_IDS: return len(self.data) return aql_edge_count_src_dst( self.db, self.src_node_id, self.dst_node_id, self.graph.name, self.traversal_direction.name, ) def __iter__(self) -> Iterator[str]: """for k in g._adj['node/1']['node/2']""" if not (self.FETCHED_ALL_DATA or self.FETCHED_ALL_IDS): self._fetch_all() yield from self.data.keys() def keys(self) -> Any: """g._adj['node/1']['node/2'].keys()""" if self.FETCHED_ALL_IDS: yield from self.data.keys() else: assert self.src_node_id assert self.dst_node_id edge_ids: list[str] | None = aql_edge_id( self.db, self.src_node_id, self.dst_node_id, self.graph.name, self.traversal_direction.name, can_return_multiple=True, ) if edge_ids is None: raise ValueError("Failed to fetch Edge IDs") self.FETCHED_ALL_IDS = True for edge_id in edge_ids: self.data[edge_id] = self.edge_attr_dict_factory() yield edge_id def values(self) -> Any: """g._adj['node/1']['node/2'].values()""" if not self.FETCHED_ALL_DATA: self._fetch_all() yield from self.data.values() def items(self) -> Any: """g._adj['node/1']['node/2'].items()""" if not self.FETCHED_ALL_DATA: self._fetch_all() yield from self.data.items() def _fetch_all(self) -> None: assert self.src_node_id assert self.dst_node_id self.clear() edges: list[dict[str, Any]] | None = aql_edge_get( self.db, self.src_node_id, self.dst_node_id, self.graph.name, direction=self.traversal_direction.name, can_return_multiple=True, ) if edges is None: raise ValueError("Failed to fetch edges") for edge in edges: edge_attr_dict = self._create_edge_attr_dict(edge) self.data[edge["_id"]] = edge_attr_dict self.FETCHED_ALL_DATA = True self.FETCHED_ALL_IDS = True
[docs] class AdjListInnerDict(UserDict[str, EdgeAttrDict | EdgeKeyDict]): """The 2nd level of the dict of dict (of dict) of dict structure representing the Adjacency List of a graph. AdjListInnerDict is keyed by the node ID of the destination node. Parameters ---------- db : arango.database.StandardDatabase The ArangoDB database. graph : arango.graph.Graph The ArangoDB graph. default_node_type : str The default node type. edge_type_key : str The key used to store the edge type in the edge attribute dictionary. edge_type_func : Callable[[str, str], str] The function to generate the edge type from the source and destination node types. graph_type : str The type of graph (e.g. 'Graph', 'DiGraph', 'MultiGraph', 'MultiDiGraph'). adjlist_outer_dict : AdjListOuterDict | None The parent AdjListOuterDict. Examples -------- >>> g = nxadb.Graph(name="MyGraph") >>> g.add_edge("node/1", "node/2", foo="bar") >>> g['node/1'] AdjListInnerDict('node/1') """ def __init__( self, db: StandardDatabase, graph: Graph, default_node_type: str, edge_type_key: str, edge_type_func: Callable[[str, str], str], graph_type: str, adjlist_outer_dict: AdjListOuterDict | None, *args: Any, **kwargs: Any, ): if graph_type not in GraphType.__members__: raise ValueError(f"**graph_type** not supported: {graph_type}") super().__init__(*args, **kwargs) self.data: dict[str, EdgeAttrDict | EdgeKeyDict] = {} self.graph_type = graph_type self.is_directed = graph_type in DIRECTED_GRAPH_TYPES self.is_multigraph = graph_type in MULTIGRAPH_TYPES self.db = db self.graph = graph self.edge_type_key = edge_type_key self.edge_type_func = edge_type_func self.default_node_type = default_node_type self.edge_attr_dict_factory = edge_attr_dict_factory(self.db, self.graph) self.edge_key_dict_factory = edge_key_dict_factory( self.db, self.graph, edge_type_key, edge_type_func, self.is_directed, self ) self.src_node_id: str | None = None self.__src_node_type: str | None = None self.adjlist_outer_dict = adjlist_outer_dict self.FETCHED_ALL_DATA = False self.FETCHED_ALL_IDS = False self.traversal_direction: TraversalDirection = ( adjlist_outer_dict.traversal_direction if adjlist_outer_dict is not None else ( TraversalDirection.OUTBOUND if self.is_directed else TraversalDirection.ANY ) ) direction_mappings = { TraversalDirection.OUTBOUND: ("e._to", "_to"), TraversalDirection.INBOUND: ("e._from", "_from"), TraversalDirection.ANY: ("e._to == @src_node_id ? e._from : e._to", None), } k = self.traversal_direction self.__iter__return_str, self._fetch_all_dst_node_key = direction_mappings[k] self.__getitem_helper_db: Callable[[str, str], EdgeAttrDict | EdgeKeyDict] self.__setitem_helper: Callable[[EdgeAttrDict | EdgeKeyDict, str, str], None] self.__delitem_helper: Callable[[str | list[str]], None] if self.is_multigraph: self.__contains_helper = self.__contains__multigraph self.__getitem_helper_db = self.__getitem__multigraph_db self.__getitem_helper_cache = self.__getitem__multigraph_cache self.__setitem_helper = self.__setitem__multigraph # type: ignore[assignment] # noqa self.__delitem_helper = self.__delitem__multigraph # type: ignore[assignment] # noqa self.__fetch_all_helper = self.__fetch_all_multigraph else: self.__contains_helper = self.__contains__graph self.__getitem_helper_db = self.__getitem__graph_db self.__getitem_helper_cache = self.__getitem__graph_cache self.__setitem_helper = self.__setitem__graph # type: ignore[assignment] self.__delitem_helper = self.__delitem__graph # type: ignore[assignment] self.__fetch_all_helper = self.__fetch_all_graph @property def src_node_type(self) -> str: if self.__src_node_type is None: assert self.src_node_id self.__src_node_type = self.src_node_id.split("/")[0] return self.__src_node_type def _create_edge_attr_dict(self, edge: dict[str, Any]) -> EdgeAttrDict: edge_attr_dict = self.edge_attr_dict_factory() edge_attr_dict.edge_id = edge["_id"] edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, edge) return edge_attr_dict def __get_mirrored_edge_attr_or_key_dict( self, dst_node_id: str ) -> EdgeAttrDict | EdgeKeyDict | None: """This method is used to get the EdgeAttrDict or EdgeKeyDict of the "mirrored" AdJlistInnerDict. A "mirrored edge" is defined as a reference to an edge (or multiple edges) that represents both the forward and reverse edge between two nodes. This is useful because ArangoDB does not need to duplicate edges in both directions in the database. If the Graph is Undirected: - The "mirror" is the same adjlist_outer_dict because the adjacency list is the same in both directions (i.e _adj) If the Graph is Directed: - The "mirror" is the "reverse" adjlist_outer_dict because the adjacency list is different in both directions (i.e _pred and _succ) Parameters ---------- dst_node_id : str The destination node ID. Returns ------- EdgeAttrDict | EdgeKeyDict | None The edge attribute dictionary or key dictionary if it exists. """ if self.adjlist_outer_dict is None: return None mirror = self.adjlist_outer_dict # fake mirror (i.e G._adj) if self.is_directed: mirror = mirror.mirror # real mirror (i.e _pred or _succ) if dst_node_id in mirror.data: if self.src_node_id in mirror.data[dst_node_id].data: return mirror.data[dst_node_id].data[self.src_node_id] return None def __repr__(self) -> str: if self.FETCHED_ALL_DATA: return self.data.__repr__() return f"AdjListInnerDict('{self.src_node_id}')" def __str__(self) -> str: return self.__repr__() @key_is_string def __contains__(self, key: str) -> bool: """'node/2' in G.adj['node/1']""" assert self.src_node_id dst_node_id = get_node_id(key, self.default_node_type) if dst_node_id in self.data: return True if self.FETCHED_ALL_IDS: return False result = aql_edge_exists( self.db, self.src_node_id, dst_node_id, self.graph.name, direction=self.traversal_direction.name, ) if not result: return False self.__contains_helper(dst_node_id) return True def __contains__graph(self, dst_node_id: str) -> None: """Helper function for __contains__ in Graphs.""" empty_edge_attr_dict = self.edge_attr_dict_factory() self.data[dst_node_id] = empty_edge_attr_dict def __contains__multigraph(self, dst_node_id: str) -> None: """Helper function for __contains__ in MultiGraphs.""" lazy_edge_key_dict = self.edge_key_dict_factory() lazy_edge_key_dict.src_node_id = self.src_node_id lazy_edge_key_dict.dst_node_id = dst_node_id self.data[dst_node_id] = lazy_edge_key_dict @key_is_string def __getitem__(self, key: str) -> EdgeAttrDict | EdgeKeyDict: """g._adj['node/1']['node/2']""" dst_node_id = get_node_id(key, self.default_node_type) if self.__getitem_helper_cache(dst_node_id): return self.data[dst_node_id] if result := self.__get_mirrored_edge_attr_or_key_dict(dst_node_id): self.data[dst_node_id] = result return result if key not in self.data and self.FETCHED_ALL_IDS: raise KeyError(key) return self.__getitem_helper_db(key, dst_node_id) def __getitem__graph_cache(self, dst_node_id: str) -> bool: """Cache Helper function for __getitem__ in Graphs.""" if _ := self.data.get(dst_node_id): return True return False def __getitem__graph_db(self, key: str, dst_node_id: str) -> EdgeAttrDict: """DB Helper function for __getitem__ in Graphs.""" assert self.src_node_id edge: dict[str, Any] | None = aql_edge_get( self.db, self.src_node_id, dst_node_id, self.graph.name, direction=self.traversal_direction.name, can_return_multiple=self.is_multigraph, ) if not edge: raise KeyError(key) edge_attr_dict: EdgeAttrDict = self._create_edge_attr_dict(edge) self.data[dst_node_id] = edge_attr_dict return edge_attr_dict def __getitem__multigraph_cache(self, dst_node_id: str) -> bool: """Cache Helper function for __getitem__ in Graphs.""" # Notice that we're not using the walrus operator here # compared to other __getitem__ methods. # This is because EdgeKeyDict is lazily populated # when the second key is accessed (e.g G._adj["node/1"]["node/2"]['edge/1']). # Therefore, there is no actual data in EdgeKeyDict.data # when it is first created! return dst_node_id in self.data def __getitem__multigraph_db(self, key: str, dst_node_id: str) -> EdgeKeyDict: """Helper function for __getitem__ in MultiGraphs.""" assert self.src_node_id result = aql_edge_exists( self.db, self.src_node_id, dst_node_id, self.graph.name, direction=self.traversal_direction.name, ) if not result: raise KeyError(key) lazy_edge_key_dict = self.edge_key_dict_factory() lazy_edge_key_dict.src_node_id = self.src_node_id lazy_edge_key_dict.dst_node_id = dst_node_id self.data[dst_node_id] = lazy_edge_key_dict return lazy_edge_key_dict @key_is_string def __setitem__(self, key: str, value: EdgeAttrDict | EdgeKeyDict) -> None: """ g._adj['node/1']['node/2'] = {'foo': 'bar'} g._adj['node/1']['node/2'] = {0: {'foo': 'bar'}} """ assert self.src_node_id assert isinstance(value, EdgeKeyDict if self.is_multigraph else EdgeAttrDict) dst_node_type, dst_node_id = get_node_type_and_id(key, self.default_node_type) if result := self.__get_mirrored_edge_attr_or_key_dict(dst_node_id): self.data[dst_node_id] = result return self.__setitem_helper(value, dst_node_type, dst_node_id) def __setitem__graph( self, edge_attr_dict: EdgeAttrDict, dst_node_type: str, dst_node_id: str ) -> None: """Helper function for __setitem__ in Graphs.""" if edge_attr_dict.edge_id: # If the edge_id is already set, it means that the # EdgeAttrDict.update() that was just called was # able to update the edge in the database. # Therefore, we don't need to insert anything. if self.edge_type_key in edge_attr_dict.data: m = f"Cannot set '{self.edge_type_key}' if edge already exists in DB." raise EdgeTypeAmbiguity(m) return edge_type = edge_attr_dict.data.pop(self.edge_type_key, None) if edge_type is None: edge_type = self.edge_type_func(self.src_node_type, dst_node_type) assert self.src_node_id edge_id = aql_edge_id( self.db, self.src_node_id, dst_node_id, self.graph.name, direction=self.traversal_direction.name, can_return_multiple=False, ) edge = ( doc_insert(self.db, edge_type, edge_id, edge_attr_dict.data) if edge_id else edge_link( self.graph, edge_type, self.src_node_id, dst_node_id, edge_attr_dict.data, ) ) edge_data: dict[str, Any] = { **edge_attr_dict.data, **edge, "_from": self.src_node_id, "_to": dst_node_id, } # We have to re-create the EdgeAttrDict because the # previous one was created without any **edge_id** # TODO: Could we somehow update the existing EdgeAttrDict? # i.e edge_attr_dict.data = edge_data # + some extra code to set the **edge_id** attribute # for any nested EdgeAttrDicts within edge_attr_dict edge_attr_dict = self._create_edge_attr_dict(edge_data) self.data[dst_node_id] = edge_attr_dict def __setitem__multigraph( self, edge_key_dict: EdgeKeyDict, dst_node_type: str, dst_node_id: str ) -> None: """Helper function for __setitem__ in MultiGraphs.""" assert len(edge_key_dict.data) == 1 assert list(edge_key_dict.data.keys())[0] == "-1" assert edge_key_dict.src_node_id is None assert edge_key_dict.dst_node_id is None assert self.src_node_id is not None edge_attr_dict = edge_key_dict.data["-1"] edge_type = edge_attr_dict.data.pop(self.edge_type_key, None) if edge_type is None: edge_type = self.edge_type_func(self.src_node_type, dst_node_type) edge = edge_link( self.graph, edge_type, self.src_node_id, dst_node_id, edge_attr_dict.data ) edge_data: dict[str, Any] = { **edge_attr_dict.data, **edge, "_from": self.src_node_id, "_to": dst_node_id, } # We have to re-create the EdgeAttrDict because the # previous one was created without any **edge_id** # TODO: Could we somehow update the existing EdgeAttrDict? # i.e edge_attr_dict.data = edge_data # + some extra code to set the **edge_id** attribute # for any nested EdgeAttrDicts within edge_attr_dict edge_id = edge["_id"] edge_key_dict.data[edge_id] = self._create_edge_attr_dict(edge_data) edge_key_dict.src_node_id = self.src_node_id edge_key_dict.dst_node_id = dst_node_id edge_key_dict.traversal_direction = self.traversal_direction self.data[dst_node_id] = edge_key_dict del edge_key_dict.data["-1"] @key_is_string def __delitem__(self, key: str) -> None: """del g._adj['node/1']['node/2']""" assert self.src_node_id dst_node_id = get_node_id(key, self.default_node_type) self.data.pop(dst_node_id, None) if self.__get_mirrored_edge_attr_or_key_dict(dst_node_id): # We're skipping the DB deletion because the # edge deletion for mirrored edges is handled # twice (once for each direction). # i.e the DB edge will be deleted in via the # delitem() call on the mirrored edge return result = aql_edge_id( self.db, self.src_node_id, dst_node_id, self.graph.name, direction=self.traversal_direction.name, can_return_multiple=self.is_multigraph, ) if not result: # TODO: Should we raise a KeyError instead? return self.__delitem_helper(result) def __delitem__graph(self, edge_id: str) -> None: """Helper function for __delitem__ in Graphs.""" try: self.graph.delete_edge(edge_id) except DocumentDeleteError as e: m = f"Failed to delete edge '{edge_id}' from Graph: {e}." raise KeyError(m) def __delitem__multigraph(self, edge_ids: list[str]) -> None: """Helper function for __delitem__ in MultiGraphs.""" # TODO: Consider separating **edge_ids** by edge collection, # and invoking db.collection(...).delete_many() instead of this: for edge_id in edge_ids: self.__delitem__graph(edge_id) def __len__(self) -> int: """len(g._adj['node/1'])""" assert self.src_node_id if self.FETCHED_ALL_IDS: return len(self.data) return aql_edge_count_src( self.db, self.src_node_id, self.graph.name, self.traversal_direction.name ) def __iter__(self) -> Iterator[str]: """for k in g._adj['node/1']""" if not (self.FETCHED_ALL_DATA or self.FETCHED_ALL_IDS): self._fetch_all() yield from self.data.keys() def keys(self) -> Any: """g._adj['node/1'].keys()""" if self.FETCHED_ALL_IDS: yield from self.data.keys() else: query = f""" FOR v, e IN 1..1 {self.traversal_direction.name} @src_node_id GRAPH @graph_name RETURN {self.__iter__return_str} """ bind_vars = {"src_node_id": self.src_node_id, "graph_name": self.graph.name} self.FETCHED_ALL_IDS = True for edge_id in aql(self.db, query, bind_vars): self.__contains_helper(edge_id) yield edge_id def clear(self) -> None: """G._adj['node/1'].clear()""" self.data.clear() self.FETCHED_ALL_DATA = False self.FETCHED_ALL_IDS = False def copy(self) -> Any: """G._adj['node/1'].copy()""" if not self.FETCHED_ALL_DATA: self._fetch_all() return {key: value.copy() for key, value in self.data.items()} @keys_are_strings def update(self, edges: dict[str, dict[str, Any]]) -> None: """g._adj['node/1'].update({'node/2': {'foo': 'bar'}})""" assert self.src_node_id from_col_name = get_node_type(self.src_node_id, self.default_node_type) to_upsert: Dict[str, List[Dict[str, Any]]] = {from_col_name: []} for edge_id, edge_data in edges.items(): edge_doc = edge_data edge_doc["_from"] = self.src_node_id edge_doc["_to"] = edge_id edge_doc_id = edge_data.get("_id") if not edge_doc_id: raise ValueError("Edge _id field is required for update.") edge_col_name = get_node_type(edge_doc_id, self.default_node_type) if to_upsert.get(edge_col_name) is None: to_upsert[edge_col_name] = [edge_doc] else: to_upsert[edge_col_name].append(edge_doc) # perform write to ArangoDB result = upsert_collection_edges(self.db, to_upsert) all_good = check_update_list_for_errors(result) if all_good: # Means no single operation failed, in this case we update the local cache self.__set_adj_elements(edges) else: # In this case some or all documents failed. Right now we will not # update the local cache, but raise an error instead. # Reason: We cannot set silent to True, because we need as it does # not report errors then. We need to update the driver to also pass # the errors back to the user, then we can adjust the behavior here. # This will also save network traffic and local computation time. errors = [] for collections_results in result: for collection_result in collections_results: errors.append(collection_result) logger.warning( "Failed to insert at least one node. Will not update local cache." ) raise ArangoDBBatchError(errors) def values(self) -> Any: """g._adj['node/1'].values()""" if not self.FETCHED_ALL_DATA: self._fetch_all() yield from self.data.values() def items(self) -> Any: """g._adj['node/1'].items()""" if not self.FETCHED_ALL_DATA: self._fetch_all() yield from self.data.items() def _fetch_all(self) -> None: assert self.src_node_id self.clear() query = f""" FOR v, e IN 1..1 {self.traversal_direction.name} @src_node_id GRAPH @graph_name RETURN UNSET(e, '_rev') """ bind_vars = {"src_node_id": self.src_node_id, "graph_name": self.graph.name} for edge in aql(self.db, query, bind_vars): edge_attr_dict: EdgeAttrDict = self._create_edge_attr_dict(edge) dst_node_id: str = ( edge[self._fetch_all_dst_node_key] if self._fetch_all_dst_node_key else edge["_to"] if self.src_node_id == edge["_from"] else edge["_from"] ) self.__fetch_all_helper(edge_attr_dict, dst_node_id) self.FETCHED_ALL_DATA = True self.FETCHED_ALL_IDS = True def __set_adj_elements(self, edges): for dst_node_id, edge in edges.items(): edge_attr_dict: EdgeAttrDict = self._create_edge_attr_dict(edge) self.__fetch_all_helper(edge_attr_dict, dst_node_id, is_update=True) def __fetch_all_graph( self, edge_attr_dict: EdgeAttrDict, dst_node_id: str, is_update: bool = False ) -> None: """Helper function for _fetch_all() in Graphs.""" if dst_node_id in self.data: # Don't raise an error if it's a self-loop if self.data[dst_node_id] == edge_attr_dict: return if is_update: return m = "Multiple edges between the same nodes are not supported in Graphs." m += f" Found 2 edges between {self.src_node_id} & {dst_node_id}." m += " Consider using a MultiGraph." raise MultipleEdgesFound(m) self.data[dst_node_id] = edge_attr_dict def __fetch_all_multigraph( self, edge_attr_dict: EdgeAttrDict, dst_node_id: str, is_update: bool = False ) -> None: """Helper function for _fetch_all() in MultiGraphs.""" edge_key_dict = self.data.get(dst_node_id) if edge_key_dict is None: edge_key_dict = self.edge_key_dict_factory() edge_key_dict.src_node_id = self.src_node_id edge_key_dict.dst_node_id = dst_node_id edge_key_dict.FETCHED_ALL_DATA = True edge_key_dict.FETCHED_ALL_IDS = True assert edge_attr_dict.edge_id assert isinstance(edge_key_dict, EdgeKeyDict) edge_key_dict.data[edge_attr_dict.edge_id] = edge_attr_dict self.data[dst_node_id] = edge_key_dict
[docs] class AdjListOuterDict(UserDict[str, AdjListInnerDict]): """The 1st level of the dict of dict (of dict) of dict representing the Adjacency List of a graph. AdjListOuterDict is keyed by the node ID of the source node. Parameters ---------- db : arango.database.StandardDatabase The ArangoDB database. graph : arango.graph.Graph The ArangoDB graph. default_node_type : str The default node type. edge_type_key : str The key used to store the edge type in the edge attribute dictionary. edge_type_func : Callable[[str, str], str] The function to generate the edge type from the source and destination node types. graph_type : str The type of graph (e.g. 'Graph', 'DiGraph', 'MultiGraph', 'MultiDiGraph'). symmetrize_edges_if_directed : bool Whether to add the reverse edge if the graph is directed. read_parallelism : int The number of parallel threads to use for reading data in _fetch_all. read_batch_size : int The number of documents to read in each batch in _fetch_all. Example ------- >>> g = nxadb.Graph(name="MyGraph") >>> g.add_edge("node/1", "node/2", foo="bar") >>> g._adj AdjListOuterDict('MyGraph') """ def __init__( self, db: StandardDatabase, graph: Graph, default_node_type: str, read_parallelism: int, read_batch_size: int, edge_type_key: str, edge_type_func: Callable[[str, str], str], graph_type: str, symmetrize_edges_if_directed: bool, *args: Any, **kwargs: Any, ): if graph_type not in GraphType.__members__: raise ValueError(f"**graph_type** not supported: {graph_type}") super().__init__(*args, **kwargs) self.data: dict[str, AdjListInnerDict] = {} self.graph_type = graph_type self.is_directed = graph_type in DIRECTED_GRAPH_TYPES self.is_multigraph = graph_type in MULTIGRAPH_TYPES self.db = db self.graph = graph self.edge_type_key = edge_type_key self.edge_type_func = edge_type_func self.default_node_type = default_node_type self.read_parallelism = read_parallelism self.read_batch_size = read_batch_size self.adjlist_inner_dict_factory = adjlist_inner_dict_factory( db, graph, default_node_type, edge_type_key, edge_type_func, graph_type, self, ) self.FETCHED_ALL_DATA = False self.FETCHED_ALL_IDS = False self.traversal_direction = ( TraversalDirection.OUTBOUND if self.is_directed else TraversalDirection.ANY ) self.symmetrize_edges_if_directed = ( symmetrize_edges_if_directed and self.is_directed ) self.mirror: AdjListOuterDict def __get_mirrored_adjlist_inner_dict( self, node_id: str ) -> AdjListInnerDict | None: """This method is used to get the AdjListInnerDict of the "mirrored" AdjListOuterDict. A "mirrored edge" is defined as a reference to an edge that represents both the forward and reverse edge between two nodes. This is useful because ArangoDB does not need to duplicate edges in both directions in the database. If the Graph is Undirected: - The "mirror" is the same AdjListOuterDict because the adjacency list is the same in both directions (i.e _adj) If the Graph is Directed: - The "mirror" is the "reverse" AdjListOuterDict because the adjacency list is different in both directions (i.e _pred and _succ) :param node_id: The source node ID. :type node_id: str :return: The adjacency list inner dictionary if it exists. :rtype: AdjListInnerDict | None """ if not self.is_directed: return None if node_id in self.mirror.data: return self.mirror.data[node_id] return None def __repr__(self) -> str: if self.FETCHED_ALL_DATA: return self.data.__repr__() return f"AdjListOuterDict('{self.graph.name}')" def __str__(self) -> str: return self.__repr__() @key_is_string def __contains__(self, key: str) -> bool: """'node/1' in G.adj""" node_id = get_node_id(key, self.default_node_type) if node_id in self.data: return True if self.FETCHED_ALL_IDS: return False if self.graph.has_vertex(node_id): lazy_adjlist_inner_dict = self.adjlist_inner_dict_factory() lazy_adjlist_inner_dict.src_node_id = node_id self.data[node_id] = lazy_adjlist_inner_dict return True return False @key_is_string def __getitem__(self, key: str) -> AdjListInnerDict: """G._adj["node/1"]""" node_id = get_node_id(key, self.default_node_type) if node_id in self.data: # Notice that we're not using the walrus operator here # compared to other __getitem__ methods. # This is because AdjListInnerDict is lazily populated # when the second key is accessed (e.g G._adj["node/1"]["node/2"]). # Therefore, there is no actual data in AdjListInnerDict.data # when it is first created! return self.data[node_id] if self.__get_mirrored_adjlist_inner_dict(node_id): lazy_adjlist_inner_dict = self.adjlist_inner_dict_factory() lazy_adjlist_inner_dict.src_node_id = node_id self.data[node_id] = lazy_adjlist_inner_dict return lazy_adjlist_inner_dict if self.FETCHED_ALL_IDS: raise KeyError(key) if self.graph.has_vertex(node_id): lazy_adjlist_inner_dict = self.adjlist_inner_dict_factory() lazy_adjlist_inner_dict.src_node_id = node_id self.data[node_id] = lazy_adjlist_inner_dict return lazy_adjlist_inner_dict raise KeyError(key) @key_is_string def __setitem__(self, src_key: str, adjlist_inner_dict: AdjListInnerDict) -> None: """g._adj['node/1'] = AdjListInnerDict()""" assert isinstance(adjlist_inner_dict, AdjListInnerDict) assert len(adjlist_inner_dict.data) == 0 src_node_id = get_node_id(src_key, self.default_node_type) adjlist_inner_dict.src_node_id = src_node_id adjlist_inner_dict.adjlist_outer_dict = self adjlist_inner_dict.traversal_direction = self.traversal_direction self.data[src_node_id] = adjlist_inner_dict @key_is_string def __delitem__(self, key: str) -> None: """del G._adj['node/1']""" # Nothing else to do here, as this delete is always invoked by # G.remove_node(), which already removes all edges via # del G._node['node/1'] node_id = get_node_id(key, self.default_node_type) self.data.pop(node_id, None) def __len__(self) -> int: """len(g._adj)""" return sum( [ self.graph.vertex_collection(c).count() for c in self.graph.vertex_collections() ] ) def __iter__(self) -> Iterator[str]: """for k in g._adj""" if not (self.FETCHED_ALL_DATA or self.FETCHED_ALL_IDS): self._fetch_all() yield from self.data.keys() def keys(self) -> Any: """g._adj.keys()""" if self.FETCHED_ALL_IDS: yield from self.data.keys() else: self.FETCHED_ALL_IDS = True for collection in self.graph.vertex_collections(): for node_id in self.graph.vertex_collection(collection).ids(): lazy_adjlist_inner_dict = self.adjlist_inner_dict_factory() lazy_adjlist_inner_dict.src_node_id = node_id self.data[node_id] = lazy_adjlist_inner_dict yield node_id def clear(self) -> None: """g._adj.clear()""" self.data.clear() self.FETCHED_ALL_DATA = False self.FETCHED_ALL_IDS = False def copy(self) -> Any: """g._adj.copy()""" if not self.FETCHED_ALL_DATA: self._fetch_all() return {key: value.copy() for key, value in self.data.items()} @keys_are_strings def update(self, edges: Any) -> None: """g._adj.update({'node/1': {'node/2': {'_id': 'foo/bar', 'foo': "bar"}})""" separated_by_edge_collection = separate_edges_by_collections( edges, graph_type=self.graph_type, default_node_type=self.default_node_type ) result = upsert_collection_edges(self.db, separated_by_edge_collection) all_good = check_update_list_for_errors(result) if all_good: # Means no single operation failed, in this case we update the local cache self.__set_adj_elements(edges) else: # In this case some or all documents failed. Right now we will not # update the local cache, but raise an error instead. # Reason: We cannot set silent to True, because we need as it does # not report errors then. We need to update the driver to also pass # the errors back to the user, then we can adjust the behavior here. # This will also save network traffic and local computation time. errors = [] for collections_results in result: for collection_result in collections_results: errors.append(collection_result) warnings.warn( "Failed to insert at least one node. Will not update local cache." ) raise ArangoDBBatchError(errors) def values(self) -> Any: """g._adj.values()""" if not self.FETCHED_ALL_DATA: self._fetch_all() yield from self.data.values() def items(self, data: str | None = None, default: Any | None = None) -> Any: """g._adj.items() or G._adj.items(data='foo')""" if data is None: if not self.FETCHED_ALL_DATA: self._fetch_all() yield from self.data.items() else: e_cols = [ed["edge_collection"] for ed in self.graph.edge_definitions()] yield from aql_fetch_data_edge(self.db, e_cols, data, default) def __set_adj_elements( self, adj_dict: AdjDict, node_dict: NodeDict | None = None ) -> None: def set_edge_graph( src_node_id: str, dst_node_id: str, edge: dict[str, Any] ) -> EdgeAttrDict: edge.pop("_rev", None) adjlist_inner_dict = self.data[src_node_id] edge_attr_dict: EdgeAttrDict edge_attr_dict = adjlist_inner_dict._create_edge_attr_dict(edge) if dst_node_id not in adjlist_inner_dict.data: adjlist_inner_dict.data[dst_node_id] = edge_attr_dict else: existing_edge_attr_dict = adjlist_inner_dict.data[dst_node_id] existing_edge_attr_dict.data.update(edge_attr_dict.data) return adjlist_inner_dict.data[dst_node_id] # type: ignore # false positive def set_edge_multigraph( src_node_id: str, dst_node_id: str, edges: dict[int, dict[str, Any]] ) -> EdgeKeyDict: adjlist_inner_dict = self.data[src_node_id] edge_key_dict = adjlist_inner_dict.edge_key_dict_factory() edge_key_dict.src_node_id = src_node_id edge_key_dict.dst_node_id = dst_node_id edge_key_dict.FETCHED_ALL_DATA = True edge_key_dict.FETCHED_ALL_IDS = True for edge in edges.values(): edge.pop("_rev", None) edge_attr_dict: EdgeAttrDict edge_attr_dict = adjlist_inner_dict._create_edge_attr_dict(edge) if edge["_id"] not in edge_key_dict.data: edge_key_dict.data[edge["_id"]] = edge_attr_dict else: existing_edge_attr_dict = edge_key_dict.data[edge["_id"]] existing_edge_attr_dict.data.update(edge_attr_dict.data) adjlist_inner_dict.data[dst_node_id] = edge_key_dict return edge_key_dict set_edge_func = set_edge_multigraph if self.is_multigraph else set_edge_graph def propagate_edge_undirected( src_node_id: str, dst_node_id: str, edge_key_or_attr_dict: EdgeKeyDict | EdgeAttrDict, ) -> None: self.data[dst_node_id].data[src_node_id] = edge_key_or_attr_dict def propagate_edge_directed( src_node_id: str, dst_node_id: str, edge_key_or_attr_dict: EdgeKeyDict | EdgeAttrDict, ) -> None: self.mirror.data[dst_node_id].data[src_node_id] = edge_key_or_attr_dict def propagate_edge_directed_symmetric( src_node_id: str, dst_node_id: str, edge_key_or_attr_dict: EdgeKeyDict | EdgeAttrDict, ) -> None: propagate_edge_directed(src_node_id, dst_node_id, edge_key_or_attr_dict) propagate_edge_undirected(src_node_id, dst_node_id, edge_key_or_attr_dict) self.mirror.data[src_node_id].data[dst_node_id] = edge_key_or_attr_dict propagate_edge_func = ( propagate_edge_directed_symmetric if self.symmetrize_edges_if_directed else ( propagate_edge_directed if self.is_directed else propagate_edge_undirected ) ) set_adj_inner_dict_mirror = ( self.mirror.__set_adj_inner_dict if self.is_directed else lambda *args: None ) if node_dict is not None: for node_id in node_dict.keys(): self.__set_adj_inner_dict(node_id) set_adj_inner_dict_mirror(node_id) for src_node_id, inner_dict in adj_dict.items(): for dst_node_id, edge_or_edges in inner_dict.items(): self.__set_adj_inner_dict(src_node_id) self.__set_adj_inner_dict(dst_node_id) set_adj_inner_dict_mirror(src_node_id) set_adj_inner_dict_mirror(dst_node_id) edge_attr_or_key_dict = set_edge_func( src_node_id, dst_node_id, edge_or_edges ) propagate_edge_func(src_node_id, dst_node_id, edge_attr_or_key_dict) def __set_adj_inner_dict(self, node_id: str) -> AdjListInnerDict: if node_id in self.data: return self.data[node_id] adj_inner_dict = self.adjlist_inner_dict_factory() adj_inner_dict.src_node_id = node_id adj_inner_dict.FETCHED_ALL_DATA = True adj_inner_dict.FETCHED_ALL_IDS = True self.data[node_id] = adj_inner_dict return adj_inner_dict def _fetch_all(self) -> None: self.clear() if self.is_directed: self.mirror.clear() ( node_dict, adj_dict, *_, ) = get_arangodb_graph( self.graph, load_node_dict=True, load_adj_dict=True, load_coo=False, edge_collections_attributes=set(), # not used load_all_vertex_attributes=False, load_all_edge_attributes=True, is_directed=True, is_multigraph=self.is_multigraph, symmetrize_edges_if_directed=self.symmetrize_edges_if_directed, read_parallelism=self.read_parallelism, read_batch_size=self.read_batch_size, ) # Even if the Graph is undirected, # we can rely on a "directed load" to get the adjacency list. # This prevents the adj_dict loop in __set_adj_elements() # from setting the same edge twice in the adjacency list. # We still get the benefit of propagating the edge to the "mirror" # in the case of an undirected graph, via the `propagate_edge_func`. adj_dict = adj_dict["succ"] self.__set_adj_elements(adj_dict, node_dict) self.FETCHED_ALL_DATA = True self.FETCHED_ALL_IDS = True if self.is_directed: self.mirror.FETCHED_ALL_DATA = True self.mirror.FETCHED_ALL_IDS = True