blob: 7ba4a27f764f4b14b4d321e9ad6e510f124bb3cd [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
import psycopg2
from psycopg2 import errors
from psycopg2 import extensions as ext
from psycopg2 import sql
from .exceptions import *
from .builder import ResultHandler , parseAgeValue, newResultHandler
_EXCEPTION_NoConnection = NoConnection()
_EXCEPTION_GraphNotSet = GraphNotSet()
WHITESPACE = re.compile('\s')
def setUpAge(conn:ext.connection, graphName:str):
with conn.cursor() as cursor:
cursor.execute("LOAD 'age';")
cursor.execute("SET search_path = ag_catalog, '$user', public;")
cursor.execute("SELECT typelem FROM pg_type WHERE typname='_agtype'")
oid = cursor.fetchone()[0]
if oid == None :
raise AgeNotSet()
AGETYPE = ext.new_type((oid,), 'AGETYPE', parseAgeValue)
ext.register_type(AGETYPE)
# ext.register_adapter(Path, marshalAgtValue)
# Check graph exists
if graphName != None:
checkGraphCreated(conn, graphName)
# Create the graph, if it does not exist
def checkGraphCreated(conn:ext.connection, graphName:str):
with conn.cursor() as cursor:
cursor.execute(sql.SQL("SELECT count(*) FROM ag_graph WHERE name={graphName}").format(graphName=sql.Literal(graphName)))
if cursor.fetchone()[0] == 0:
cursor.execute(sql.SQL("SELECT create_graph({graphName});").format(graphName=sql.Literal(graphName)))
conn.commit()
def deleteGraph(conn:ext.connection, graphName:str):
with conn.cursor() as cursor:
cursor.execute(sql.SQL("SELECT drop_graph({graphName}, true);").format(graphName=sql.Literal(graphName)))
conn.commit()
def buildCypher(graphName:str, cypherStmt:str, columns:list) ->str:
if graphName == None:
raise _EXCEPTION_GraphNotSet
columnExp=[]
if columns != None and len(columns) > 0:
for col in columns:
if col.strip() == '':
continue
elif WHITESPACE.search(col) != None:
columnExp.append(col)
else:
columnExp.append(col + " agtype")
else:
columnExp.append('v agtype')
stmtArr = []
stmtArr.append("SELECT * from cypher(NULL,NULL) as (")
stmtArr.append(','.join(columnExp))
stmtArr.append(");")
return "".join(stmtArr)
def execSql(conn:ext.connection, stmt:str, commit:bool=False, params:tuple=None) -> ext.cursor :
if conn == None or conn.closed:
raise _EXCEPTION_NoConnection
cursor = conn.cursor()
try:
cursor.execute(stmt, params)
if commit:
conn.commit()
return cursor
except SyntaxError as cause:
conn.rollback()
raise cause
except Exception as cause:
conn.rollback()
raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt +")", cause)
def querySql(conn:ext.connection, stmt:str, params:tuple=None) -> ext.cursor :
return execSql(conn, stmt, False, params)
# Execute cypher statement and return cursor.
# If cypher statement changes data (create, set, remove),
# You must commit session(ag.commit())
# (Otherwise the execution cannot make any effect.)
def execCypher(conn:ext.connection, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> ext.cursor :
if conn == None or conn.closed:
raise _EXCEPTION_NoConnection
cursor = conn.cursor()
#clean up the string for mogrificiation
cypherStmt = cypherStmt.replace("\n", "")
cypherStmt = cypherStmt.replace("\t", "")
cypher = str(cursor.mogrify(cypherStmt, params))
cypher = cypher[2:len(cypher)-1]
preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})"
cursor = conn.cursor()
try:
cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher)))
except SyntaxError as cause:
conn.rollback()
raise cause
except Exception as cause:
conn.rollback()
raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + preparedStmt +")", cause)
stmt = buildCypher(graphName, cypher, cols)
cursor = conn.cursor()
try:
cursor.execute(stmt)
return cursor
except SyntaxError as cause:
conn.rollback()
raise cause
except Exception as cause:
conn.rollback()
raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt +")", cause)
def cypher(cursor:ext.cursor, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> ext.cursor :
#clean up the string for mogrificiation
cypherStmt = cypherStmt.replace("\n", "")
cypherStmt = cypherStmt.replace("\t", "")
cypher = str(cursor.mogrify(cypherStmt, params))
cypher = cypher[2:len(cypher)-1]
preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})"
cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher)))
stmt = buildCypher(graphName, cypher, cols)
cursor.execute(stmt)
# def execCypherWithReturn(conn:ext.connection, graphName:str, cypherStmt:str, columns:list=None , params:tuple=None) -> ext.cursor :
# stmt = buildCypher(graphName, cypherStmt, columns)
# return execSql(conn, stmt, False, params)
# def queryCypher(conn:ext.connection, graphName:str, cypherStmt:str, columns:list=None , params:tuple=None) -> ext.cursor :
# return execCypherWithReturn(conn, graphName, cypherStmt, columns, params)
class Age:
def __init__(self):
self.connection = None # psycopg2 connection]
self.graphName = None
# Connect to PostgreSQL Server and establish session and type extension environment.
def connect(self, graph:str=None, dsn:str=None, connection_factory=None, cursor_factory=None, **kwargs):
conn = psycopg2.connect(dsn, connection_factory, cursor_factory, **kwargs)
setUpAge(conn, graph)
self.connection = conn
self.graphName = graph
return self
def close(self):
self.connection.close()
def setGraph(self, graph:str):
checkGraphCreated(self.connection, graph)
self.graphName = graph
return self
def commit(self):
self.connection.commit()
def rollback(self):
self.connection.rollback()
def execCypher(self, cypherStmt:str, cols:list=None, params:tuple=None) -> ext.cursor :
return execCypher(self.connection, self.graphName, cypherStmt, cols=cols, params=params)
def cypher(self, cursor:ext.cursor, cypherStmt:str, cols:list=None, params:tuple=None) -> ext.cursor :
return cypher(cursor, self.graphName, cypherStmt, cols=cols, params=params)
# def execSql(self, stmt:str, commit:bool=False, params:tuple=None) -> ext.cursor :
# return execSql(self.connection, stmt, commit, params)
# def execCypher(self, cypherStmt:str, commit:bool=False, params:tuple=None) -> ext.cursor :
# return execCypher(self.connection, self.graphName, cypherStmt, commit, params)
# def execCypherWithReturn(self, cypherStmt:str, columns:list=None , params:tuple=None) -> ext.cursor :
# return execCypherWithReturn(self.connection, self.graphName, cypherStmt, columns, params)
# def queryCypher(self, cypherStmt:str, columns:list=None , params:tuple=None) -> ext.cursor :
# return queryCypher(self.connection, self.graphName, cypherStmt, columns, params)