| #!/usr/bin/env python |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import sasl |
| import time |
| |
| from beeswaxd import BeeswaxService |
| from beeswaxd.BeeswaxService import QueryState |
| from ExecStats.ttypes import TExecStats |
| from ImpalaService import ImpalaService |
| from ErrorCodes.ttypes import TErrorCode |
| from Status.ttypes import TStatus |
| from thrift.protocol import TBinaryProtocol |
| from thrift_sasl import TSaslClientTransport |
| from thrift.transport.TSocket import TSocket |
| from thrift.transport.TTransport import TBufferedTransport, TTransportException |
| from thrift.Thrift import TApplicationException |
| |
| class RpcStatus: |
| """Convenience enum to describe Rpc return statuses""" |
| OK = 0 |
| ERROR = 1 |
| |
| class RPCException(Exception): |
| def __init__(self, value=""): |
| self.value = value |
| def __str__(self): |
| return self.value |
| |
| class QueryStateException(Exception): |
| def __init__(self, value=""): |
| self.value = value |
| def __str__(self): |
| return self.value |
| |
| class DisconnectedException(Exception): |
| def __init__(self, value=""): |
| self.value = value |
| def __str__(self): |
| return self.value |
| |
| class ImpalaClient(object): |
| |
| def __init__(self, impalad, use_kerberos=False, kerberos_service_name="impala", |
| use_ssl=False, ca_cert=None, user=None, ldap_password=None, |
| use_ldap=False): |
| self.connected = False |
| self.impalad = impalad |
| self.imp_service = None |
| self.transport = None |
| self.use_kerberos = use_kerberos |
| self.kerberos_service_name = kerberos_service_name |
| self.use_ssl = use_ssl |
| self.ca_cert = ca_cert |
| self.user, self.ldap_password = user, ldap_password |
| self.use_ldap = use_ldap |
| self.default_query_options = {} |
| self.query_state = QueryState._NAMES_TO_VALUES |
| self.fetch_batch_size = 1024 |
| |
| def _options_to_string_list(self, set_query_options): |
| return ["%s=%s" % (k, v) for (k, v) in set_query_options.iteritems()] |
| |
| def build_default_query_options_dict(self): |
| """The default query options are retrieved from a rpc call, and are dependent |
| on the impalad to which a connection has been established. They need to be |
| refreshed each time a connection is made. This is particularly helpful when |
| there is a version mismatch between the shell and the impalad. |
| """ |
| try: |
| get_default_query_options = self.imp_service.get_default_configuration(False) |
| except: |
| return |
| rpc_result = self._do_rpc(lambda: get_default_query_options) |
| options, status = rpc_result |
| if status != RpcStatus.OK: |
| raise RPCException("Unable to retrieve default query options") |
| for option in options: |
| self.default_query_options[option.key.upper()] = option.value |
| |
| def build_summary_table(self, summary, idx, is_fragment_root, indent_level, |
| new_indent_level, output): |
| """Direct translation of Coordinator::PrintExecSummary() to recursively build a list |
| of rows of summary statistics, one per exec node |
| |
| summary: the TExecSummary object that contains all the summary data |
| |
| idx: the index of the node to print |
| |
| is_fragment_root: true if the node to print is the root of a fragment (and therefore |
| feeds into an exchange) |
| |
| indent_level: the number of spaces to print before writing the node's label, to give |
| the appearance of a tree. The 0th child of a node has the same indent_level as its |
| parent. All other children have an indent_level of one greater than their parent. |
| |
| output: the list of rows into which to append the rows produced for this node and its |
| children. |
| |
| Returns the index of the next exec node in summary.exec_nodes that should be |
| processed, used internally to this method only. |
| """ |
| attrs = ["latency_ns", "cpu_time_ns", "cardinality", "memory_used"] |
| |
| # Initialise aggregate and maximum stats |
| agg_stats, max_stats = TExecStats(), TExecStats() |
| for attr in attrs: |
| setattr(agg_stats, attr, 0) |
| setattr(max_stats, attr, 0) |
| |
| node = summary.nodes[idx] |
| if node.exec_stats is not None: |
| for stats in node.exec_stats: |
| for attr in attrs: |
| val = getattr(stats, attr) |
| if val is not None: |
| setattr(agg_stats, attr, getattr(agg_stats, attr) + val) |
| setattr(max_stats, attr, max(getattr(max_stats, attr), val)) |
| |
| if node.exec_stats is not None and node.exec_stats: |
| avg_time = agg_stats.latency_ns / len(node.exec_stats) |
| else: |
| avg_time = 0 |
| |
| # If the node is a broadcast-receiving exchange node, the cardinality of rows produced |
| # is the max over all instances (which should all have received the same number of |
| # rows). Otherwise, the cardinality is the sum over all instances which process |
| # disjoint partitions. |
| if node.is_broadcast and is_fragment_root: |
| cardinality = max_stats.cardinality |
| else: |
| cardinality = agg_stats.cardinality |
| |
| est_stats = node.estimated_stats |
| label_prefix = "" |
| if indent_level > 0: |
| label_prefix = "|" |
| label_prefix += " |" * (indent_level - 1) |
| if new_indent_level: |
| label_prefix += "--" |
| else: |
| label_prefix += " " |
| |
| def prettyprint(val, units, divisor): |
| for unit in units: |
| if val < divisor: |
| if unit == units[0]: |
| return "%d%s" % (val, unit) |
| else: |
| return "%3.2f%s" % (val, unit) |
| val /= divisor |
| |
| def prettyprint_bytes(byte_val): |
| return prettyprint(byte_val, [' B', ' KB', ' MB', ' GB', ' TB'], 1024.0) |
| |
| def prettyprint_units(unit_val): |
| return prettyprint(unit_val, ["", "K", "M", "B"], 1000.0) |
| |
| def prettyprint_time(time_val): |
| return prettyprint(time_val, ["ns", "us", "ms", "s"], 1000.0) |
| |
| hosts = 0 |
| if node.exec_stats is not None: |
| hosts = len(node.exec_stats) |
| row = [ label_prefix + node.label, |
| hosts, |
| prettyprint_time(avg_time), |
| prettyprint_time(max_stats.latency_ns), |
| prettyprint_units(cardinality), |
| prettyprint_units(est_stats.cardinality), |
| prettyprint_bytes(max_stats.memory_used), |
| prettyprint_bytes(est_stats.memory_used), |
| node.label_detail ] |
| |
| output.append(row) |
| try: |
| sender_idx = summary.exch_to_sender_map[idx] |
| # This is an exchange node, so the sender is a fragment root, and should be printed |
| # next. |
| self.build_summary_table(summary, sender_idx, True, indent_level, False, output) |
| except (KeyError, TypeError): |
| # Fall through if idx not in map, or if exch_to_sender_map itself is not set |
| pass |
| |
| idx += 1 |
| if node.num_children > 0: |
| first_child_output = [] |
| idx = \ |
| self.build_summary_table( |
| summary, idx, False, indent_level, False, first_child_output) |
| for child_idx in xrange(1, node.num_children): |
| # All other children are indented (we only have 0, 1 or 2 children for every exec |
| # node at the moment) |
| idx = self.build_summary_table( |
| summary, idx, False, indent_level + 1, True, output) |
| output += first_child_output |
| return idx |
| |
| def test_connection(self): |
| """Checks to see if the current Impala connection is still alive. If not, an exception |
| will be raised.""" |
| if self.connected: |
| self.imp_service.PingImpalaService() |
| |
| def connect(self): |
| """Creates a connection to an Impalad instance |
| |
| The instance of the impala service is then pinged to |
| test the connection and get back the server version |
| """ |
| if self.transport is not None: |
| self.transport.close() |
| self.transport = None |
| |
| self.connected = False |
| self.transport = self._get_transport() |
| self.transport.open() |
| protocol = TBinaryProtocol.TBinaryProtocol(self.transport) |
| self.imp_service = ImpalaService.Client(protocol) |
| result = self.ping_impala_service() |
| self.connected = True |
| return result.version |
| |
| def ping_impala_service(self): |
| return self.imp_service.PingImpalaService() |
| |
| def close_connection(self): |
| """Close the transport if it's still open""" |
| if self.transport: |
| self.transport.close() |
| |
| def _get_transport(self): |
| """Create a Transport. |
| |
| A non-kerberized impalad just needs a simple buffered transport. For |
| the kerberized version, a sasl transport is created. |
| |
| If SSL is enabled, a TSSLSocket underlies the transport stack; otherwise a TSocket |
| is used. |
| """ |
| if self.use_ssl: |
| # TSSLSocket needs the ssl module, which may not be standard on all Operating |
| # Systems. Only attempt to import TSSLSocket if the user wants an SSL connection. |
| from TSSLSocketWithWildcardSAN import TSSLSocketWithWildcardSAN |
| |
| # sasl does not accept unicode strings, explicitly encode the string into ascii. |
| host, port = self.impalad[0].encode('ascii', 'ignore'), int(self.impalad[1]) |
| if self.use_ssl: |
| if self.ca_cert is None: |
| # No CA cert means don't try to verify the certificate |
| sock = TSSLSocketWithWildcardSAN(host, port, validate=False) |
| else: |
| sock = TSSLSocketWithWildcardSAN(host, port, validate=True, ca_certs=self.ca_cert) |
| else: |
| sock = TSocket(host, port) |
| if not (self.use_ldap or self.use_kerberos): |
| return TBufferedTransport(sock) |
| # Initializes a sasl client |
| def sasl_factory(): |
| sasl_client = sasl.Client() |
| sasl_client.setAttr("host", host) |
| if self.use_ldap: |
| sasl_client.setAttr("username", self.user) |
| sasl_client.setAttr("password", self.ldap_password) |
| else: |
| sasl_client.setAttr("service", self.kerberos_service_name) |
| sasl_client.init() |
| return sasl_client |
| # GSSASPI is the underlying mechanism used by kerberos to authenticate. |
| if self.use_kerberos: |
| return TSaslClientTransport(sasl_factory, "GSSAPI", sock) |
| else: |
| return TSaslClientTransport(sasl_factory, "PLAIN", sock) |
| |
| def create_beeswax_query(self, query_str, set_query_options): |
| """Create a beeswax query object from a query string""" |
| query = BeeswaxService.Query() |
| query.hadoop_user = self.user |
| query.query = query_str |
| query.configuration = self._options_to_string_list(set_query_options) |
| return query |
| |
| def execute_query(self, query): |
| rpc_result = self._do_rpc(lambda: self.imp_service.query(query)) |
| last_query_handle, status = rpc_result |
| if status != RpcStatus.OK: |
| raise RPCException("Error executing the query") |
| return last_query_handle |
| |
| def wait_to_finish(self, last_query_handle, periodic_callback=None): |
| loop_start = time.time() |
| while True: |
| query_state = self.get_query_state(last_query_handle) |
| if query_state == self.query_state["FINISHED"]: |
| break |
| elif query_state == self.query_state["EXCEPTION"]: |
| if self.connected: |
| raise QueryStateException(self.get_warning_log(last_query_handle)) |
| else: |
| raise DisconnectedException("Not connected to impalad.") |
| |
| if periodic_callback is not None: periodic_callback() |
| time.sleep(self._get_sleep_interval(loop_start)) |
| |
| def fetch(self, query_handle): |
| """Fetch all the results. |
| This function returns a generator to create an iterable of the result rows. |
| """ |
| result_rows = [] |
| while True: |
| rpc_result = self._do_rpc( |
| lambda: self.imp_service.fetch(query_handle, False, |
| self.fetch_batch_size)) |
| |
| result, status = rpc_result |
| |
| if status != RpcStatus.OK: |
| raise RPCException() |
| |
| result_rows.extend(result.data) |
| |
| if len(result_rows) >= self.fetch_batch_size or not result.has_more: |
| rows = [row.split('\t') for row in result_rows] |
| result_rows = [] |
| yield rows |
| if not result.has_more: |
| break |
| |
| def close_insert(self, last_query_handle): |
| """Fetches the results of an INSERT query""" |
| rpc_result = self._do_rpc( |
| lambda: self.imp_service.CloseInsert(last_query_handle)) |
| insert_result, status = rpc_result |
| |
| if status != RpcStatus.OK: |
| raise RPCException() |
| |
| num_rows = sum([int(k) for k in insert_result.rows_appended.values()]) |
| return num_rows |
| |
| def close_query(self, last_query_handle, query_handle_closed=False): |
| """Close the query handle""" |
| # Make closing a query handle idempotent |
| if query_handle_closed: |
| return True |
| rpc_result = self._do_rpc(lambda: self.imp_service.close(last_query_handle)) |
| _, status = rpc_result |
| return status == RpcStatus.OK |
| |
| def cancel_query(self, last_query_handle, query_handle_closed=False): |
| """Cancel a query on a keyboard interrupt from the shell.""" |
| # Cancel sets query_state to EXCEPTION before calling cancel() in the |
| # co-ordinator, so we don't need to wait. |
| if query_handle_closed: |
| return True |
| rpc_result = self._do_rpc(lambda: self.imp_service.Cancel(last_query_handle)) |
| _, status = rpc_result |
| return status == RpcStatus.OK |
| |
| def get_query_state(self, last_query_handle): |
| rpc_result = self._do_rpc( |
| lambda: self.imp_service.get_state(last_query_handle)) |
| state, status = rpc_result |
| if status != RpcStatus.OK: |
| return self.query_state["EXCEPTION"] |
| return state |
| |
| def get_runtime_profile(self, last_query_handle): |
| rpc_result = self._do_rpc( |
| lambda: self.imp_service.GetRuntimeProfile(last_query_handle)) |
| profile, status = rpc_result |
| if status == RpcStatus.OK and profile: |
| return profile |
| |
| def get_summary(self, last_query_handle): |
| """Calls GetExecSummary() for the last query handle""" |
| rpc_result = self._do_rpc( |
| lambda: self.imp_service.GetExecSummary(last_query_handle)) |
| summary, status = rpc_result |
| if status == RpcStatus.OK and summary: |
| return summary |
| return None |
| |
| def _do_rpc(self, rpc): |
| """Executes the provided callable.""" |
| |
| if not self.connected: |
| raise DisconnectedException("Not connected (use CONNECT to establish a connection)") |
| return None, RpcStatus.ERROR |
| try: |
| ret = rpc() |
| status = RpcStatus.OK |
| # TODO: In the future more advanced error detection/handling can be done based on |
| # the TStatus return value. For now, just print any error(s) that were encountered |
| # and validate the result of the operation was a success. |
| if ret is not None and isinstance(ret, TStatus): |
| if ret.status_code != TErrorCode.OK: |
| if ret.error_msgs: |
| raise RPCException ('RPC Error: %s' % '\n'.join(ret.error_msgs)) |
| status = RpcStatus.ERROR |
| return ret, status |
| except BeeswaxService.QueryNotFoundException: |
| raise QueryStateException('Error: Stale query handle') |
| # beeswaxException prints out the entire object, printing |
| # just the message is far more readable/helpful. |
| except BeeswaxService.BeeswaxException, b: |
| raise RPCException("ERROR: %s" % b.message) |
| except TTransportException, e: |
| # issue with the connection with the impalad |
| raise DisconnectedException("Error communicating with impalad: %s" % e) |
| except TApplicationException, t: |
| raise RPCException("Application Exception : %s" % t) |
| return None, RpcStatus.ERROR |
| |
| def _get_sleep_interval(self, start_time): |
| """Returns a step function of time to sleep in seconds before polling |
| again. Maximum sleep is 1s, minimum is 0.1s""" |
| elapsed = time.time() - start_time |
| if elapsed < 10.0: |
| return 0.1 |
| elif elapsed < 60.0: |
| return 0.5 |
| return 1.0 |
| |
| def get_column_names(self, last_query_handle): |
| rpc_result = self._do_rpc( |
| lambda: self.imp_service.get_results_metadata(last_query_handle)) |
| metadata, _ = rpc_result |
| if not metadata is None: |
| return [fs.name for fs in metadata.schema.fieldSchemas] |
| |
| def expect_result_metadata(self, query_str): |
| """ Given a query string, return True if impalad expects result metadata""" |
| excluded_query_types = ['use', 'alter', 'drop'] |
| if True in set(map(query_str.startswith, excluded_query_types)): |
| return False |
| return True |
| |
| def get_warning_log(self, last_query_handle): |
| if last_query_handle is None: |
| return "Query could not be executed" |
| rpc_result = self._do_rpc( |
| lambda: self.imp_service.get_log(last_query_handle.log_context)) |
| log, status = rpc_result |
| if status != RpcStatus.OK: |
| return "Failed to get error log: %s" % status |
| if log and log.strip(): |
| return "WARNINGS: %s" % log |
| return "" |