Source code for nx_arangodb.convert

"""Functions to convert between NetworkX, NetworkX-ArangoDB,
and NetworkX-cuGraph.

Examples
--------
>>> import networkx as nx
>>> import nx_arangodb as nxadb
>>> import nx_cugraph as nxcg
>>>
>>> G = nx.Graph()
>>> G.add_edge(1, 2, weight=3.0)
>>> G.add_edge(2, 3, weight=7.5)
>>>
>>> G_ADB = nxadb.convert._to_nxadb_graph(G)
>>> G_CG = nxadb.convert._to_nxcg_graph(G_ADB)
>>> G_NX = nxadb.convert._to_nx_graph(G_ADB)
"""

from __future__ import annotations

import time
from typing import Any

import networkx as nx

import nx_arangodb as nxadb
from nx_arangodb.classes.dict.adj import AdjListOuterDict
from nx_arangodb.classes.dict.node import NodeDict
from nx_arangodb.logger import logger

try:
    import cupy as cp
    import nx_cugraph as nxcg

    GPU_AVAILABLE = True
    logger.info("NetworkX-cuGraph is available.")
except Exception as e:
    GPU_AVAILABLE = False
    logger.info(f"NetworkX-cuGraph is unavailable: {e}.")

__all__ = [
    "_to_nx_graph",
    "_to_nxadb_graph",
    "_to_nxcg_graph",
]


[docs] def _to_nx_graph(G: Any, *args: Any, **kwargs: Any) -> nx.Graph: """Convert a graph to a NetworkX graph. Parameters ---------- G : Any The graph to convert. Currently supported types: - nx.Graph - nxadb.Graph Returns ------- nx.Graph The converted graph. """ logger.debug(f"_to_nx_graph for {G.__class__.__name__}") if isinstance(G, nxadb.Graph): return nxadb_to_nx(G) if isinstance(G, nx.Graph): return G raise TypeError(f"Expected nxadb.Graph or nx.Graph; got {type(G)}")
[docs] def _to_nxadb_graph( G: Any, *args: Any, as_directed: bool = False, **kwargs: Any ) -> nxadb.Graph: """Convert a graph to a NetworkX-ArangoDB graph. Parameters ---------- G : Any The graph to convert. Currently supported types: - nx.Graph - nxadb.Graph as_directed : bool, optional Whether to convert the graph to a directed graph. Default is False. Returns ------- nxadb.Graph The converted graph. """ logger.debug(f"_to_nxadb_graph for {G.__class__.__name__}") if isinstance(G, nxadb.Graph): return G if isinstance(G, nx.Graph): return nx_to_nxadb(G, as_directed=as_directed) raise TypeError(f"Expected nxadb.Graph or nx.Graph; got {type(G)}")
if GPU_AVAILABLE: def _to_nxcg_graph(G: Any, as_directed: bool = False) -> nxcg.Graph: """Convert a graph to a NetworkX-cuGraph graph. NOTE: Only supported if NetworkX-cuGraph is installed. Parameters ---------- G : Any The graph to convert. Currently supported types: - nxadb.Graph - nxcg.Graph as_directed : bool, optional Whether to convert the graph to a directed graph. Default is False. Returns ------- nxcg.Graph The converted graph. """ logger.debug(f"_to_nxcg_graph for {G.__class__.__name__}") if isinstance(G, nxcg.Graph): return G if isinstance(G, nxadb.Graph): logger.debug("converting nx_arangodb graph to nx_cugraph graph") if not G.graph_exists_in_db: return nxcg.convert.from_networkx(G) return nxadb_to_nxcg(G, as_directed=as_directed) raise TypeError(f"Expected nx_arangodb.Graph or nxcg.Graph; got {type(G)}") else:
[docs] def _to_nxcg_graph(G: Any, as_directed: bool = False) -> nxcg.Graph: m = "nx-cugraph is not installed; cannot convert to nx-cugraph" raise NotImplementedError(m)
[docs] def nx_to_nxadb( graph: nx.Graph, *args: Any, as_directed: bool = False, **kwargs: Any, ) -> nxadb.Graph: """Convert a NetworkX graph to a NetworkX-ArangoDB graph. Parameters ---------- graph : nx.Graph The NetworkX graph to convert. as_directed : bool, optional Whether to convert the graph to a directed graph. Default is False. Returns ------- nxadb.Graph The converted graph. """ logger.debug(f"from_networkx for {graph.__class__.__name__}") klass: type[nxadb.Graph] if graph.is_multigraph(): if graph.is_directed() or as_directed: klass = nxadb.MultiDiGraph else: klass = nxadb.MultiGraph else: if graph.is_directed() or as_directed: klass = nxadb.DiGraph else: klass = nxadb.Graph return klass(incoming_graph_data=graph)
[docs] def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph: """Convert a NetworkX-ArangoDB graph to a NetworkX graph. This function will pull the graph from the database if it does not exist in the cache. A new NetworkX graph will be created using the node and adjacency dictionaries that are fetched. NOTE: The current downside of this approach is that we are not able to take advantage of the custom Dictionary classes that we have implemented in nx_arangodb.classes.dict. This is because the node and adjacency dictionaries are fetched as regular Python dictionaries. Furthermore, we don't cache the dictionaries themselves, so we have to fetch them every time we convert the graph, which is currently being invoked on *every* algorithm call. See the note below for a potential solution. As a temporary workaround, users can do the following: ``` import networkx as nx import nx_arangodb as nxadb G_ADB = nxadb.Graph(name="MyGraph") # Connect to the graph G_NX = nxadb.convert._to_nx_graph(G_ADB) # Pull the graph nx.pagerank(G_NX) nx.betweenness_centrality(G_NX) ... ``` Parameters ---------- G : nxadb.Graph The NetworkX-ArangoDB graph to convert. Returns ------- nx.Graph The converted graph. """ if not G.graph_exists_in_db: # Since nxadb.Graph is a subclass of nx.Graph, we can return it as is. # This only applies if the graph does not exist in the database. return G assert isinstance(G._node, NodeDict) assert isinstance(G._adj, AdjListOuterDict) if G._node.FETCHED_ALL_DATA and G._adj.FETCHED_ALL_DATA: return G start_time = time.time() node_dict, adj_dict, *_ = nxadb.classes.function.get_arangodb_graph( adb_graph=G.adb_graph, load_node_dict=True, load_adj_dict=True, load_coo=False, edge_collections_attributes=G.edge_attributes, load_all_vertex_attributes=False, load_all_edge_attributes=len(G.edge_attributes) == 0, is_directed=G.is_directed(), is_multigraph=G.is_multigraph(), symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False, read_parallelism=G.read_parallelism, read_batch_size=G.read_batch_size, ) logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s") # NOTE: At this point, we _could_ choose to implement something similar to # NodeDict._fetch_all() and AdjListOuterDict._fetch_all() to iterate through # **node_dict** and **adj_dict**, and establish the "custom" Dictionary classes # that we've implemented in nx_arangodb.classes.dict. # However, this would involve adding additional for-loops and would likely be # slower than the current implementation. # Perhaps we should consider adding a feature flag to allow users to choose # between the two methods? e.g `build_remote_dicts=True/False` # If True, then we would return the (updated) nxadb.Graph that was passed in. # If False, then we would return the nx.Graph that is built below: G_NX: nx.Graph = G.to_networkx_class()() G_NX._node = node_dict if isinstance(G_NX, nx.DiGraph): G_NX._succ = G_NX._adj = adj_dict["succ"] G_NX._pred = adj_dict["pred"] else: G_NX._adj = adj_dict return G_NX
if GPU_AVAILABLE: def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph: """Convert a NetworkX-ArangoDB graph to a NetworkX-cuGraph graph. This function will pull the graph from the database if it does not exist in the cache. A new NetworkX-cuGraph graph will be created using the COO format that is fetched. The created graph will be cached in the nxadb.Graph object for future use. Parameters ---------- G : nxadb.Graph The NetworkX-ArangoDB graph to convert. as_directed : bool, optional Whether to convert the graph to a directed graph. Default is False. Returns ------- nxcg.Graph The converted graph. """ if G.use_nxcg_cache and G.nxcg_graph is not None: m = "**use_nxcg_cache** is enabled. using cached NXCG Graph. no pull required." # noqa logger.debug(m) return G.nxcg_graph start_time = time.time() ( _, _, src_indices, dst_indices, edge_indices, vertex_ids_to_index, edge_values, ) = nxadb.classes.function.get_arangodb_graph( adb_graph=G.adb_graph, load_node_dict=False, load_adj_dict=False, load_coo=True, edge_collections_attributes=G.edge_attributes, load_all_vertex_attributes=False, # not used load_all_edge_attributes=len(G.edge_attributes) == 0, is_directed=G.is_directed(), is_multigraph=G.is_multigraph(), symmetrize_edges_if_directed=( G.symmetrize_edges if G.is_directed() else False ), read_parallelism=G.read_parallelism, read_batch_size=G.read_batch_size, ) logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s") start_time = time.time() N = len(vertex_ids_to_index) src_indices_cp = cp.array(src_indices) dst_indices_cp = cp.array(dst_indices) edge_indices_cp = cp.array(edge_indices) if G.is_multigraph(): if G.is_directed() or as_directed: klass = nxcg.MultiDiGraph else: klass = nxcg.MultiGraph G.nxcg_graph = klass.from_coo( N=N, src_indices=src_indices_cp, dst_indices=dst_indices_cp, edge_indices=edge_indices_cp, edge_values=edge_values, # edge_masks, # node_values, # node_masks, key_to_id=vertex_ids_to_index, # edge_keys=edge_keys, ) else: if G.is_directed() or as_directed: klass = nxcg.DiGraph else: klass = nxcg.Graph G.nxcg_graph = klass.from_coo( N=N, src_indices=src_indices_cp, dst_indices=dst_indices_cp, edge_values=edge_values, # edge_masks, # node_values, # node_masks, key_to_id=vertex_ids_to_index, ) logger.info(f"NXCG Graph construction took {time.time() - start_time}s") return G.nxcg_graph