Added Networkx Support in python driver (#1716)
Co-authored-by: munmud <moontasir042@gmail.com>
diff --git a/.github/workflows/python-driver.yaml b/.github/workflows/python-driver.yaml
index 40eb58e..fb0ad5c 100644
--- a/.github/workflows/python-driver.yaml
+++ b/.github/workflows/python-driver.yaml
@@ -40,4 +40,5 @@
- name: Test
run: |
python test_age_py.py -db "postgres" -u "postgres" -pass "agens"
+ python test_networkx.py -db "postgres" -u "postgres" -pass "agens"
python -m unittest -v test_agtypes.py
diff --git a/drivers/python/README.md b/drivers/python/README.md
index cf3e53b..c7c6802 100644
--- a/drivers/python/README.md
+++ b/drivers/python/README.md
@@ -86,3 +86,74 @@
### License
Apache-2.0 License
+
+
+## Networkx
+### Netowkx Unit test
+```
+python test_networkx.py \
+-host "127.0.0.1" \
+-db "postgres" \
+-u "postgres" \
+-pass "agens" \
+-port 5432
+```
+Here the following value required
+- `-host` : host name (optional)
+- `-db` : database name
+- `-u` : user name
+- `-pass` : password
+- `-port` : port (optional)
+
+### Networkx to AGE
+Insert From networkx directed graph into an Age database.
+#### Parameters
+
+- `connection` (psycopg2.connect): Connection object to the Age database.
+
+- `G` (networkx.DiGraph): Networkx directed graph to be converted and inserted.
+
+- `graphName` (str): Name of the age graph.
+
+#### Returns
+
+None
+
+#### Example
+
+```python
+
+# Create a Networkx DiGraph
+G = nx.DiGraph()
+G.add_node(1)
+G.add_node(2)
+G.add_edge(1, 2)
+
+# Convert and insert the graph into the Age database
+graphName = "sample_graph"
+networkx_to_age(connection, G, graphName)
+```
+
+
+
+### AGE to Netowkx
+
+Converts data from a Apache AGE graph database into a Networkx directed graph.
+
+#### Parameters
+
+- `connection` (psycopg2.connect): Connection object to the PostgreSQL database.
+- `graphName` (str): Name of the graph.
+- `G` (None | nx.DiGraph): Optional Networkx directed graph. If provided, the data will be added to this graph.
+- `query` (str | None): Optional Cypher query to retrieve data from the database.
+
+#### Returns
+
+- `nx.DiGraph`: Networkx directed graph containing the converted data.
+
+#### Example
+
+```python
+# Call the function to convert data into a Networkx graph
+graph = age_to_networkx(connection, graphName="MyGraph" )
+```
diff --git a/drivers/python/age/networkx/__init__.py b/drivers/python/age/networkx/__init__.py
new file mode 100644
index 0000000..17a1d3f
--- /dev/null
+++ b/drivers/python/age/networkx/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from .networkx_to_age import networkx_to_age
+from .age_to_networkx import age_to_networkx
diff --git a/drivers/python/age/networkx/age_to_networkx.py b/drivers/python/age/networkx/age_to_networkx.py
new file mode 100644
index 0000000..3a16aa3
--- /dev/null
+++ b/drivers/python/age/networkx/age_to_networkx.py
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from age import *
+import psycopg2
+import networkx as nx
+from age.models import Vertex, Edge, Path
+from .lib import *
+
+
+def age_to_networkx(connection: psycopg2.connect,
+ graphName: str,
+ G: None | nx.DiGraph = None,
+ query: str | None = None
+ ) -> nx.DiGraph:
+ """
+ @params
+ ---------------------
+ connection - (psycopg2.connect) Connection object
+ graphName - (str) Name of the graph
+ G - (networkx.DiGraph) Networkx directed Graph [optional]
+ query - (str) Cypher query [optional]
+
+ @returns
+ ------------
+ Networkx directed Graph
+
+ """
+
+ # Check if the age graph exists
+ checkIfGraphNameExistInAGE(connection, graphName)
+
+ # Create an empty directed graph
+ if G == None:
+ G = nx.DiGraph()
+
+ def addNodeToNetworkx(node):
+ """Add Nodes in Networkx"""
+ G.add_node(node.id,
+ label=node.label,
+ properties=node.properties)
+
+ def addEdgeToNetworkx(edge):
+ """Add Edge in Networkx"""
+ G.add_edge(edge.start_id,
+ edge.end_id,
+ label=edge.label,
+ properties=edge.properties)
+
+ def addPath(path):
+ """Add Edge in Networkx"""
+ for x in path:
+ if (type(x) == Path):
+ addPath(x)
+ for x in path:
+ if (type(x) == Vertex):
+ addNodeToNetworkx(x)
+ for x in path:
+ if (type(x) == Edge):
+ addEdgeToNetworkx(x)
+
+ # Setting up connection to work with Graph
+ age.setUpAge(connection, graphName)
+
+ if (query == None):
+ addAllNodesIntoNetworkx(connection, graphName, G)
+ addAllEdgesIntoNetworkx(connection, graphName, G)
+ else:
+ with connection.cursor() as cursor:
+ cursor.execute(query)
+ rows = cursor.fetchall()
+ for row in rows:
+ for x in row:
+ if type(x) == Path:
+ addPath(x)
+ elif type(x) == Edge:
+ addEdgeToNetworkx(x)
+ elif type(x) == Vertex:
+ addNodeToNetworkx(x)
+ return G
diff --git a/drivers/python/age/networkx/lib.py b/drivers/python/age/networkx/lib.py
new file mode 100644
index 0000000..a4df088
--- /dev/null
+++ b/drivers/python/age/networkx/lib.py
@@ -0,0 +1,253 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from age import *
+import psycopg2
+import networkx as nx
+from psycopg2 import sql
+from psycopg2.extras import execute_values
+from typing import Dict, Any, List, Set
+from age.models import Vertex, Edge, Path
+
+
+def checkIfGraphNameExistInAGE(connection: psycopg2.connect,
+ graphName: str):
+ """Check if the age graph exists"""
+ with connection.cursor() as cursor:
+ cursor.execute(sql.SQL("""
+ SELECT count(*)
+ FROM ag_catalog.ag_graph
+ WHERE name='%s'
+ """ % (graphName)))
+ if cursor.fetchone()[0] == 0:
+ raise GraphNotFound(graphName)
+
+
+def getOidOfGraph(connection: psycopg2.connect,
+ graphName: str) -> int:
+ """Returns oid of a graph"""
+ try:
+ with connection.cursor() as cursor:
+ cursor.execute(sql.SQL("""
+ SELECT graphid FROM ag_catalog.ag_graph WHERE name='%s' ;
+ """ % (graphName)))
+ oid = cursor.fetchone()[0]
+ return oid
+ except Exception as e:
+ print(e)
+
+
+def get_vlabel(connection: psycopg2.connect,
+ graphName: str):
+ node_label_list = []
+ oid = getOidOfGraph(connection, graphName)
+ try:
+ with connection.cursor() as cursor:
+ cursor.execute(
+ """SELECT name FROM ag_catalog.ag_label WHERE kind='v' AND graph=%s;""" % oid)
+ for row in cursor:
+ node_label_list.append(row[0])
+
+ except Exception as ex:
+ print(type(ex), ex)
+ return node_label_list
+
+
+def create_vlabel(connection: psycopg2.connect,
+ graphName: str,
+ node_label_list: List):
+ """create_vlabels from list if not exist"""
+ try:
+ node_label_set = set(get_vlabel(connection, graphName))
+ crete_label_statement = ''
+ for label in node_label_list:
+ if label in node_label_set:
+ continue
+ crete_label_statement += """SELECT create_vlabel('%s','%s');\n""" % (
+ graphName, label)
+ if crete_label_statement != '':
+ with connection.cursor() as cursor:
+ cursor.execute(crete_label_statement)
+ connection.commit()
+ except Exception as e:
+ raise Exception(e)
+
+
+def get_elabel(connection: psycopg2.connect,
+ graphName: str):
+ edge_label_list = []
+ oid = getOidOfGraph(connection, graphName)
+ try:
+ with connection.cursor() as cursor:
+ cursor.execute(
+ """SELECT name FROM ag_catalog.ag_label WHERE kind='e' AND graph=%s;""" % oid)
+ for row in cursor:
+ edge_label_list.append(row[0])
+ except Exception as ex:
+ print(type(ex), ex)
+ return edge_label_list
+
+
+def create_elabel(connection: psycopg2.connect,
+ graphName: str,
+ edge_label_list: List):
+ """create_vlabels from list if not exist"""
+ try:
+ edge_label_set = set(get_elabel(connection, graphName))
+ crete_label_statement = ''
+ for label in edge_label_list:
+ if label in edge_label_set:
+ continue
+ crete_label_statement += """SELECT create_elabel('%s','%s');\n""" % (
+ graphName, label)
+ if crete_label_statement != '':
+ with connection.cursor() as cursor:
+ cursor.execute(crete_label_statement)
+ connection.commit()
+ except Exception as e:
+ raise Exception(e)
+
+
+def getNodeLabelListAfterPreprocessing(G: nx.DiGraph):
+ """
+ - Add default label if label is missing
+ - Add properties if not exist
+ - return all distinct node label
+ """
+ node_label_list = set()
+ try:
+ for node, data in G.nodes(data=True):
+ if 'label' not in data:
+ data['label'] = '_ag_label_vertex'
+ if 'properties' not in data:
+ data['properties'] = {}
+ if not isinstance(data['label'], str):
+ raise Exception(f"label of node : {node} must be a string")
+ if not isinstance(data['properties'], Dict):
+ raise Exception(f"properties of node : {node} must be a dict")
+ if '__id__' not in data['properties'].keys():
+ data['properties']['__id__'] = node
+ node_label_list.add(data['label'])
+ except Exception as e:
+ raise Exception(e)
+ return node_label_list
+
+
+def getEdgeLabelListAfterPreprocessing(G: nx.DiGraph):
+ """
+ - Add default label if label is missing
+ - Add properties if not exist
+ - return all distinct edge label
+ """
+ edge_label_list = set()
+ try:
+ for u, v, data in G.edges(data=True):
+ if 'label' not in data:
+ data['label'] = '_ag_label_edge'
+ if 'properties' not in data:
+ data['properties'] = {}
+ if not isinstance(data['label'], str):
+ raise Exception(f"label of edge : {u}->{v} must be a string")
+ if not isinstance(data['properties'], Dict):
+ raise Exception(
+ f"properties of edge : {u}->{v} must be a dict")
+ edge_label_list.add(data['label'])
+ except Exception as e:
+ raise Exception(e)
+ return edge_label_list
+
+
+def addAllNodesIntoAGE(connection: psycopg2.connect, graphName: str, G: nx.DiGraph, node_label_list: Set):
+ """Add all node to AGE"""
+ try:
+ queue_data = {label: [] for label in node_label_list}
+ id_data = {}
+
+ for node, data in G.nodes(data=True):
+ json_string = json.dumps(data['properties'])
+ queue_data[data['label']].append((json_string,))
+
+ for label, rows in queue_data.items():
+ table_name = """%s."%s" """ % (graphName, label)
+ insert_query = f"INSERT INTO {table_name} (properties) VALUES %s RETURNING id"
+ cursor = connection.cursor()
+ id_data[label] = execute_values(
+ cursor, insert_query, rows, fetch=True)
+ connection.commit()
+ cursor.close()
+ id_data[label].reverse()
+
+ for node, data in G.nodes(data=True):
+ data['properties']['__gid__'] = id_data[data['label']][-1][0]
+ id_data[data['label']].pop()
+
+ except Exception as e:
+ raise Exception(e)
+
+
+def addAllEdgesIntoAGE(connection: psycopg2.connect, graphName: str, G: nx.DiGraph, edge_label_list: Set):
+ """Add all edge to AGE"""
+ try:
+ queue_data = {label: [] for label in edge_label_list}
+ for u, v, data in G.edges(data=True):
+ json_string = json.dumps(data['properties'])
+ queue_data[data['label']].append(
+ (G.nodes[u]['properties']['__gid__'], G.nodes[v]['properties']['__gid__'], json_string,))
+
+ for label, rows in queue_data.items():
+ table_name = """%s."%s" """ % (graphName, label)
+ insert_query = f"INSERT INTO {table_name} (start_id,end_id,properties) VALUES %s"
+ cursor = connection.cursor()
+ execute_values(cursor, insert_query, rows)
+ connection.commit()
+ cursor.close()
+ except Exception as e:
+ raise Exception(e)
+
+
+def addAllNodesIntoNetworkx(connection: psycopg2.connect, graphName: str, G: nx.DiGraph):
+ """Add all nodes to Networkx"""
+ node_label_list = get_vlabel(connection, graphName)
+ try:
+ for label in node_label_list:
+ with connection.cursor() as cursor:
+ cursor.execute("""
+ SELECT id, CAST(properties AS VARCHAR)
+ FROM %s."%s";
+ """ % (graphName, label))
+ rows = cursor.fetchall()
+ for row in rows:
+ G.add_node(int(row[0]), label=label,
+ properties=json.loads(row[1]))
+ except Exception as e:
+ print(e)
+
+
+def addAllEdgesIntoNetworkx(connection: psycopg2.connect, graphName: str, G: nx.DiGraph):
+ """Add All edges to Networkx"""
+ try:
+ edge_label_list = get_elabel(connection, graphName)
+ for label in edge_label_list:
+ with connection.cursor() as cursor:
+ cursor.execute("""
+ SELECT start_id, end_id, CAST(properties AS VARCHAR)
+ FROM %s."%s";
+ """ % (graphName, label))
+ rows = cursor.fetchall()
+ for row in rows:
+ G.add_edge(int(row[0]), int(
+ row[1]), label=label, properties=json.loads(row[2]))
+ except Exception as e:
+ print(e)
diff --git a/drivers/python/age/networkx/networkx_to_age.py b/drivers/python/age/networkx/networkx_to_age.py
new file mode 100644
index 0000000..90cee97
--- /dev/null
+++ b/drivers/python/age/networkx/networkx_to_age.py
@@ -0,0 +1,49 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from age import *
+import psycopg2
+import networkx as nx
+from .lib import *
+
+
+def networkx_to_age(connection: psycopg2.connect,
+ G: nx.DiGraph,
+ graphName: str):
+ """
+ @params
+ -----------
+ connection - (psycopg2.connect) Connection object
+
+ G - (networkx.DiGraph) Networkx directed Graph
+
+ graphName - (str) Name of the graph
+
+ @returns
+ ------------
+ None
+
+ """
+ node_label_list = getNodeLabelListAfterPreprocessing(G)
+ edge_label_list = getEdgeLabelListAfterPreprocessing(G)
+
+ # Setup connection with Graph
+ age.setUpAge(connection, graphName)
+
+ create_vlabel(connection, graphName, node_label_list)
+ create_elabel(connection, graphName, edge_label_list)
+
+ addAllNodesIntoAGE(connection, graphName, G, node_label_list)
+ addAllEdgesIntoAGE(connection, graphName, G, edge_label_list)
diff --git a/drivers/python/requirements.txt b/drivers/python/requirements.txt
index 4951053..81d1ef7 100644
--- a/drivers/python/requirements.txt
+++ b/drivers/python/requirements.txt
Binary files differ
diff --git a/drivers/python/samples/networkx.ipynb b/drivers/python/samples/networkx.ipynb
new file mode 100644
index 0000000..20fc566
--- /dev/null
+++ b/drivers/python/samples/networkx.ipynb
@@ -0,0 +1,183 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import psycopg2\n",
+ "from age.networkx import *\n",
+ "from age import *\n",
+ "import networkx as nx\n",
+ "\n",
+ "conn = psycopg2.connect(\n",
+ " host=\"localhost\",\n",
+ " port=\"5432\",\n",
+ " dbname=\"postgres\",\n",
+ " user=\"moontasir\",\n",
+ " password=\"254826\")\n",
+ "graphName = 'bitnine_global_inic'\n",
+ "\n",
+ "\n",
+ "age.setUpAge(conn, graphName)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Creating a Random Networkx Graph"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "DiGraph with 5 nodes and 9 edges\n"
+ ]
+ }
+ ],
+ "source": [
+ "import random\n",
+ "num_nodes = 5\n",
+ "num_edges = 10\n",
+ "G = nx.DiGraph()\n",
+ "try:\n",
+ " for i in range(num_nodes//2):\n",
+ " G.add_node(i, label='Number',properties={'name' : i})\n",
+ " for i in range(num_nodes//2,num_nodes):\n",
+ " G.add_node(i, label='Integer',properties={'age' : i*2} )\n",
+ " for i in range(num_edges):\n",
+ " source = random.randint(0, num_nodes-1)\n",
+ " target = random.randint(0, num_nodes-1)\n",
+ " G.add_edge(source, target, label='Connection' ,properties={'st' : source , 'ed':target})\n",
+ "except Exception as e:\n",
+ " raise e\n",
+ "print(G)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Networkx to AGE"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "networkx_to_age(conn, G, graphName)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## AGE to Networkx"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### ALL AGE to Networkx\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "DiGraph with 5 nodes and 9 edges\n"
+ ]
+ }
+ ],
+ "source": [
+ "G = age_to_networkx(conn, graphName)\n",
+ "print(G)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Add subgraph of AGE to Networkx using query"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "DiGraph with 5 nodes and 9 edges\n"
+ ]
+ }
+ ],
+ "source": [
+ "G = age_to_networkx(conn, graphName, \n",
+ " query=\"\"\"\n",
+ "SELECT * from cypher('%s', $$\n",
+ " MATCH (V:Integer)\n",
+ " RETURN V\n",
+ "$$) as (V agtype);\n",
+ "\"\"\" % graphName)\n",
+ "\n",
+ "G = age_to_networkx(conn, graphName, G=G,\n",
+ " query=\"\"\"\n",
+ "SELECT * from cypher('%s', $$\n",
+ " MATCH (V:Number)\n",
+ " RETURN V\n",
+ "$$) as (V agtype);\n",
+ "\"\"\" % graphName)\n",
+ "\n",
+ "G = age_to_networkx(conn, graphName, G=G,\n",
+ " query=\"\"\"\n",
+ "SELECT * from cypher('%s', $$\n",
+ " MATCH (V)-[R]->(V2)\n",
+ " RETURN V,R,V2\n",
+ "$$) as (V agtype, R agtype, V2 agtype);\n",
+ "\"\"\" % graphName)\n",
+ "print(G)\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "myenv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.6"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/drivers/python/setup.py b/drivers/python/setup.py
index 6315022..15be9e7 100644
--- a/drivers/python/setup.py
+++ b/drivers/python/setup.py
@@ -30,8 +30,8 @@
url = 'https://github.com/apache/age/tree/master/drivers/python',
download_url = 'https://github.com/apache/age/releases' ,
license = 'Apache2.0',
- install_requires = [ 'psycopg2', 'antlr4_python3_runtime==4.11.1'],
- packages = ['age', 'age.gen'],
+ install_requires = [ 'psycopg2', 'antlr4-python3-runtime==4.11.1'],
+ packages = ['age', 'age.gen','age.networkx'],
keywords = ['Graph Database', 'Apache AGE', 'PostgreSQL'],
python_requires = '>=3.9',
classifiers = [
diff --git a/drivers/python/test_networkx.py b/drivers/python/test_networkx.py
new file mode 100644
index 0000000..290d4c4
--- /dev/null
+++ b/drivers/python/test_networkx.py
@@ -0,0 +1,480 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import age
+import unittest
+import argparse
+import networkx as nx
+from age.models import *
+from age.networkx import *
+from age.exceptions import *
+
+
+TEST_GRAPH_NAME = "test_graph"
+ORIGINAL_GRAPH = "original_graph"
+EXPECTED_GRAPH = "expected_graph"
+
+
+class TestAgeToNetworkx(unittest.TestCase):
+ ag = None
+
+ def setUp(self):
+
+ TEST_DB = self.args.database
+ TEST_USER = self.args.user
+ TEST_PASSWORD = self.args.password
+ TEST_PORT = self.args.port
+ TEST_HOST = self.args.host
+ self.ag = age.connect(graph=TEST_GRAPH_NAME, host=TEST_HOST, port=TEST_PORT,
+ dbname=TEST_DB, user=TEST_USER, password=TEST_PASSWORD)
+
+ def tearDown(self):
+ age.deleteGraph(self.ag.connection, self.ag.graphName)
+ self.ag.close()
+
+ def compare_networkX(self, G, H):
+ if G.number_of_nodes() != H.number_of_nodes():
+ return False
+ if G.number_of_edges() != H.number_of_edges():
+ return False
+ # test nodes
+ nodes_G, nodes_H = G.number_of_nodes(), H.number_of_nodes()
+ markG, markH = [0]*nodes_G, [0]*nodes_H
+ nodes_list_G, nodes_list_H = list(G.nodes), list(H.nodes)
+ for i in range(0, nodes_G):
+ for j in range(0, nodes_H):
+ if markG[i] == 0 and markH[j] == 0:
+ node_id_G = nodes_list_G[i]
+ property_G = G.nodes[node_id_G]
+
+ node_id_H = nodes_list_H[j]
+ property_H = H.nodes[node_id_H]
+ if property_G == property_H:
+ markG[i] = 1
+ markH[j] = 1
+
+ if any(elem == 0 for elem in markG):
+ return False
+ if any(elem == 0 for elem in markH):
+ return False
+
+ # test edges
+ edges_G, edges_H = G.number_of_edges(), H.number_of_edges()
+ markG, markH = [0]*edges_G, [0]*edges_H
+ edges_list_G, edges_list_H = list(G.edges), list(H.edges)
+
+ for i in range(0, edges_G):
+ for j in range(0, edges_H):
+ if markG[i] == 0 and markH[j] == 0:
+ source_G, target_G = edges_list_G[i]
+ property_G = G.edges[source_G, target_G]
+
+ source_H, target_H = edges_list_H[j]
+ property_H = H.edges[source_H, target_H]
+
+ if property_G == property_H:
+ markG[i] = 1
+ markH[j] = 1
+
+ if any(elem == 0 for elem in markG):
+ return False
+ if any(elem == 0 for elem in markH):
+ return False
+
+ return True
+
+ def test_empty_graph(self):
+ print('Testing AGE to Networkx for empty graph')
+ # Expected Graph
+ # Empty Graph
+ G = nx.DiGraph()
+
+ # Convert Apache AGE to NetworkX
+ H = age_to_networkx(self.ag.connection, TEST_GRAPH_NAME)
+
+ self.assertTrue(self.compare_networkX(G, H))
+
+ def test_existing_graph_without_query(self):
+ print('Testing AGE to Networkx for non empty graph without query')
+ ag = self.ag
+ # Create nodes
+ ag.execCypher("CREATE (n:Person {name: %s}) ", params=('Jack',))
+ ag.execCypher("CREATE (n:Person {name: %s}) ", params=('Andy',))
+ ag.execCypher("CREATE (n:Person {name: %s}) ", params=('Smith',))
+ ag.commit()
+
+ # Create Edges
+ ag.execCypher("""MATCH (a:Person), (b:Person)
+ WHERE a.name = 'Andy' AND b.name = 'Jack'
+ CREATE (a)-[r:workWith {weight: 3}]->(b)""")
+ ag.execCypher("""MATCH (a:Person), (b:Person)
+ WHERE a.name = %s AND b.name = %s
+ CREATE p=((a)-[r:workWith {weight: 10}]->(b)) """, params=('Jack', 'Smith',))
+ ag.commit()
+
+ G = age_to_networkx(self.ag.connection, TEST_GRAPH_NAME)
+
+ # Check that the G has the expected properties
+ self.assertIsInstance(G, nx.DiGraph)
+
+ # Check that the G has the correct number of nodes and edges
+ self.assertEqual(len(G.nodes), 3)
+ self.assertEqual(len(G.edges), 2)
+
+ # Check that the node properties are correct
+ for node in G.nodes:
+ self.assertEqual(int, type(node))
+ self.assertEqual(G.nodes[node]['label'], 'Person')
+ self.assertIn('name', G.nodes[node]['properties'])
+ self.assertIn('properties', G.nodes[node])
+ self.assertEqual(str, type(G.nodes[node]['label']))
+
+ # Check that the edge properties are correct
+ for edge in G.edges:
+ self.assertEqual(tuple, type(edge))
+ self.assertEqual(int, type(edge[0]) and type(edge[1]))
+ self.assertEqual(G.edges[edge]['label'], 'workWith')
+ self.assertIn('weight', G.edges[edge]['properties'])
+ self.assertEqual(int, type(G.edges[edge]['properties']['weight']))
+
+ def test_existing_graph_with_query(self):
+ print('Testing AGE to Networkx for non empty graph with query')
+
+ ag = self.ag
+ # Create nodes
+ ag.execCypher("CREATE (n:Person {name: %s}) ", params=('Jack',))
+ ag.execCypher("CREATE (n:Person {name: %s}) ", params=('Andy',))
+ ag.execCypher("CREATE (n:Person {name: %s}) ", params=('Smith',))
+ ag.commit()
+
+ # Create Edges
+ ag.execCypher("""MATCH (a:Person), (b:Person)
+ WHERE a.name = 'Andy' AND b.name = 'Jack'
+ CREATE (a)-[r:workWith {weight: 3}]->(b)""")
+ ag.execCypher("""MATCH (a:Person), (b:Person)
+ WHERE a.name = %s AND b.name = %s
+ CREATE p=((a)-[r:workWith {weight: 10}]->(b)) """, params=('Jack', 'Smith',))
+ ag.commit()
+
+ query = """SELECT * FROM cypher('%s', $$ MATCH (a:Person)-[r:workWith]->(b:Person)
+ WHERE a.name = 'Andy'
+ RETURN a, r, b $$) AS (a agtype, r agtype, b agtype);
+ """ % (TEST_GRAPH_NAME)
+
+ G = age_to_networkx(self.ag.connection,
+ graphName=TEST_GRAPH_NAME, query=query)
+
+ # Check that the G has the expected properties
+ self.assertIsInstance(G, nx.DiGraph)
+
+ # Check that the G has the correct number of nodes and edges
+ self.assertEqual(len(G.nodes), 2)
+ self.assertEqual(len(G.edges), 1)
+
+ # Check that the node properties are correct
+ for node in G.nodes:
+ self.assertEqual(int, type(node))
+ self.assertEqual(G.nodes[node]['label'], 'Person')
+ self.assertIn('name', G.nodes[node]['properties'])
+ self.assertIn('properties', G.nodes[node])
+ self.assertEqual(str, type(G.nodes[node]['label']))
+
+ # Check that the edge properties are correct
+ for edge in G.edges:
+ self.assertEqual(tuple, type(edge))
+ self.assertEqual(int, type(edge[0]) and type(edge[1]))
+ self.assertEqual(G.edges[edge]['label'], 'workWith')
+ self.assertIn('weight', G.edges[edge]['properties'])
+ self.assertEqual(int, type(G.edges[edge]['properties']['weight']))
+
+ def test_existing_graph(self):
+ print("Testing AGE to NetworkX for non-existing graph")
+ ag = self.ag
+
+ non_existing_graph = "non_existing_graph"
+ # Check that the function raises an exception for non existing graph
+ with self.assertRaises(GraphNotFound) as context:
+ age_to_networkx(ag.connection, graphName=non_existing_graph)
+ # Check the raised exception has the expected error message
+ self.assertEqual(str(context.exception), non_existing_graph)
+
+
+class TestNetworkxToAGE(unittest.TestCase):
+ ag = None
+ ag1 = None
+ ag2 = None
+
+ def setUp(self):
+ TEST_DB = self.args.database
+ TEST_USER = self.args.user
+ TEST_PASSWORD = self.args.password
+ TEST_PORT = self.args.port
+ TEST_HOST = self.args.host
+ self.ag = age.connect(graph=TEST_GRAPH_NAME, host=TEST_HOST, port=TEST_PORT,
+ dbname=TEST_DB, user=TEST_USER, password=TEST_PASSWORD)
+ self.ag1 = age.connect(graph=ORIGINAL_GRAPH, host=TEST_HOST, port=TEST_PORT,
+ dbname=TEST_DB, user=TEST_USER, password=TEST_PASSWORD)
+ self.ag2 = age.connect(graph=EXPECTED_GRAPH, host=TEST_HOST, port=TEST_PORT,
+ dbname=TEST_DB, user=TEST_USER, password=TEST_PASSWORD)
+ self.graph = nx.DiGraph()
+
+ def tearDown(self):
+ age.deleteGraph(self.ag1.connection, self.ag1.graphName)
+ age.deleteGraph(self.ag2.connection, self.ag2.graphName)
+ age.deleteGraph(self.ag.connection, self.ag.graphName)
+ self.ag.close()
+ self.ag1.close()
+ self.ag2.close()
+
+ def compare_age(self, age1, age2):
+ cursor = age1.execCypher("MATCH (v) RETURN v")
+ g_nodes = cursor.fetchall()
+
+ cursor = age1.execCypher("MATCH ()-[r]->() RETURN r")
+ g_edges = cursor.fetchall()
+
+ cursor = age2.execCypher("MATCH (v) RETURN v")
+ h_nodes = cursor.fetchall()
+
+ cursor = age2.execCypher("MATCH ()-[r]->() RETURN r")
+ h_edges = cursor.fetchall()
+
+ if len(g_nodes) != len(h_nodes) or len(g_edges) != len(h_edges):
+ return False
+
+ # test nodes
+ nodes_G, nodes_H = len(g_nodes), len(h_nodes)
+ markG, markH = [0]*nodes_G, [0]*nodes_H
+
+ # return True
+ for i in range(0, nodes_G):
+ for j in range(0, nodes_H):
+ if markG[i] == 0 and markH[j] == 0:
+ property_G = g_nodes[i][0].properties
+ property_G['label'] = g_nodes[i][0].label
+ property_G.pop('__id__')
+
+ property_H = h_nodes[j][0].properties
+ property_H['label'] = h_nodes[j][0].label
+
+ if property_G == property_H:
+ markG[i] = 1
+ markH[j] = 1
+
+ if any(elem == 0 for elem in markG):
+ return False
+ if any(elem == 0 for elem in markH):
+ return False
+
+ # test edges
+ edges_G, edges_H = len(g_edges), len(h_edges)
+ markG, markH = [0]*edges_G, [0]*edges_H
+
+ for i in range(0, edges_G):
+ for j in range(0, edges_H):
+ if markG[i] == 0 and markH[j] == 0:
+ property_G = g_edges[i][0].properties
+ property_G['label'] = g_edges[i][0].label
+
+ property_H = h_edges[j][0].properties
+ property_H['label'] = h_edges[j][0].label
+ if property_G == property_H:
+ markG[i] = 1
+ markH[j] = 1
+
+ if any(elem == 0 for elem in markG):
+ return False
+ if any(elem == 0 for elem in markH):
+ return False
+
+ return True
+
+ def test_number_of_nodes_and_edges(self):
+ print("Testing Networkx To AGE for number of nodes and edges")
+ ag = self.ag
+
+ # Create NetworkX graph
+ self.graph.add_node(1, label='Person', properties={
+ 'name': 'Moontasir', 'age': '26', 'id': 1})
+ self.graph.add_node(2, label='Person', properties={
+ 'name': 'Austen', 'age': '26', 'id': 2})
+ self.graph.add_edge(1, 2, label='KNOWS', properties={
+ 'since': '1997', 'start_id': 1, 'end_id': 2})
+ self.graph.add_node(3, label='Person', properties={
+ 'name': 'Eric', 'age': '28', 'id': 3})
+
+ networkx_to_age(self.ag.connection, self.graph, TEST_GRAPH_NAME)
+
+ # Check that node(s) were created
+ cursor = ag.execCypher('MATCH (n) RETURN n')
+ result = cursor.fetchall()
+ # Check number of vertices created
+ self.assertEqual(len(result), 3)
+ # Checks if type of property in query output is a Vertex
+ self.assertEqual(Vertex, type(result[0][0]))
+ self.assertEqual(Vertex, type(result[1][0]))
+
+ # Check that edge(s) was created
+ cursor = ag.execCypher('MATCH ()-[e]->() RETURN e')
+ result = cursor.fetchall()
+ # Check number of edge(s) created
+ self.assertEqual(len(result), 1)
+ # Checks if type of property in query output is an Edge
+ self.assertEqual(Edge, type(result[0][0]))
+
+ def test_empty_graph(self):
+ print("Testing Networkx To AGE for empty Graph")
+ # Expected Graph
+ # Empty Graph
+
+ # NetworkX Graph
+ G = nx.DiGraph()
+
+ # Convert Apache AGE to NetworkX
+ networkx_to_age(self.ag1.connection, G, ORIGINAL_GRAPH)
+
+ self.assertTrue(self.compare_age(self.ag1, self.ag2))
+
+ def test_non_empty_graph(self):
+ print("Testing Networkx To AGE for non-empty Graph")
+ # Expected Graph
+ self.ag2.execCypher("CREATE (:l1 {name: 'n1', weight: '5'})")
+ self.ag2.execCypher("CREATE (:l1 {name: 'n2', weight: '4'})")
+ self.ag2.execCypher("CREATE (:l1 {name: 'n3', weight: '9'})")
+
+ self.ag2.execCypher("""MATCH (a:l1), (b:l1)
+ WHERE a.name = 'n1' AND b.name = 'n2'
+ CREATE (a)-[e:e1 {property:'graph'}]->(b)""")
+ self.ag2.execCypher("""MATCH (a:l1), (b:l1)
+ WHERE a.name = 'n2' AND b.name = 'n3'
+ CREATE (a)-[e:e2 {property:'node'}]->(b)""")
+
+ # NetworkX Graph
+ G = nx.DiGraph()
+
+ G.add_node('1',
+ label='l1',
+ properties={'name': 'n1',
+ 'weight': '5'})
+ G.add_node('2',
+ label='l1',
+ properties={'name': 'n2',
+ 'weight': '4'})
+ G.add_node('3',
+ label='l1',
+ properties={'name': 'n3',
+ 'weight': '9'})
+ G.add_edge('1', '2', label='e1', properties={'property': 'graph'})
+ G.add_edge('2', '3', label='e2', properties={'property': 'node'})
+
+ # Convert Apache AGE to NetworkX
+ networkx_to_age(self.ag1.connection, G, ORIGINAL_GRAPH)
+
+ self.assertTrue(self.compare_age(self.ag1, self.ag2))
+
+ def test_invalid_node_label(self):
+ print("Testing Networkx To AGE for invalid node label")
+ self.graph.add_node(4, label=123, properties={
+ 'name': 'Mason', 'age': '24', 'id': 4})
+
+ # Check that the function raises an exception for the invalid node label
+ with self.assertRaises(Exception) as context:
+ networkx_to_age(self.ag.connection, G=self.graph,
+ graphName=TEST_GRAPH_NAME)
+ # Check the raised exception has the expected error message
+ self.assertEqual(str(context.exception),
+ "label of node : 4 must be a string")
+
+ def test_invalid_node_properties(self):
+ print("Testing Networkx To AGE for invalid node properties")
+ self.graph.add_node(4, label='Person', properties="invalid")
+
+ # Check that the function raises an exception for the invalid node properties
+ with self.assertRaises(Exception) as context:
+ networkx_to_age(self.ag.connection, G=self.graph,
+ graphName=TEST_GRAPH_NAME)
+ # Check the raised exception has the expected error message
+ self.assertEqual(str(context.exception),
+ "properties of node : 4 must be a dict")
+
+ def test_invalid_edge_label(self):
+ print("Testing Networkx To AGE for invalid edge label")
+ self.graph.add_edge(1, 2, label=123, properties={
+ 'since': '1997', 'start_id': 1, 'end_id': 2})
+
+ # Check that the function raises an exception for the invalid edge label
+ with self.assertRaises(Exception) as context:
+ networkx_to_age(self.ag.connection, G=self.graph,
+ graphName=TEST_GRAPH_NAME)
+ # Check the raised exception has the expected error message
+ self.assertEqual(str(context.exception),
+ "label of edge : 1->2 must be a string")
+
+ def test_invalid_edge_properties(self):
+ print("Testing Networkx To AGE for invalid edge properties")
+ self.graph.add_edge(1, 2, label='KNOWS', properties="invalid")
+
+ # Check that the function raises an exception for the invalid edge properties
+ with self.assertRaises(Exception) as context:
+ networkx_to_age(self.ag.connection, G=self.graph,
+ graphName=TEST_GRAPH_NAME)
+ # Check the raised exception has the expected error message
+ self.assertEqual(str(context.exception),
+ "properties of edge : 1->2 must be a dict")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('-host',
+ '--host',
+ help='Optional Host Name. Default Host is "127.0.0.1" ',
+ default="127.0.0.1")
+ parser.add_argument('-port',
+ '--port',
+ help='Optional Port Number. Default port no is 5432',
+ default=5432)
+ parser.add_argument('-db',
+ '--database',
+ help='Required Database Name',
+ required=True)
+ parser.add_argument('-u',
+ '--user',
+ help='Required Username Name',
+ required=True)
+ parser.add_argument('-pass',
+ '--password',
+ help='Required Password for authentication',
+ required=True)
+
+ args = parser.parse_args()
+ suite = unittest.TestSuite()
+
+ suite.addTest(TestAgeToNetworkx('test_empty_graph'))
+ suite.addTest(TestAgeToNetworkx('test_existing_graph_without_query'))
+ suite.addTest(TestAgeToNetworkx('test_existing_graph_with_query'))
+ suite.addTest(TestAgeToNetworkx('test_existing_graph'))
+ TestAgeToNetworkx.args = args
+
+ suite.addTest(TestNetworkxToAGE('test_number_of_nodes_and_edges'))
+ suite.addTest(TestNetworkxToAGE('test_empty_graph'))
+ suite.addTest(TestNetworkxToAGE('test_non_empty_graph'))
+ suite.addTest(TestNetworkxToAGE('test_invalid_node_label'))
+ suite.addTest(TestNetworkxToAGE('test_invalid_node_properties'))
+ suite.addTest(TestNetworkxToAGE('test_invalid_edge_label'))
+ suite.addTest(TestNetworkxToAGE('test_invalid_edge_properties'))
+ TestNetworkxToAGE.args = args
+
+ unittest.TextTestRunner().run(suite)