blob: 817cc6e5aabdd598da342d5088f99951681935cf [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
import psycopg
from psycopg.types import TypeInfo
from psycopg.adapt import Loader
from psycopg import sql
from psycopg.client_cursor import ClientCursor
from .exceptions import *
from .builder import parseAgeValue
_EXCEPTION_NoConnection = NoConnection()
_EXCEPTION_GraphNotSet = GraphNotSet()
WHITESPACE = re.compile('\s')
class AgeDumper(psycopg.adapt.Dumper):
def dump(self, obj: Any) -> bytes | bytearray | memoryview:
pass
class AgeLoader(psycopg.adapt.Loader):
def load(self, data: bytes | bytearray | memoryview) -> Any | None:
if isinstance(data, memoryview):
data_bytes = data.tobytes()
else:
data_bytes = data
return parseAgeValue(data_bytes.decode('utf-8'))
def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=False):
with conn.cursor() as cursor:
if load_from_plugins:
cursor.execute("LOAD '$libdir/plugins/age';")
else:
cursor.execute("LOAD 'age';")
cursor.execute("SET search_path = ag_catalog, '$user', public;")
ag_info = TypeInfo.fetch(conn, 'agtype')
if not ag_info:
raise AgeNotSet()
conn.adapters.register_loader(ag_info.oid, AgeLoader)
conn.adapters.register_loader(ag_info.array_oid, AgeLoader)
# Check graph exists
if graphName != None:
checkGraphCreated(conn, graphName)
# Create the graph, if it does not exist
def checkGraphCreated(conn:psycopg.connection, graphName:str):
with conn.cursor() as cursor:
cursor.execute(sql.SQL("SELECT count(*) FROM ag_graph WHERE name={graphName}").format(graphName=sql.Literal(graphName)))
if cursor.fetchone()[0] == 0:
cursor.execute(sql.SQL("SELECT create_graph({graphName});").format(graphName=sql.Literal(graphName)))
conn.commit()
def deleteGraph(conn:psycopg.connection, graphName:str):
with conn.cursor() as cursor:
cursor.execute(sql.SQL("SELECT drop_graph({graphName}, true);").format(graphName=sql.Literal(graphName)))
conn.commit()
def buildCypher(graphName:str, cypherStmt:str, columns:list) ->str:
if graphName == None:
raise _EXCEPTION_GraphNotSet
columnExp=[]
if columns != None and len(columns) > 0:
for col in columns:
if col.strip() == '':
continue
elif WHITESPACE.search(col) != None:
columnExp.append(col)
else:
columnExp.append(col + " agtype")
else:
columnExp.append('v agtype')
stmtArr = []
stmtArr.append("SELECT * from cypher(NULL,NULL) as (")
stmtArr.append(','.join(columnExp))
stmtArr.append(");")
return "".join(stmtArr)
def execSql(conn:psycopg.connection, stmt:str, commit:bool=False, params:tuple=None) -> psycopg.cursor :
if conn == None or conn.closed:
raise _EXCEPTION_NoConnection
cursor = conn.cursor()
try:
cursor.execute(stmt, params)
if commit:
conn.commit()
return cursor
except SyntaxError as cause:
conn.rollback()
raise cause
except Exception as cause:
conn.rollback()
raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt +")", cause)
def querySql(conn:psycopg.connection, stmt:str, params:tuple=None) -> psycopg.cursor :
return execSql(conn, stmt, False, params)
# Execute cypher statement and return cursor.
# If cypher statement changes data (create, set, remove),
# You must commit session(ag.commit())
# (Otherwise the execution cannot make any effect.)
def execCypher(conn:psycopg.connection, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor :
if conn == None or conn.closed:
raise _EXCEPTION_NoConnection
cursor = conn.cursor()
#clean up the string for mogrification
cypherStmt = cypherStmt.replace("\n", "")
cypherStmt = cypherStmt.replace("\t", "")
cypher = str(cursor.mogrify(cypherStmt, params))
cypher = cypher.strip()
preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})"
cursor = conn.cursor()
try:
cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher)))
except SyntaxError as cause:
conn.rollback()
raise cause
except Exception as cause:
conn.rollback()
raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + preparedStmt +")", cause)
stmt = buildCypher(graphName, cypher, cols)
cursor = conn.cursor()
try:
cursor.execute(stmt)
return cursor
except SyntaxError as cause:
conn.rollback()
raise cause
except Exception as cause:
conn.rollback()
raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt +")", cause)
def cypher(cursor:psycopg.cursor, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor :
#clean up the string for mogrification
cypherStmt = cypherStmt.replace("\n", "")
cypherStmt = cypherStmt.replace("\t", "")
cypher = str(cursor.mogrify(cypherStmt, params))
cypher = cypher.strip()
preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})"
cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher)))
stmt = buildCypher(graphName, cypher, cols)
cursor.execute(stmt)
# def execCypherWithReturn(conn:psycopg.connection, graphName:str, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor :
# stmt = buildCypher(graphName, cypherStmt, columns)
# return execSql(conn, stmt, False, params)
# def queryCypher(conn:psycopg.connection, graphName:str, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor :
# return execCypherWithReturn(conn, graphName, cypherStmt, columns, params)
class Age:
def __init__(self):
self.connection = None # psycopg connection]
self.graphName = None
# Connect to PostgreSQL Server and establish session and type extension environment.
def connect(self, graph:str=None, dsn:str=None, connection_factory=None, cursor_factory=ClientCursor,
load_from_plugins:bool=False, **kwargs):
conn = psycopg.connect(dsn, cursor_factory=cursor_factory, **kwargs)
setUpAge(conn, graph, load_from_plugins)
self.connection = conn
self.graphName = graph
return self
def close(self):
self.connection.close()
def setGraph(self, graph:str):
checkGraphCreated(self.connection, graph)
self.graphName = graph
return self
def commit(self):
self.connection.commit()
def rollback(self):
self.connection.rollback()
def execCypher(self, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor :
return execCypher(self.connection, self.graphName, cypherStmt, cols=cols, params=params)
def cypher(self, cursor:psycopg.cursor, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor :
return cypher(cursor, self.graphName, cypherStmt, cols=cols, params=params)
# def execSql(self, stmt:str, commit:bool=False, params:tuple=None) -> psycopg.cursor :
# return execSql(self.connection, stmt, commit, params)
# def execCypher(self, cypherStmt:str, commit:bool=False, params:tuple=None) -> psycopg.cursor :
# return execCypher(self.connection, self.graphName, cypherStmt, commit, params)
# def execCypherWithReturn(self, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor :
# return execCypherWithReturn(self.connection, self.graphName, cypherStmt, columns, params)
# def queryCypher(self, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor :
# return queryCypher(self.connection, self.graphName, cypherStmt, columns, params)