| #!/usr/bin/env python |
| |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from __future__ import print_function |
| |
| import argparse |
| import base64 |
| import collections |
| import csv |
| import errno |
| import json |
| import numbers |
| import os |
| import re |
| import readline |
| import ssl |
| import sys |
| import time |
| import unicodedata |
| import urllib2 |
| |
| class DruidSqlException(Exception): |
| def friendly_message(self): |
| return self.message if self.message else "Query failed" |
| |
| def write_to(self, f): |
| f.write('\x1b[31m') |
| f.write(self.friendly_message()) |
| f.write('\x1b[0m') |
| f.write('\n') |
| f.flush() |
| |
| def do_query_with_args(url, sql, context, args): |
| return do_query(url, sql, context, args.timeout, args.user, args.ignore_ssl_verification, args.cafile, args.capath, args.certchain, args.keyfile, args.keypass) |
| |
| def do_query(url, sql, context, timeout, user, ignore_ssl_verification, ca_file, ca_path, cert_chain, key_file, key_pass): |
| json_decoder = json.JSONDecoder(object_pairs_hook=collections.OrderedDict) |
| try: |
| if timeout <= 0: |
| timeout = None |
| query_context = context |
| elif int(context.get('timeout', 0)) / 1000. < timeout: |
| query_context = context.copy() |
| query_context['timeout'] = timeout * 1000 |
| |
| sql_json = json.dumps({'query' : sql, 'context' : query_context}) |
| |
| # SSL stuff |
| ssl_context = None |
| if ignore_ssl_verification or ca_file is not None or ca_path is not None or cert_chain is not None: |
| ssl_context = ssl.create_default_context() |
| if ignore_ssl_verification: |
| ssl_context.check_hostname = False |
| ssl_context.verify_mode = ssl.CERT_NONE |
| elif ca_path is not None: |
| ssl_context.load_verify_locations(cafile=ca_file, capath=ca_path) |
| else: |
| ssl_context.load_cert_chain(certfile=cert_chain, keyfile=key_file, password=key_pass) |
| |
| req = urllib2.Request(url, sql_json, {'Content-Type' : 'application/json'}) |
| |
| if user: |
| req.add_header("Authorization", "Basic %s" % base64.b64encode(user)) |
| |
| response = urllib2.urlopen(req, None, timeout, context=ssl_context) |
| |
| first_chunk = True |
| eof = False |
| buf = '' |
| |
| while not eof or len(buf) > 0: |
| while True: |
| try: |
| # Remove starting ',' |
| buf = buf.lstrip(',') |
| obj, sz = json_decoder.raw_decode(buf) |
| yield obj |
| buf = buf[sz:] |
| except ValueError as e: |
| # Maybe invalid JSON, maybe partial object; it's hard to tell with this library. |
| if eof and buf.rstrip() == ']': |
| # Stream done and all objects read. |
| buf = '' |
| break |
| elif eof or len(buf) > 256 * 1024: |
| # If we read more than 256KB or if it's eof then report the parse error. |
| raise |
| else: |
| # Stop reading objects, get more from the stream instead. |
| break |
| |
| # Read more from the http stream |
| if not eof: |
| chunk = response.read(8192) |
| if chunk: |
| buf = buf + chunk |
| if first_chunk: |
| # Remove starting '[' |
| buf = buf.lstrip('[') |
| else: |
| # Stream done. Keep reading objects out of buf though. |
| eof = True |
| |
| except urllib2.URLError as e: |
| raise_friendly_error(e) |
| |
| def raise_friendly_error(e): |
| if isinstance(e, urllib2.HTTPError): |
| text = e.read().strip() |
| error_obj = {} |
| try: |
| error_obj = dict(json.loads(text)) |
| except: |
| pass |
| if e.code == 500 and 'errorMessage' in error_obj: |
| error_text = '' |
| if error_obj['error'] != 'Unknown exception': |
| error_text = error_text + error_obj['error'] + ': ' |
| if error_obj['errorClass']: |
| error_text = error_text + str(error_obj['errorClass']) + ': ' |
| error_text = error_text + str(error_obj['errorMessage']) |
| if error_obj['host']: |
| error_text = error_text + ' (' + str(error_obj['host']) + ')' |
| raise DruidSqlException(error_text) |
| elif e.code == 405: |
| error_text = 'HTTP Error {0}: {1}\n{2}'.format(e.code, e.reason + " - Are you using the correct broker URL and " +\ |
| "is druid.sql.enabled set to true on your broker?", text) |
| raise DruidSqlException(error_text) |
| else: |
| raise DruidSqlException("HTTP Error {0}: {1}\n{2}".format(e.code, e.reason, text)) |
| else: |
| raise DruidSqlException(str(e)) |
| |
| def to_utf8(value): |
| if value is None: |
| return "" |
| elif isinstance(value, unicode): |
| return value.encode("utf-8") |
| else: |
| return str(value) |
| |
| def to_tsv(values, delimiter): |
| return delimiter.join(to_utf8(v).replace(delimiter, '') for v in values) |
| |
| def print_csv(rows, header): |
| csv_writer = csv.writer(sys.stdout) |
| first = True |
| for row in rows: |
| if first and header: |
| csv_writer.writerow(list(to_utf8(k) for k in row.keys())) |
| first = False |
| |
| values = [] |
| for key, value in row.iteritems(): |
| values.append(to_utf8(value)) |
| |
| csv_writer.writerow(values) |
| |
| def print_tsv(rows, header, tsv_delimiter): |
| first = True |
| for row in rows: |
| if first and header: |
| print(to_tsv(row.keys(), tsv_delimiter)) |
| first = False |
| |
| values = [] |
| for key, value in row.iteritems(): |
| values.append(value) |
| |
| print(to_tsv(values, tsv_delimiter)) |
| |
| def print_json(rows): |
| for row in rows: |
| print(json.dumps(row)) |
| |
| def table_to_printable_value(value): |
| # Unicode string, trimmed with control characters removed |
| if value is None: |
| return u"NULL" |
| else: |
| return to_utf8(value).strip().decode('utf-8').translate(dict.fromkeys(range(32))) |
| |
| def table_compute_string_width(v): |
| normalized = unicodedata.normalize('NFC', v) |
| width = 0 |
| for c in normalized: |
| ccategory = unicodedata.category(c) |
| cwidth = unicodedata.east_asian_width(c) |
| if ccategory == 'Cf': |
| # Formatting control, zero width |
| pass |
| elif cwidth == 'F' or cwidth == 'W': |
| # Double-wide character, prints in two columns |
| width = width + 2 |
| else: |
| # All other characters |
| width = width + 1 |
| return width |
| |
| def table_compute_column_widths(row_buffer): |
| widths = None |
| for values in row_buffer: |
| values_widths = [table_compute_string_width(v) for v in values] |
| if not widths: |
| widths = values_widths |
| else: |
| i = 0 |
| for v in values: |
| widths[i] = max(widths[i], values_widths[i]) |
| i = i + 1 |
| return widths |
| |
| def table_print_row(values, column_widths, column_types): |
| vertical_line = u'\u2502'.encode('utf-8') |
| for i in xrange(0, len(values)): |
| padding = ' ' * max(0, column_widths[i] - table_compute_string_width(values[i])) |
| if column_types and column_types[i] == 'n': |
| print(vertical_line + ' ' + padding + values[i].encode('utf-8') + ' ', end="") |
| else: |
| print(vertical_line + ' ' + values[i].encode('utf-8') + padding + ' ', end="") |
| print(vertical_line) |
| |
| def table_print_header(values, column_widths): |
| # Line 1 |
| left_corner = u'\u250C'.encode('utf-8') |
| horizontal_line = u'\u2500'.encode('utf-8') |
| top_tee = u'\u252C'.encode('utf-8') |
| right_corner = u'\u2510'.encode('utf-8') |
| print(left_corner, end="") |
| for i in xrange(0, len(column_widths)): |
| print(horizontal_line * max(0, column_widths[i] + 2), end="") |
| if i + 1 < len(column_widths): |
| print(top_tee, end="") |
| print(right_corner) |
| |
| # Line 2 |
| table_print_row(values, column_widths, None) |
| |
| # Line 3 |
| left_tee = u'\u251C'.encode('utf-8') |
| cross = u'\u253C'.encode('utf-8') |
| right_tee = u'\u2524'.encode('utf-8') |
| print(left_tee, end="") |
| for i in xrange(0, len(column_widths)): |
| print(horizontal_line * max(0, column_widths[i] + 2), end="") |
| if i + 1 < len(column_widths): |
| print(cross, end="") |
| print(right_tee) |
| |
| def table_print_bottom(column_widths): |
| left_corner = u'\u2514'.encode('utf-8') |
| right_corner = u'\u2518'.encode('utf-8') |
| bottom_tee = u'\u2534'.encode('utf-8') |
| horizontal_line = u'\u2500'.encode('utf-8') |
| print(left_corner, end="") |
| for i in xrange(0, len(column_widths)): |
| print(horizontal_line * max(0, column_widths[i] + 2), end="") |
| if i + 1 < len(column_widths): |
| print(bottom_tee, end="") |
| print(right_corner) |
| |
| def table_print_row_buffer(row_buffer, column_widths, column_types): |
| first = True |
| for values in row_buffer: |
| if first: |
| table_print_header(values, column_widths) |
| first = False |
| else: |
| table_print_row(values, column_widths, column_types) |
| |
| def print_table(rows): |
| start = time.time() |
| nrows = 0 |
| first = True |
| |
| # Buffer some rows before printing. |
| rows_to_buffer = 500 |
| row_buffer = [] |
| column_types = [] |
| column_widths = None |
| |
| for row in rows: |
| nrows = nrows + 1 |
| |
| if first: |
| row_buffer.append([table_to_printable_value(k) for k in row.keys()]) |
| for k in row.keys(): |
| if isinstance(row[k], numbers.Number): |
| column_types.append('n') |
| else: |
| column_types.append('s') |
| first = False |
| |
| values = [table_to_printable_value(v) for k, v in row.iteritems()] |
| if rows_to_buffer > 0: |
| row_buffer.append(values) |
| rows_to_buffer = rows_to_buffer - 1 |
| else: |
| if row_buffer: |
| column_widths = table_compute_column_widths(row_buffer) |
| table_print_row_buffer(row_buffer, column_widths, column_types) |
| del row_buffer[:] |
| table_print_row(values, column_widths, column_types) |
| |
| if row_buffer: |
| column_widths = table_compute_column_widths(row_buffer) |
| table_print_row_buffer(row_buffer, column_widths, column_types) |
| |
| if column_widths: |
| table_print_bottom(column_widths) |
| |
| print("Retrieved {0:,d} row{1:s} in {2:.2f}s.".format(nrows, 's' if nrows != 1 else '', time.time() - start)) |
| print("") |
| |
| def display_query(url, sql, context, args): |
| rows = do_query_with_args(url, sql, context, args) |
| |
| if args.format == 'csv': |
| print_csv(rows, args.header) |
| elif args.format == 'tsv': |
| print_tsv(rows, args.header, args.tsv_delimiter) |
| elif args.format == 'json': |
| print_json(rows) |
| elif args.format == 'table': |
| print_table(rows) |
| |
| def sql_literal_escape(s): |
| if s is None: |
| return "''" |
| elif isinstance(s, unicode): |
| ustr = s |
| else: |
| ustr = str(s).decode('utf-8') |
| |
| escaped = [u"U&'"] |
| |
| for c in ustr: |
| ccategory = unicodedata.category(c) |
| if ccategory.startswith('L') or ccategory.startswith('N') or c == ' ': |
| escaped.append(c) |
| else: |
| escaped.append(u'\\') |
| escaped.append('%04x' % ord(c)) |
| |
| escaped.append("'") |
| return ''.join(escaped) |
| |
| def make_readline_completer(url, context, args): |
| starters = [ |
| 'EXPLAIN PLAN FOR', |
| 'SELECT' |
| ] |
| |
| middlers = [ |
| 'FROM', |
| 'WHERE', |
| 'GROUP BY', |
| 'ORDER BY', |
| 'LIMIT' |
| ] |
| |
| def readline_completer(text, state): |
| if readline.get_begidx() == 0: |
| results = [x for x in starters if x.startswith(text.upper())] + [None] |
| else: |
| results = ([x for x in middlers if x.startswith(text.upper())] + [None]) |
| |
| return results[state] + " " |
| |
| print("Connected to [" + args.host + "].") |
| print("") |
| |
| return readline_completer |
| |
| def main(): |
| parser = argparse.ArgumentParser(description='Druid SQL command-line client.') |
| parser_cnn = parser.add_argument_group('Connection options') |
| parser_fmt = parser.add_argument_group('Formatting options') |
| parser_oth = parser.add_argument_group('Other options') |
| parser_cnn.add_argument('--host', '-H', type=str, default='http://localhost:8082/', help='Druid query host or url, like https://localhost:8282/') |
| parser_cnn.add_argument('--user', '-u', type=str, help='HTTP basic authentication credentials, like user:password') |
| parser_cnn.add_argument('--timeout', type=int, default=0, help='Timeout in seconds') |
| parser_cnn.add_argument('--cafile', type=str, help='Path to SSL CA file for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.') |
| parser_cnn.add_argument('--capath', type=str, help='SSL CA path for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.') |
| parser_cnn.add_argument('--ignore-ssl-verification', '-k', action='store_true', default=False, help='Skip verification of SSL certificates.') |
| parser_fmt.add_argument('--format', type=str, default='table', choices=('csv', 'tsv', 'json', 'table'), help='Result format') |
| parser_fmt.add_argument('--header', action='store_true', help='Include header row for formats "csv" and "tsv"') |
| parser_fmt.add_argument('--tsv-delimiter', type=str, default='\t', help='Delimiter for format "tsv"') |
| parser_oth.add_argument('--context-option', '-c', type=str, action='append', help='Set context option for this connection, see https://druid.apache.org/docs/latest/querying/sql.html#connection-context for options') |
| parser_oth.add_argument('--execute', '-e', type=str, help='Execute single SQL query') |
| parser_cnn.add_argument('--certchain', type=str, help='Path to SSL certificate used to connect to server. See load_cert_chain() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.') |
| parser_cnn.add_argument('--keyfile', type=str, help='Path to private SSL key used to connect to server. See load_cert_chain() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.') |
| parser_cnn.add_argument('--keypass', type=str, help='Password to private SSL key file used to connect to server. See load_cert_chain() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.') |
| args = parser.parse_args() |
| |
| # Build broker URL |
| url = args.host.rstrip('/') + '/druid/v2/sql/' |
| if not url.startswith('http:') and not url.startswith('https:'): |
| url = 'http://' + url |
| |
| # Build context |
| context = {} |
| if args.context_option: |
| for opt in args.context_option: |
| kv = opt.split("=", 1) |
| if len(kv) != 2: |
| raise ValueError('Invalid context option, should be key=value: ' + opt) |
| if re.match(r"^\d+$", kv[1]): |
| context[kv[0]] = long(kv[1]) |
| else: |
| context[kv[0]] = kv[1] |
| |
| if args.execute: |
| display_query(url, args.execute, context, args) |
| else: |
| # interactive mode |
| print("Welcome to dsql, the command-line client for Druid SQL.") |
| |
| readline_history_file = os.path.expanduser("~/.dsql_history") |
| readline.parse_and_bind('tab: complete') |
| readline.set_history_length(500) |
| readline.set_completer(make_readline_completer(url, context, args)) |
| |
| try: |
| readline.read_history_file(readline_history_file) |
| except IOError: |
| # IOError can happen if the file doesn't exist. |
| pass |
| |
| print("Type \"\\h\" for help.") |
| |
| while True: |
| sql = '' |
| while not sql.endswith(';'): |
| prompt = "dsql> " if sql == '' else 'more> ' |
| try: |
| more_sql = raw_input(prompt) |
| except EOFError: |
| sys.stdout.write('\n') |
| sys.exit(1) |
| if sql == '' and more_sql.startswith('\\'): |
| # backslash command |
| dmatch = re.match(r'^\\d(S?)(\+?)(\s+.*?|)\s*$', more_sql) |
| if dmatch: |
| include_system = dmatch.group(1) |
| extra_info = dmatch.group(2) |
| arg = dmatch.group(3).strip() |
| if arg: |
| sql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = " + sql_literal_escape(arg) |
| if not include_system: |
| sql = sql + " AND TABLE_SCHEMA = 'druid'" |
| # break to execute sql |
| break |
| else: |
| sql = "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES" |
| if not include_system: |
| sql = sql + " WHERE TABLE_SCHEMA = 'druid'" |
| # break to execute sql |
| break |
| |
| hmatch = re.match(r'^\\h\s*$', more_sql) |
| if hmatch: |
| print("Commands:") |
| print(" \\d show tables") |
| print(" \\dS show tables, including system tables") |
| print(" \\d table_name describe table") |
| print(" \\h show this help") |
| print(" \\q exit this program") |
| print("Or enter a SQL query ending with a semicolon (;).") |
| continue |
| |
| qmatch = re.match(r'^\\q\s*$', more_sql) |
| if qmatch: |
| sys.exit(0) |
| |
| print("No such command: " + more_sql) |
| else: |
| sql = (sql + ' ' + more_sql).strip() |
| |
| try: |
| readline.write_history_file(readline_history_file) |
| display_query(url, sql.rstrip(';'), context, args) |
| except DruidSqlException as e: |
| e.write_to(sys.stdout) |
| except KeyboardInterrupt: |
| sys.stdout.write("Query interrupted\n") |
| sys.stdout.flush() |
| |
| try: |
| main() |
| except DruidSqlException as e: |
| e.write_to(sys.stderr) |
| sys.exit(1) |
| except KeyboardInterrupt: |
| sys.exit(1) |
| except IOError as e: |
| if e.errno == errno.EPIPE: |
| sys.exit(1) |
| else: |
| raise |