format code
diff --git a/iotdb-client/client-py/iotdb/dbapi/Cursor.py b/iotdb-client/client-py/iotdb/dbapi/Cursor.py index 018ade6..8897720 100644 --- a/iotdb-client/client-py/iotdb/dbapi/Cursor.py +++ b/iotdb-client/client-py/iotdb/dbapi/Cursor.py
@@ -110,25 +110,6 @@ else: sql = operation % parameters - time_index = [] - time_names = [] - if self.__sqlalchemy_mode: - sql_seqs = [] - seqs = sql.split("\n") - for seq in seqs: - if seq.find("FROM Time Index") >= 0: - time_index = [ - int(index) - for index in seq.replace("FROM Time Index", "").split() - ] - elif seq.find("FROM Time Name") >= 0: - time_names = [ - name for name in seq.replace("FROM Time Name", "").split() - ] - else: - sql_seqs.append(seq) - sql = "\n".join(sql_seqs) - data_set = self.__session.execute_statement(sql) col_names = None col_types = None @@ -136,14 +117,6 @@ if data_set: data = data_set.todf() - - if self.__sqlalchemy_mode and time_index: - time_column = data.columns[0] - time_column_value = data.Time - del data[time_column] - for i in range(len(time_index)): - data.insert(time_index[i], time_names[i], time_column_value) - col_names = data.columns.tolist() col_types = data_set.get_column_types() rows = data.values.tolist()
diff --git a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDialect.py b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDialect.py index 44f9e86..cc32649 100644 --- a/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDialect.py +++ b/iotdb-client/client-py/iotdb/sqlalchemy/IoTDBDialect.py
@@ -37,7 +37,7 @@ "INT32": types.Integer, "INT64": types.BigInteger, "FLOAT": types.Float, - "DOUBLE": types.Float, + "DOUBLE": types.Float(precision=53), "STRING": types.String, "TEXT": types.Text, "BLOB": types.LargeBinary, @@ -86,7 +86,9 @@ def create_connect_args(self, url): opts = url.translate_connect_args() opts.update(url.query) + opts["sqlalchemy_mode"] = True opts["sql_dialect"] = "table" + self._url_database = opts.get("database") return ([], opts) def initialize(self, connection): @@ -96,7 +98,7 @@ return None def _get_default_schema_name(self, connection): - return None + return getattr(self, "_url_database", None) def has_schema(self, connection, schema_name, **kw): return schema_name in self.get_schema_names(connection) @@ -110,16 +112,20 @@ def get_table_names(self, connection, schema=None, **kw): if schema: - connection.execute(text("USE %s" % schema)) - cursor = connection.execute(text("SHOW TABLES")) + quoted = self.identifier_preparer.quote_identifier(schema) + cursor = connection.execute(text("SHOW TABLES FROM %s" % quoted)) + else: + cursor = connection.execute(text("SHOW TABLES")) return [row[0] for row in cursor.fetchall()] def get_columns(self, connection, table_name, schema=None, **kw): + quoted_table = self.identifier_preparer.quote_identifier(table_name) if schema: - connection.execute(text("USE %s" % schema)) - cursor = connection.execute( - text("SHOW COLUMNS FROM %s" % table_name) - ) + quoted_schema = self.identifier_preparer.quote_identifier(schema) + qualified = "%s.%s" % (quoted_schema, quoted_table) + else: + qualified = quoted_table + cursor = connection.execute(text("DESC %s" % qualified)) columns = [] for row in cursor.fetchall(): col_name = row[0]