blob: a4df088637bd29b66c7bdad8904b7101e9830312 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from age import *
import psycopg2
import networkx as nx
from psycopg2 import sql
from psycopg2.extras import execute_values
from typing import Dict, Any, List, Set
from age.models import Vertex, Edge, Path
def checkIfGraphNameExistInAGE(connection: psycopg2.connect,
graphName: str):
"""Check if the age graph exists"""
with connection.cursor() as cursor:
cursor.execute(sql.SQL("""
SELECT count(*)
FROM ag_catalog.ag_graph
WHERE name='%s'
""" % (graphName)))
if cursor.fetchone()[0] == 0:
raise GraphNotFound(graphName)
def getOidOfGraph(connection: psycopg2.connect,
graphName: str) -> int:
"""Returns oid of a graph"""
try:
with connection.cursor() as cursor:
cursor.execute(sql.SQL("""
SELECT graphid FROM ag_catalog.ag_graph WHERE name='%s' ;
""" % (graphName)))
oid = cursor.fetchone()[0]
return oid
except Exception as e:
print(e)
def get_vlabel(connection: psycopg2.connect,
graphName: str):
node_label_list = []
oid = getOidOfGraph(connection, graphName)
try:
with connection.cursor() as cursor:
cursor.execute(
"""SELECT name FROM ag_catalog.ag_label WHERE kind='v' AND graph=%s;""" % oid)
for row in cursor:
node_label_list.append(row[0])
except Exception as ex:
print(type(ex), ex)
return node_label_list
def create_vlabel(connection: psycopg2.connect,
graphName: str,
node_label_list: List):
"""create_vlabels from list if not exist"""
try:
node_label_set = set(get_vlabel(connection, graphName))
crete_label_statement = ''
for label in node_label_list:
if label in node_label_set:
continue
crete_label_statement += """SELECT create_vlabel('%s','%s');\n""" % (
graphName, label)
if crete_label_statement != '':
with connection.cursor() as cursor:
cursor.execute(crete_label_statement)
connection.commit()
except Exception as e:
raise Exception(e)
def get_elabel(connection: psycopg2.connect,
graphName: str):
edge_label_list = []
oid = getOidOfGraph(connection, graphName)
try:
with connection.cursor() as cursor:
cursor.execute(
"""SELECT name FROM ag_catalog.ag_label WHERE kind='e' AND graph=%s;""" % oid)
for row in cursor:
edge_label_list.append(row[0])
except Exception as ex:
print(type(ex), ex)
return edge_label_list
def create_elabel(connection: psycopg2.connect,
graphName: str,
edge_label_list: List):
"""create_vlabels from list if not exist"""
try:
edge_label_set = set(get_elabel(connection, graphName))
crete_label_statement = ''
for label in edge_label_list:
if label in edge_label_set:
continue
crete_label_statement += """SELECT create_elabel('%s','%s');\n""" % (
graphName, label)
if crete_label_statement != '':
with connection.cursor() as cursor:
cursor.execute(crete_label_statement)
connection.commit()
except Exception as e:
raise Exception(e)
def getNodeLabelListAfterPreprocessing(G: nx.DiGraph):
"""
- Add default label if label is missing
- Add properties if not exist
- return all distinct node label
"""
node_label_list = set()
try:
for node, data in G.nodes(data=True):
if 'label' not in data:
data['label'] = '_ag_label_vertex'
if 'properties' not in data:
data['properties'] = {}
if not isinstance(data['label'], str):
raise Exception(f"label of node : {node} must be a string")
if not isinstance(data['properties'], Dict):
raise Exception(f"properties of node : {node} must be a dict")
if '__id__' not in data['properties'].keys():
data['properties']['__id__'] = node
node_label_list.add(data['label'])
except Exception as e:
raise Exception(e)
return node_label_list
def getEdgeLabelListAfterPreprocessing(G: nx.DiGraph):
"""
- Add default label if label is missing
- Add properties if not exist
- return all distinct edge label
"""
edge_label_list = set()
try:
for u, v, data in G.edges(data=True):
if 'label' not in data:
data['label'] = '_ag_label_edge'
if 'properties' not in data:
data['properties'] = {}
if not isinstance(data['label'], str):
raise Exception(f"label of edge : {u}->{v} must be a string")
if not isinstance(data['properties'], Dict):
raise Exception(
f"properties of edge : {u}->{v} must be a dict")
edge_label_list.add(data['label'])
except Exception as e:
raise Exception(e)
return edge_label_list
def addAllNodesIntoAGE(connection: psycopg2.connect, graphName: str, G: nx.DiGraph, node_label_list: Set):
"""Add all node to AGE"""
try:
queue_data = {label: [] for label in node_label_list}
id_data = {}
for node, data in G.nodes(data=True):
json_string = json.dumps(data['properties'])
queue_data[data['label']].append((json_string,))
for label, rows in queue_data.items():
table_name = """%s."%s" """ % (graphName, label)
insert_query = f"INSERT INTO {table_name} (properties) VALUES %s RETURNING id"
cursor = connection.cursor()
id_data[label] = execute_values(
cursor, insert_query, rows, fetch=True)
connection.commit()
cursor.close()
id_data[label].reverse()
for node, data in G.nodes(data=True):
data['properties']['__gid__'] = id_data[data['label']][-1][0]
id_data[data['label']].pop()
except Exception as e:
raise Exception(e)
def addAllEdgesIntoAGE(connection: psycopg2.connect, graphName: str, G: nx.DiGraph, edge_label_list: Set):
"""Add all edge to AGE"""
try:
queue_data = {label: [] for label in edge_label_list}
for u, v, data in G.edges(data=True):
json_string = json.dumps(data['properties'])
queue_data[data['label']].append(
(G.nodes[u]['properties']['__gid__'], G.nodes[v]['properties']['__gid__'], json_string,))
for label, rows in queue_data.items():
table_name = """%s."%s" """ % (graphName, label)
insert_query = f"INSERT INTO {table_name} (start_id,end_id,properties) VALUES %s"
cursor = connection.cursor()
execute_values(cursor, insert_query, rows)
connection.commit()
cursor.close()
except Exception as e:
raise Exception(e)
def addAllNodesIntoNetworkx(connection: psycopg2.connect, graphName: str, G: nx.DiGraph):
"""Add all nodes to Networkx"""
node_label_list = get_vlabel(connection, graphName)
try:
for label in node_label_list:
with connection.cursor() as cursor:
cursor.execute("""
SELECT id, CAST(properties AS VARCHAR)
FROM %s."%s";
""" % (graphName, label))
rows = cursor.fetchall()
for row in rows:
G.add_node(int(row[0]), label=label,
properties=json.loads(row[1]))
except Exception as e:
print(e)
def addAllEdgesIntoNetworkx(connection: psycopg2.connect, graphName: str, G: nx.DiGraph):
"""Add All edges to Networkx"""
try:
edge_label_list = get_elabel(connection, graphName)
for label in edge_label_list:
with connection.cursor() as cursor:
cursor.execute("""
SELECT start_id, end_id, CAST(properties AS VARCHAR)
FROM %s."%s";
""" % (graphName, label))
rows = cursor.fetchall()
for row in rows:
G.add_edge(int(row[0]), int(
row[1]), label=label, properties=json.loads(row[2]))
except Exception as e:
print(e)