Added Networkx Support in python driver (#1716)

Co-authored-by: munmud <moontasir042@gmail.com>
diff --git a/.github/workflows/python-driver.yaml b/.github/workflows/python-driver.yaml
index 40eb58e..fb0ad5c 100644
--- a/.github/workflows/python-driver.yaml
+++ b/.github/workflows/python-driver.yaml
@@ -40,4 +40,5 @@
     - name: Test
       run: |
         python test_age_py.py -db "postgres" -u "postgres" -pass "agens"
+        python test_networkx.py -db "postgres" -u "postgres" -pass "agens"
         python -m unittest -v test_agtypes.py
diff --git a/drivers/python/README.md b/drivers/python/README.md
index cf3e53b..c7c6802 100644
--- a/drivers/python/README.md
+++ b/drivers/python/README.md
@@ -86,3 +86,74 @@
 

 ### License

 Apache-2.0 License

+

+

+## Networkx

+### Netowkx Unit test

+```

+python test_networkx.py \

+-host "127.0.0.1" \

+-db "postgres" \

+-u "postgres" \

+-pass "agens" \

+-port 5432

+```

+Here the following value required

+- `-host` : host name (optional)

+- `-db` : database name

+- `-u` : user name

+- `-pass` : password

+- `-port` : port (optional)

+

+### Networkx to AGE

+Insert From networkx directed graph into an Age database.

+#### Parameters

+

+- `connection` (psycopg2.connect): Connection object to the Age database.

+

+- `G` (networkx.DiGraph): Networkx directed graph to be converted and inserted.

+

+- `graphName` (str): Name of the age graph.

+

+#### Returns

+

+None

+

+#### Example

+

+```python

+

+# Create a Networkx DiGraph

+G = nx.DiGraph()

+G.add_node(1)

+G.add_node(2)

+G.add_edge(1, 2)

+

+# Convert and insert the graph into the Age database

+graphName = "sample_graph"

+networkx_to_age(connection, G, graphName)

+```

+

+

+

+### AGE to Netowkx

+

+Converts data from a Apache AGE graph database into a Networkx directed graph.

+

+#### Parameters

+

+- `connection` (psycopg2.connect): Connection object to the PostgreSQL database.

+- `graphName` (str): Name of the graph.

+- `G` (None | nx.DiGraph): Optional Networkx directed graph. If provided, the data will be added to this graph.

+- `query` (str | None): Optional Cypher query to retrieve data from the database.

+

+#### Returns

+

+- `nx.DiGraph`: Networkx directed graph containing the converted data.

+

+#### Example

+

+```python

+# Call the function to convert data into a Networkx graph

+graph = age_to_networkx(connection, graphName="MyGraph" )

+```

diff --git a/drivers/python/age/networkx/__init__.py b/drivers/python/age/networkx/__init__.py
new file mode 100644
index 0000000..17a1d3f
--- /dev/null
+++ b/drivers/python/age/networkx/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from .networkx_to_age import networkx_to_age
+from .age_to_networkx import age_to_networkx
diff --git a/drivers/python/age/networkx/age_to_networkx.py b/drivers/python/age/networkx/age_to_networkx.py
new file mode 100644
index 0000000..3a16aa3
--- /dev/null
+++ b/drivers/python/age/networkx/age_to_networkx.py
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from age import *
+import psycopg2
+import networkx as nx
+from age.models import Vertex, Edge, Path
+from .lib import *
+
+
+def age_to_networkx(connection: psycopg2.connect,
+                    graphName: str,
+                    G: None | nx.DiGraph = None,
+                    query: str | None = None
+                    ) -> nx.DiGraph:
+    """
+    @params
+    ---------------------
+    connection - (psycopg2.connect) Connection object
+    graphName - (str) Name of the graph
+    G - (networkx.DiGraph) Networkx directed Graph [optional]
+    query - (str) Cypher query [optional]
+
+        @returns
+    ------------
+    Networkx directed Graph
+
+    """
+
+    # Check if the age graph exists
+    checkIfGraphNameExistInAGE(connection, graphName)
+
+    # Create an empty directed graph
+    if G == None:
+        G = nx.DiGraph()
+
+    def addNodeToNetworkx(node):
+        """Add Nodes in Networkx"""
+        G.add_node(node.id,
+                   label=node.label,
+                   properties=node.properties)
+
+    def addEdgeToNetworkx(edge):
+        """Add Edge in Networkx"""
+        G.add_edge(edge.start_id,
+                   edge.end_id,
+                   label=edge.label,
+                   properties=edge.properties)
+
+    def addPath(path):
+        """Add Edge in Networkx"""
+        for x in path:
+            if (type(x) == Path):
+                addPath(x)
+        for x in path:
+            if (type(x) == Vertex):
+                addNodeToNetworkx(x)
+        for x in path:
+            if (type(x) == Edge):
+                addEdgeToNetworkx(x)
+
+    # Setting up connection to work with Graph
+    age.setUpAge(connection, graphName)
+
+    if (query == None):
+        addAllNodesIntoNetworkx(connection, graphName, G)
+        addAllEdgesIntoNetworkx(connection, graphName, G)
+    else:
+        with connection.cursor() as cursor:
+            cursor.execute(query)
+            rows = cursor.fetchall()
+            for row in rows:
+                for x in row:
+                    if type(x) == Path:
+                        addPath(x)
+                    elif type(x) == Edge:
+                        addEdgeToNetworkx(x)
+                    elif type(x) == Vertex:
+                        addNodeToNetworkx(x)
+    return G
diff --git a/drivers/python/age/networkx/lib.py b/drivers/python/age/networkx/lib.py
new file mode 100644
index 0000000..a4df088
--- /dev/null
+++ b/drivers/python/age/networkx/lib.py
@@ -0,0 +1,253 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from age import *
+import psycopg2
+import networkx as nx
+from psycopg2 import sql
+from psycopg2.extras import execute_values
+from typing import Dict, Any, List, Set
+from age.models import Vertex, Edge, Path
+
+
+def checkIfGraphNameExistInAGE(connection: psycopg2.connect,
+                               graphName: str):
+    """Check if the age graph exists"""
+    with connection.cursor() as cursor:
+        cursor.execute(sql.SQL("""
+                    SELECT count(*) 
+                    FROM ag_catalog.ag_graph 
+                    WHERE name='%s'
+                """ % (graphName)))
+        if cursor.fetchone()[0] == 0:
+            raise GraphNotFound(graphName)
+
+
+def getOidOfGraph(connection: psycopg2.connect,
+                  graphName: str) -> int:
+    """Returns oid of a graph"""
+    try:
+        with connection.cursor() as cursor:
+            cursor.execute(sql.SQL("""
+                        SELECT graphid FROM ag_catalog.ag_graph WHERE name='%s' ;
+                    """ % (graphName)))
+            oid = cursor.fetchone()[0]
+            return oid
+    except Exception as e:
+        print(e)
+
+
+def get_vlabel(connection: psycopg2.connect,
+               graphName: str):
+    node_label_list = []
+    oid = getOidOfGraph(connection, graphName)
+    try:
+        with connection.cursor() as cursor:
+            cursor.execute(
+                """SELECT name FROM ag_catalog.ag_label WHERE kind='v' AND graph=%s;""" % oid)
+            for row in cursor:
+                node_label_list.append(row[0])
+
+    except Exception as ex:
+        print(type(ex), ex)
+    return node_label_list
+
+
+def create_vlabel(connection: psycopg2.connect,
+                  graphName: str,
+                  node_label_list: List):
+    """create_vlabels from list if not exist"""
+    try:
+        node_label_set = set(get_vlabel(connection, graphName))
+        crete_label_statement = ''
+        for label in node_label_list:
+            if label in node_label_set:
+                continue
+            crete_label_statement += """SELECT create_vlabel('%s','%s');\n""" % (
+                graphName, label)
+        if crete_label_statement != '':
+            with connection.cursor() as cursor:
+                cursor.execute(crete_label_statement)
+                connection.commit()
+    except Exception as e:
+        raise Exception(e)
+
+
+def get_elabel(connection: psycopg2.connect,
+               graphName: str):
+    edge_label_list = []
+    oid = getOidOfGraph(connection, graphName)
+    try:
+        with connection.cursor() as cursor:
+            cursor.execute(
+                """SELECT name FROM ag_catalog.ag_label WHERE kind='e' AND graph=%s;""" % oid)
+            for row in cursor:
+                edge_label_list.append(row[0])
+    except Exception as ex:
+        print(type(ex), ex)
+    return edge_label_list
+
+
+def create_elabel(connection: psycopg2.connect,
+                  graphName: str,
+                  edge_label_list: List):
+    """create_vlabels from list if not exist"""
+    try:
+        edge_label_set = set(get_elabel(connection, graphName))
+        crete_label_statement = ''
+        for label in edge_label_list:
+            if label in edge_label_set:
+                continue
+            crete_label_statement += """SELECT create_elabel('%s','%s');\n""" % (
+                graphName, label)
+        if crete_label_statement != '':
+            with connection.cursor() as cursor:
+                cursor.execute(crete_label_statement)
+                connection.commit()
+    except Exception as e:
+        raise Exception(e)
+
+
+def getNodeLabelListAfterPreprocessing(G: nx.DiGraph):
+    """
+        - Add default label if label is missing 
+        - Add properties if not exist 
+        - return all distinct node label
+    """
+    node_label_list = set()
+    try:
+        for node, data in G.nodes(data=True):
+            if 'label' not in data:
+                data['label'] = '_ag_label_vertex'
+            if 'properties' not in data:
+                data['properties'] = {}
+            if not isinstance(data['label'], str):
+                raise Exception(f"label of node : {node} must be a string")
+            if not isinstance(data['properties'], Dict):
+                raise Exception(f"properties of node : {node} must be a dict")
+            if '__id__' not in data['properties'].keys():
+                data['properties']['__id__'] = node
+            node_label_list.add(data['label'])
+    except Exception as e:
+        raise Exception(e)
+    return node_label_list
+
+
+def getEdgeLabelListAfterPreprocessing(G: nx.DiGraph):
+    """
+        - Add default label if label is missing 
+        - Add properties if not exist 
+        - return all distinct edge label
+    """
+    edge_label_list = set()
+    try:
+        for u, v, data in G.edges(data=True):
+            if 'label' not in data:
+                data['label'] = '_ag_label_edge'
+            if 'properties' not in data:
+                data['properties'] = {}
+            if not isinstance(data['label'], str):
+                raise Exception(f"label of edge : {u}->{v} must be a string")
+            if not isinstance(data['properties'], Dict):
+                raise Exception(
+                    f"properties of edge : {u}->{v} must be a dict")
+            edge_label_list.add(data['label'])
+    except Exception as e:
+        raise Exception(e)
+    return edge_label_list
+
+
+def addAllNodesIntoAGE(connection: psycopg2.connect, graphName: str, G: nx.DiGraph, node_label_list: Set):
+    """Add all node to AGE"""
+    try:
+        queue_data = {label: [] for label in node_label_list}
+        id_data = {}
+
+        for node, data in G.nodes(data=True):
+            json_string = json.dumps(data['properties'])
+            queue_data[data['label']].append((json_string,))
+
+        for label, rows in queue_data.items():
+            table_name = """%s."%s" """ % (graphName, label)
+            insert_query = f"INSERT INTO {table_name} (properties) VALUES %s RETURNING id"
+            cursor = connection.cursor()
+            id_data[label] = execute_values(
+                cursor, insert_query, rows, fetch=True)
+            connection.commit()
+            cursor.close()
+            id_data[label].reverse()
+
+        for node, data in G.nodes(data=True):
+            data['properties']['__gid__'] = id_data[data['label']][-1][0]
+            id_data[data['label']].pop()
+
+    except Exception as e:
+        raise Exception(e)
+
+
+def addAllEdgesIntoAGE(connection: psycopg2.connect, graphName: str, G: nx.DiGraph, edge_label_list: Set):
+    """Add all edge to AGE"""
+    try:
+        queue_data = {label: [] for label in edge_label_list}
+        for u, v, data in G.edges(data=True):
+            json_string = json.dumps(data['properties'])
+            queue_data[data['label']].append(
+                (G.nodes[u]['properties']['__gid__'], G.nodes[v]['properties']['__gid__'], json_string,))
+
+        for label, rows in queue_data.items():
+            table_name = """%s."%s" """ % (graphName, label)
+            insert_query = f"INSERT INTO {table_name} (start_id,end_id,properties) VALUES %s"
+            cursor = connection.cursor()
+            execute_values(cursor, insert_query, rows)
+            connection.commit()
+            cursor.close()
+    except Exception as e:
+        raise Exception(e)
+
+
+def addAllNodesIntoNetworkx(connection: psycopg2.connect, graphName: str, G: nx.DiGraph):
+    """Add all nodes to Networkx"""
+    node_label_list = get_vlabel(connection, graphName)
+    try:
+        for label in node_label_list:
+            with connection.cursor() as cursor:
+                cursor.execute("""
+                SELECT id, CAST(properties AS VARCHAR) 
+                FROM %s."%s";
+                """ % (graphName, label))
+                rows = cursor.fetchall()
+                for row in rows:
+                    G.add_node(int(row[0]), label=label,
+                               properties=json.loads(row[1]))
+    except Exception as e:
+        print(e)
+
+
+def addAllEdgesIntoNetworkx(connection: psycopg2.connect, graphName: str, G: nx.DiGraph):
+    """Add All edges to Networkx"""
+    try:
+        edge_label_list = get_elabel(connection, graphName)
+        for label in edge_label_list:
+            with connection.cursor() as cursor:
+                cursor.execute("""
+                               SELECT start_id, end_id, CAST(properties AS VARCHAR) 
+                               FROM %s."%s";
+                               """ % (graphName, label))
+                rows = cursor.fetchall()
+                for row in rows:
+                    G.add_edge(int(row[0]), int(
+                        row[1]), label=label, properties=json.loads(row[2]))
+    except Exception as e:
+        print(e)
diff --git a/drivers/python/age/networkx/networkx_to_age.py b/drivers/python/age/networkx/networkx_to_age.py
new file mode 100644
index 0000000..90cee97
--- /dev/null
+++ b/drivers/python/age/networkx/networkx_to_age.py
@@ -0,0 +1,49 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from age import *
+import psycopg2
+import networkx as nx
+from .lib import *
+
+
+def networkx_to_age(connection: psycopg2.connect,
+                    G: nx.DiGraph,
+                    graphName: str):
+    """
+    @params
+    -----------
+    connection - (psycopg2.connect) Connection object
+
+    G - (networkx.DiGraph) Networkx directed Graph
+
+    graphName - (str) Name of the graph
+
+    @returns
+    ------------
+    None
+
+    """
+    node_label_list = getNodeLabelListAfterPreprocessing(G)
+    edge_label_list = getEdgeLabelListAfterPreprocessing(G)
+
+    # Setup connection with Graph
+    age.setUpAge(connection, graphName)
+
+    create_vlabel(connection, graphName, node_label_list)
+    create_elabel(connection, graphName, edge_label_list)
+
+    addAllNodesIntoAGE(connection, graphName, G, node_label_list)
+    addAllEdgesIntoAGE(connection, graphName, G, edge_label_list)
diff --git a/drivers/python/requirements.txt b/drivers/python/requirements.txt
index 4951053..81d1ef7 100644
--- a/drivers/python/requirements.txt
+++ b/drivers/python/requirements.txt
Binary files differ
diff --git a/drivers/python/samples/networkx.ipynb b/drivers/python/samples/networkx.ipynb
new file mode 100644
index 0000000..20fc566
--- /dev/null
+++ b/drivers/python/samples/networkx.ipynb
@@ -0,0 +1,183 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import psycopg2\n",
+    "from age.networkx import *\n",
+    "from age import *\n",
+    "import networkx as nx\n",
+    "\n",
+    "conn = psycopg2.connect(\n",
+    "    host=\"localhost\",\n",
+    "    port=\"5432\",\n",
+    "    dbname=\"postgres\",\n",
+    "    user=\"moontasir\",\n",
+    "    password=\"254826\")\n",
+    "graphName = 'bitnine_global_inic'\n",
+    "\n",
+    "\n",
+    "age.setUpAge(conn, graphName)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Creating a Random Networkx Graph"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "DiGraph with 5 nodes and 9 edges\n"
+     ]
+    }
+   ],
+   "source": [
+    "import random\n",
+    "num_nodes = 5\n",
+    "num_edges = 10\n",
+    "G = nx.DiGraph()\n",
+    "try:\n",
+    "    for i in range(num_nodes//2):\n",
+    "        G.add_node(i, label='Number',properties={'name' : i})\n",
+    "    for i in range(num_nodes//2,num_nodes):\n",
+    "        G.add_node(i, label='Integer',properties={'age' : i*2} )\n",
+    "    for i in range(num_edges):\n",
+    "        source = random.randint(0, num_nodes-1)\n",
+    "        target = random.randint(0, num_nodes-1)\n",
+    "        G.add_edge(source, target, label='Connection' ,properties={'st' : source , 'ed':target})\n",
+    "except Exception as e:\n",
+    "    raise e\n",
+    "print(G)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Networkx to AGE"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "networkx_to_age(conn, G, graphName)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## AGE to Networkx"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### ALL AGE to Networkx\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "DiGraph with 5 nodes and 9 edges\n"
+     ]
+    }
+   ],
+   "source": [
+    "G = age_to_networkx(conn, graphName)\n",
+    "print(G)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Add subgraph of AGE to Networkx using query"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "DiGraph with 5 nodes and 9 edges\n"
+     ]
+    }
+   ],
+   "source": [
+    "G = age_to_networkx(conn, graphName, \n",
+    "                    query=\"\"\"\n",
+    "SELECT * from cypher('%s', $$\n",
+    "        MATCH (V:Integer)\n",
+    "        RETURN V\n",
+    "$$) as (V agtype);\n",
+    "\"\"\" % graphName)\n",
+    "\n",
+    "G = age_to_networkx(conn, graphName, G=G,\n",
+    "                    query=\"\"\"\n",
+    "SELECT * from cypher('%s', $$\n",
+    "        MATCH (V:Number)\n",
+    "        RETURN V\n",
+    "$$) as (V agtype);\n",
+    "\"\"\" % graphName)\n",
+    "\n",
+    "G = age_to_networkx(conn, graphName, G=G,\n",
+    "                    query=\"\"\"\n",
+    "SELECT * from cypher('%s', $$\n",
+    "        MATCH (V)-[R]->(V2)\n",
+    "        RETURN V,R,V2\n",
+    "$$) as (V agtype, R agtype, V2 agtype);\n",
+    "\"\"\" % graphName)\n",
+    "print(G)\n"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "myenv",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.6"
+  },
+  "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/drivers/python/setup.py b/drivers/python/setup.py
index 6315022..15be9e7 100644
--- a/drivers/python/setup.py
+++ b/drivers/python/setup.py
@@ -30,8 +30,8 @@
     url              = 'https://github.com/apache/age/tree/master/drivers/python',

     download_url     = 'https://github.com/apache/age/releases' ,

     license          = 'Apache2.0',

-    install_requires = [ 'psycopg2', 'antlr4_python3_runtime==4.11.1'],

-    packages         = ['age', 'age.gen'],

+    install_requires = [ 'psycopg2', 'antlr4-python3-runtime==4.11.1'],

+    packages         = ['age', 'age.gen','age.networkx'],

     keywords         = ['Graph Database', 'Apache AGE', 'PostgreSQL'],

     python_requires  = '>=3.9',

     classifiers      = [

diff --git a/drivers/python/test_networkx.py b/drivers/python/test_networkx.py
new file mode 100644
index 0000000..290d4c4
--- /dev/null
+++ b/drivers/python/test_networkx.py
@@ -0,0 +1,480 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import age
+import unittest
+import argparse
+import networkx as nx
+from age.models import *
+from age.networkx import *
+from age.exceptions import *
+
+
+TEST_GRAPH_NAME = "test_graph"
+ORIGINAL_GRAPH = "original_graph"
+EXPECTED_GRAPH = "expected_graph"
+
+
+class TestAgeToNetworkx(unittest.TestCase):
+    ag = None
+
+    def setUp(self):
+
+        TEST_DB = self.args.database
+        TEST_USER = self.args.user
+        TEST_PASSWORD = self.args.password
+        TEST_PORT = self.args.port
+        TEST_HOST = self.args.host
+        self.ag = age.connect(graph=TEST_GRAPH_NAME, host=TEST_HOST, port=TEST_PORT,
+                              dbname=TEST_DB, user=TEST_USER, password=TEST_PASSWORD)
+
+    def tearDown(self):
+        age.deleteGraph(self.ag.connection, self.ag.graphName)
+        self.ag.close()
+
+    def compare_networkX(self, G, H):
+        if G.number_of_nodes() != H.number_of_nodes():
+            return False
+        if G.number_of_edges() != H.number_of_edges():
+            return False
+        # test nodes
+        nodes_G, nodes_H = G.number_of_nodes(), H.number_of_nodes()
+        markG, markH = [0]*nodes_G, [0]*nodes_H
+        nodes_list_G, nodes_list_H = list(G.nodes), list(H.nodes)
+        for i in range(0, nodes_G):
+            for j in range(0, nodes_H):
+                if markG[i] == 0 and markH[j] == 0:
+                    node_id_G = nodes_list_G[i]
+                    property_G = G.nodes[node_id_G]
+
+                    node_id_H = nodes_list_H[j]
+                    property_H = H.nodes[node_id_H]
+                    if property_G == property_H:
+                        markG[i] = 1
+                        markH[j] = 1
+
+        if any(elem == 0 for elem in markG):
+            return False
+        if any(elem == 0 for elem in markH):
+            return False
+
+        # test edges
+        edges_G, edges_H = G.number_of_edges(), H.number_of_edges()
+        markG, markH = [0]*edges_G, [0]*edges_H
+        edges_list_G, edges_list_H = list(G.edges), list(H.edges)
+
+        for i in range(0, edges_G):
+            for j in range(0, edges_H):
+                if markG[i] == 0 and markH[j] == 0:
+                    source_G, target_G = edges_list_G[i]
+                    property_G = G.edges[source_G, target_G]
+
+                    source_H, target_H = edges_list_H[j]
+                    property_H = H.edges[source_H, target_H]
+
+                    if property_G == property_H:
+                        markG[i] = 1
+                        markH[j] = 1
+
+        if any(elem == 0 for elem in markG):
+            return False
+        if any(elem == 0 for elem in markH):
+            return False
+
+        return True
+
+    def test_empty_graph(self):
+        print('Testing AGE to Networkx for empty graph')
+        # Expected Graph
+        # Empty Graph
+        G = nx.DiGraph()
+
+        # Convert Apache AGE to NetworkX
+        H = age_to_networkx(self.ag.connection, TEST_GRAPH_NAME)
+
+        self.assertTrue(self.compare_networkX(G, H))
+
+    def test_existing_graph_without_query(self):
+        print('Testing AGE to Networkx for non empty graph without query')
+        ag = self.ag
+        # Create nodes
+        ag.execCypher("CREATE (n:Person {name: %s}) ", params=('Jack',))
+        ag.execCypher("CREATE (n:Person {name: %s}) ", params=('Andy',))
+        ag.execCypher("CREATE (n:Person {name: %s}) ", params=('Smith',))
+        ag.commit()
+
+        # Create Edges
+        ag.execCypher("""MATCH (a:Person), (b:Person)
+                    WHERE a.name = 'Andy' AND b.name = 'Jack'
+                    CREATE (a)-[r:workWith {weight: 3}]->(b)""")
+        ag.execCypher("""MATCH (a:Person), (b:Person) 
+                    WHERE  a.name = %s AND b.name = %s 
+                    CREATE p=((a)-[r:workWith {weight: 10}]->(b)) """, params=('Jack', 'Smith',))
+        ag.commit()
+
+        G = age_to_networkx(self.ag.connection, TEST_GRAPH_NAME)
+
+        # Check that the G has the expected properties
+        self.assertIsInstance(G, nx.DiGraph)
+
+        # Check that the G has the correct number of nodes and edges
+        self.assertEqual(len(G.nodes), 3)
+        self.assertEqual(len(G.edges), 2)
+
+        # Check that the node properties are correct
+        for node in G.nodes:
+            self.assertEqual(int, type(node))
+            self.assertEqual(G.nodes[node]['label'], 'Person')
+            self.assertIn('name', G.nodes[node]['properties'])
+            self.assertIn('properties', G.nodes[node])
+            self.assertEqual(str, type(G.nodes[node]['label']))
+
+        # Check that the edge properties are correct
+        for edge in G.edges:
+            self.assertEqual(tuple, type(edge))
+            self.assertEqual(int, type(edge[0]) and type(edge[1]))
+            self.assertEqual(G.edges[edge]['label'], 'workWith')
+            self.assertIn('weight', G.edges[edge]['properties'])
+            self.assertEqual(int, type(G.edges[edge]['properties']['weight']))
+
+    def test_existing_graph_with_query(self):
+        print('Testing AGE to Networkx for non empty graph with query')
+
+        ag = self.ag
+        # Create nodes
+        ag.execCypher("CREATE (n:Person {name: %s}) ", params=('Jack',))
+        ag.execCypher("CREATE (n:Person {name: %s}) ", params=('Andy',))
+        ag.execCypher("CREATE (n:Person {name: %s}) ", params=('Smith',))
+        ag.commit()
+
+        # Create Edges
+        ag.execCypher("""MATCH (a:Person), (b:Person)
+                    WHERE a.name = 'Andy' AND b.name = 'Jack'
+                    CREATE (a)-[r:workWith {weight: 3}]->(b)""")
+        ag.execCypher("""MATCH (a:Person), (b:Person) 
+                    WHERE  a.name = %s AND b.name = %s 
+                    CREATE p=((a)-[r:workWith {weight: 10}]->(b)) """, params=('Jack', 'Smith',))
+        ag.commit()
+
+        query = """SELECT * FROM cypher('%s', $$ MATCH (a:Person)-[r:workWith]->(b:Person)
+        WHERE a.name = 'Andy'
+        RETURN a, r, b $$) AS (a agtype, r agtype, b agtype);
+        """ % (TEST_GRAPH_NAME)
+
+        G = age_to_networkx(self.ag.connection,
+                            graphName=TEST_GRAPH_NAME, query=query)
+
+        # Check that the G has the expected properties
+        self.assertIsInstance(G, nx.DiGraph)
+
+        # Check that the G has the correct number of nodes and edges
+        self.assertEqual(len(G.nodes), 2)
+        self.assertEqual(len(G.edges), 1)
+
+        # Check that the node properties are correct
+        for node in G.nodes:
+            self.assertEqual(int, type(node))
+            self.assertEqual(G.nodes[node]['label'], 'Person')
+            self.assertIn('name', G.nodes[node]['properties'])
+            self.assertIn('properties', G.nodes[node])
+            self.assertEqual(str, type(G.nodes[node]['label']))
+
+        # Check that the edge properties are correct
+        for edge in G.edges:
+            self.assertEqual(tuple, type(edge))
+            self.assertEqual(int, type(edge[0]) and type(edge[1]))
+            self.assertEqual(G.edges[edge]['label'], 'workWith')
+            self.assertIn('weight', G.edges[edge]['properties'])
+            self.assertEqual(int, type(G.edges[edge]['properties']['weight']))
+
+    def test_existing_graph(self):
+        print("Testing AGE to NetworkX for non-existing graph")
+        ag = self.ag
+
+        non_existing_graph = "non_existing_graph"
+        # Check that the function raises an exception for non existing graph
+        with self.assertRaises(GraphNotFound) as context:
+            age_to_networkx(ag.connection, graphName=non_existing_graph)
+        # Check the raised exception has the expected error message
+        self.assertEqual(str(context.exception), non_existing_graph)
+
+
+class TestNetworkxToAGE(unittest.TestCase):
+    ag = None
+    ag1 = None
+    ag2 = None
+
+    def setUp(self):
+        TEST_DB = self.args.database
+        TEST_USER = self.args.user
+        TEST_PASSWORD = self.args.password
+        TEST_PORT = self.args.port
+        TEST_HOST = self.args.host
+        self.ag = age.connect(graph=TEST_GRAPH_NAME, host=TEST_HOST, port=TEST_PORT,
+                              dbname=TEST_DB, user=TEST_USER, password=TEST_PASSWORD)
+        self.ag1 = age.connect(graph=ORIGINAL_GRAPH, host=TEST_HOST, port=TEST_PORT,
+                               dbname=TEST_DB, user=TEST_USER, password=TEST_PASSWORD)
+        self.ag2 = age.connect(graph=EXPECTED_GRAPH, host=TEST_HOST, port=TEST_PORT,
+                               dbname=TEST_DB, user=TEST_USER, password=TEST_PASSWORD)
+        self.graph = nx.DiGraph()
+
+    def tearDown(self):
+        age.deleteGraph(self.ag1.connection, self.ag1.graphName)
+        age.deleteGraph(self.ag2.connection, self.ag2.graphName)
+        age.deleteGraph(self.ag.connection, self.ag.graphName)
+        self.ag.close()
+        self.ag1.close()
+        self.ag2.close()
+
+    def compare_age(self, age1, age2):
+        cursor = age1.execCypher("MATCH (v) RETURN v")
+        g_nodes = cursor.fetchall()
+
+        cursor = age1.execCypher("MATCH ()-[r]->() RETURN r")
+        g_edges = cursor.fetchall()
+
+        cursor = age2.execCypher("MATCH (v) RETURN v")
+        h_nodes = cursor.fetchall()
+
+        cursor = age2.execCypher("MATCH ()-[r]->() RETURN r")
+        h_edges = cursor.fetchall()
+
+        if len(g_nodes) != len(h_nodes) or len(g_edges) != len(h_edges):
+            return False
+
+        # test nodes
+        nodes_G, nodes_H = len(g_nodes), len(h_nodes)
+        markG, markH = [0]*nodes_G, [0]*nodes_H
+
+        # return True
+        for i in range(0, nodes_G):
+            for j in range(0, nodes_H):
+                if markG[i] == 0 and markH[j] == 0:
+                    property_G = g_nodes[i][0].properties
+                    property_G['label'] = g_nodes[i][0].label
+                    property_G.pop('__id__')
+
+                    property_H = h_nodes[j][0].properties
+                    property_H['label'] = h_nodes[j][0].label
+
+                    if property_G == property_H:
+                        markG[i] = 1
+                        markH[j] = 1
+
+        if any(elem == 0 for elem in markG):
+            return False
+        if any(elem == 0 for elem in markH):
+            return False
+
+        # test edges
+        edges_G, edges_H = len(g_edges), len(h_edges)
+        markG, markH = [0]*edges_G, [0]*edges_H
+
+        for i in range(0, edges_G):
+            for j in range(0, edges_H):
+                if markG[i] == 0 and markH[j] == 0:
+                    property_G = g_edges[i][0].properties
+                    property_G['label'] = g_edges[i][0].label
+
+                    property_H = h_edges[j][0].properties
+                    property_H['label'] = h_edges[j][0].label
+                    if property_G == property_H:
+                        markG[i] = 1
+                        markH[j] = 1
+
+        if any(elem == 0 for elem in markG):
+            return False
+        if any(elem == 0 for elem in markH):
+            return False
+
+        return True
+
+    def test_number_of_nodes_and_edges(self):
+        print("Testing Networkx To AGE for number of nodes and edges")
+        ag = self.ag
+
+        # Create NetworkX graph
+        self.graph.add_node(1, label='Person', properties={
+                            'name': 'Moontasir', 'age': '26', 'id': 1})
+        self.graph.add_node(2, label='Person', properties={
+                            'name': 'Austen', 'age': '26', 'id': 2})
+        self.graph.add_edge(1, 2, label='KNOWS', properties={
+                            'since': '1997', 'start_id': 1, 'end_id': 2})
+        self.graph.add_node(3, label='Person', properties={
+                            'name': 'Eric', 'age': '28', 'id': 3})
+
+        networkx_to_age(self.ag.connection, self.graph, TEST_GRAPH_NAME)
+
+        # Check that node(s) were created
+        cursor = ag.execCypher('MATCH (n) RETURN n')
+        result = cursor.fetchall()
+        # Check number of vertices created
+        self.assertEqual(len(result), 3)
+        # Checks if type of property in query output is a Vertex
+        self.assertEqual(Vertex, type(result[0][0]))
+        self.assertEqual(Vertex, type(result[1][0]))
+
+        # Check that edge(s) was created
+        cursor = ag.execCypher('MATCH ()-[e]->() RETURN e')
+        result = cursor.fetchall()
+        # Check number of edge(s) created
+        self.assertEqual(len(result), 1)
+        # Checks if type of property in query output is an Edge
+        self.assertEqual(Edge, type(result[0][0]))
+
+    def test_empty_graph(self):
+        print("Testing Networkx To AGE for empty Graph")
+        # Expected Graph
+        # Empty Graph
+
+        # NetworkX Graph
+        G = nx.DiGraph()
+
+        # Convert Apache AGE to NetworkX
+        networkx_to_age(self.ag1.connection, G, ORIGINAL_GRAPH)
+
+        self.assertTrue(self.compare_age(self.ag1, self.ag2))
+
+    def test_non_empty_graph(self):
+        print("Testing Networkx To AGE for non-empty Graph")
+        # Expected Graph
+        self.ag2.execCypher("CREATE (:l1 {name: 'n1', weight: '5'})")
+        self.ag2.execCypher("CREATE (:l1 {name: 'n2', weight: '4'})")
+        self.ag2.execCypher("CREATE (:l1 {name: 'n3', weight: '9'})")
+
+        self.ag2.execCypher("""MATCH (a:l1), (b:l1)
+                            WHERE a.name = 'n1' AND b.name = 'n2'
+                            CREATE (a)-[e:e1 {property:'graph'}]->(b)""")
+        self.ag2.execCypher("""MATCH (a:l1), (b:l1)
+                            WHERE a.name = 'n2' AND b.name = 'n3'
+                            CREATE (a)-[e:e2 {property:'node'}]->(b)""")
+
+        # NetworkX Graph
+        G = nx.DiGraph()
+
+        G.add_node('1',
+                   label='l1',
+                   properties={'name': 'n1',
+                               'weight': '5'})
+        G.add_node('2',
+                   label='l1',
+                   properties={'name': 'n2',
+                               'weight': '4'})
+        G.add_node('3',
+                   label='l1',
+                   properties={'name': 'n3',
+                               'weight': '9'})
+        G.add_edge('1', '2', label='e1', properties={'property': 'graph'})
+        G.add_edge('2', '3', label='e2', properties={'property': 'node'})
+
+        # Convert Apache AGE to NetworkX
+        networkx_to_age(self.ag1.connection, G, ORIGINAL_GRAPH)
+
+        self.assertTrue(self.compare_age(self.ag1, self.ag2))
+
+    def test_invalid_node_label(self):
+        print("Testing Networkx To AGE for invalid node label")
+        self.graph.add_node(4, label=123, properties={
+                            'name': 'Mason', 'age': '24', 'id': 4})
+
+        # Check that the function raises an exception for the invalid node label
+        with self.assertRaises(Exception) as context:
+            networkx_to_age(self.ag.connection, G=self.graph,
+                            graphName=TEST_GRAPH_NAME)
+        # Check the raised exception has the expected error message
+        self.assertEqual(str(context.exception),
+                         "label of node : 4 must be a string")
+
+    def test_invalid_node_properties(self):
+        print("Testing Networkx To AGE for invalid node properties")
+        self.graph.add_node(4, label='Person', properties="invalid")
+
+        # Check that the function raises an exception for the invalid node properties
+        with self.assertRaises(Exception) as context:
+            networkx_to_age(self.ag.connection, G=self.graph,
+                            graphName=TEST_GRAPH_NAME)
+        # Check the raised exception has the expected error message
+        self.assertEqual(str(context.exception),
+                         "properties of node : 4 must be a dict")
+
+    def test_invalid_edge_label(self):
+        print("Testing Networkx To AGE for invalid edge label")
+        self.graph.add_edge(1, 2, label=123, properties={
+                            'since': '1997', 'start_id': 1, 'end_id': 2})
+
+        # Check that the function raises an exception for the invalid edge label
+        with self.assertRaises(Exception) as context:
+            networkx_to_age(self.ag.connection, G=self.graph,
+                            graphName=TEST_GRAPH_NAME)
+        # Check the raised exception has the expected error message
+        self.assertEqual(str(context.exception),
+                         "label of edge : 1->2 must be a string")
+
+    def test_invalid_edge_properties(self):
+        print("Testing Networkx To AGE for invalid edge properties")
+        self.graph.add_edge(1, 2, label='KNOWS', properties="invalid")
+
+        # Check that the function raises an exception for the invalid edge properties
+        with self.assertRaises(Exception) as context:
+            networkx_to_age(self.ag.connection, G=self.graph,
+                            graphName=TEST_GRAPH_NAME)
+        # Check the raised exception has the expected error message
+        self.assertEqual(str(context.exception),
+                         "properties of edge : 1->2 must be a dict")
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument('-host',
+                        '--host',
+                        help='Optional Host Name. Default Host is "127.0.0.1" ',
+                        default="127.0.0.1")
+    parser.add_argument('-port',
+                        '--port',
+                        help='Optional Port Number. Default port no is 5432',
+                        default=5432)
+    parser.add_argument('-db',
+                        '--database',
+                        help='Required Database Name',
+                        required=True)
+    parser.add_argument('-u',
+                        '--user',
+                        help='Required Username Name',
+                        required=True)
+    parser.add_argument('-pass',
+                        '--password',
+                        help='Required Password for authentication',
+                        required=True)
+
+    args = parser.parse_args()
+    suite = unittest.TestSuite()
+
+    suite.addTest(TestAgeToNetworkx('test_empty_graph'))
+    suite.addTest(TestAgeToNetworkx('test_existing_graph_without_query'))
+    suite.addTest(TestAgeToNetworkx('test_existing_graph_with_query'))
+    suite.addTest(TestAgeToNetworkx('test_existing_graph'))
+    TestAgeToNetworkx.args = args
+
+    suite.addTest(TestNetworkxToAGE('test_number_of_nodes_and_edges'))
+    suite.addTest(TestNetworkxToAGE('test_empty_graph'))
+    suite.addTest(TestNetworkxToAGE('test_non_empty_graph'))
+    suite.addTest(TestNetworkxToAGE('test_invalid_node_label'))
+    suite.addTest(TestNetworkxToAGE('test_invalid_node_properties'))
+    suite.addTest(TestNetworkxToAGE('test_invalid_edge_label'))
+    suite.addTest(TestNetworkxToAGE('test_invalid_edge_properties'))
+    TestNetworkxToAGE.args = args
+
+    unittest.TextTestRunner().run(suite)