| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| |
| from __future__ import division |
| |
| import hashlib |
| |
| import numpy as np |
| |
| from petastorm import utils |
| from petastorm.cache import NullCache |
| from petastorm.workers_pool.worker_base import WorkerBase |
| from petastorm.py_dict_reader_worker import PyDictReaderWorkerResultsQueueReader |
| from petastorm.py_dict_reader_worker import _select_cols, _merge_two_dicts |
| |
| |
| class PyDictCarbonReaderWorker(WorkerBase): |
| def __init__(self, worker_id, publish_func, args): |
| super(PyDictCarbonReaderWorker, self).__init__(worker_id, publish_func, args) |
| |
| self._filesystem = args[0] |
| self._dataset_path = args[1] |
| self._schema = args[2] |
| self._ngram = args[3] |
| self._split_pieces = args[4] |
| self._local_cache = args[5] |
| self._transform_spec = args[6] |
| |
| # We create datasets lazily in the first invocation of 'def process'. This speeds up startup time since |
| # all Worker constructors are serialized |
| self._dataset = None |
| |
| @staticmethod |
| def new_results_queue_reader(): |
| return PyDictReaderWorkerResultsQueueReader() |
| |
| # pylint: disable=arguments-differ |
| def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): |
| """Main worker function. Loads and returns all rows matching the predicate from a blocklet |
| |
| Looks up the requested piece (a single row-group in a carbon file). If a predicate is specified, |
| columns needed by the predicate are loaded first. If no rows in the blocklet matches the predicate criteria |
| the rest of the columns are not loaded. |
| |
| :param piece_index: |
| :param shuffle_row_drop_partition: A tuple 2 of the current row drop partition and the total number |
| of partitions. |
| :return: |
| """ |
| # start = time.time() |
| piece = self._split_pieces[piece_index] |
| |
| if not isinstance(self._local_cache, NullCache): |
| if worker_predicate: |
| raise RuntimeError('Local cache is not supported together with predicates, ' |
| 'unless the dataset is partitioned by the column the predicate operates on.') |
| if shuffle_row_drop_partition[1] != 1: |
| raise RuntimeError('Local cache is not supported together with shuffle_row_drop_partitions > 1') |
| |
| if worker_predicate: |
| all_cols = self._load_rows_with_predicate(piece, worker_predicate, shuffle_row_drop_partition) |
| else: |
| # Using hash of the dataset path with the relative path in order to: |
| # 1. Make sure if a common cache serves multiple processes (e.g. redis), we don't have conflicts |
| # 2. Dataset path is hashed, to make sure we don't create too long keys, which maybe incompatible with |
| # some cache implementations |
| # 3. Still leave relative path and the piece_index in plain text to make it easier to debug |
| cache_key = '{}:{}:{}'.format(hashlib.md5(self._dataset_path.encode('utf-8')).hexdigest(), |
| piece.path, piece_index) |
| # start1 = time.time() |
| all_cols = self._local_cache.get(cache_key, |
| lambda: self._load_rows(piece, shuffle_row_drop_partition)) |
| # print("cache time is " + str(time.time() - start)) |
| |
| if self._ngram: |
| all_cols = self._ngram.form_ngram(data=all_cols, schema=self._schema) |
| |
| if all_cols: |
| self.publish_func(all_cols) |
| # print("process time is " + str(time.time() - start)) |
| |
| def _load_rows_with_predicate(self, piece, worker_predicate, shuffle_row_drop_partition): |
| """Loads all rows that match a predicate from a piece""" |
| |
| # 1. Read all columns needed by predicate and decode |
| # 2. Apply the predicate. If nothing matches, exit early |
| # 3. Read the remaining columns and decode |
| # 4. Combine with columns already decoded for the predicate. |
| |
| # Split all column names into ones that are needed by predicateand the rest. |
| predicate_column_names = set(worker_predicate.get_fields()) |
| |
| if not predicate_column_names: |
| raise ValueError('At least one field name must be returned by predicate\'s get_field() method') |
| |
| all_schema_names = set(field.name for field in self._schema.fields.values()) |
| |
| invalid_column_names = predicate_column_names - all_schema_names |
| if invalid_column_names: |
| raise ValueError('At least some column names requested by the predicate ({}) ' |
| 'are not valid schema names: ({})'.format(', '.join(invalid_column_names), |
| ', '.join(all_schema_names))) |
| |
| other_column_names = all_schema_names - predicate_column_names |
| other_column_names_list = list(other_column_names) |
| |
| predicate_column_names_list = list(predicate_column_names) |
| # Read columns needed for the predicate |
| predicate_rows = self._read_with_shuffle_row_drop(piece, predicate_column_names_list, |
| shuffle_row_drop_partition) |
| |
| # Decode values |
| transform_func = self._transform_spec.func if self._transform_spec else (lambda x: x) |
| decoded_predicate_rows = [ |
| transform_func(utils.decode_row(_select_cols(row, predicate_column_names), self._schema)) |
| for row in predicate_rows] |
| |
| # Use the predicate to filter |
| match_predicate_mask = [worker_predicate.do_include(row) for row in decoded_predicate_rows] |
| |
| # Don't have anything left after filtering? Exit early. |
| if not any(match_predicate_mask): |
| return [] |
| |
| # Remove rows that were filtered out by the predicate |
| filtered_decoded_predicate_rows = [row for i, row in enumerate(decoded_predicate_rows) if |
| match_predicate_mask[i]] |
| |
| if other_column_names: |
| # Read remaining columns |
| other_rows = self._read_with_shuffle_row_drop(piece, other_column_names_list, |
| shuffle_row_drop_partition) |
| |
| # Remove rows that were filtered out by the predicate |
| filtered_other_rows = [row for i, row in enumerate(other_rows) if match_predicate_mask[i]] |
| |
| # Decode remaining columns |
| decoded_other_rows = [utils.decode_row(row, self._schema) for row in filtered_other_rows] |
| |
| # Merge predicate needed columns with the remaining |
| all_cols = [_merge_two_dicts(a, b) for a, b in zip(decoded_other_rows, filtered_decoded_predicate_rows)] |
| return all_cols |
| else: |
| return filtered_decoded_predicate_rows |
| |
| def _load_rows(self, piece, shuffle_row_drop_range): |
| """Loads all rows from a piece""" |
| |
| # pyarrow would fail if we request a column names that the dataset is partitioned by, so we strip them from |
| # the `columns` argument. |
| column_names = list(field.name for field in self._schema.fields.values()) |
| |
| all_rows = self._read_with_shuffle_row_drop(piece, column_names, shuffle_row_drop_range) |
| |
| transform_func = self._transform_spec.func if self._transform_spec else (lambda x: x) |
| return [transform_func(utils.decode_row(row, self._schema)) for row in all_rows] |
| |
| def _read_with_shuffle_row_drop(self, piece, column_names, shuffle_row_drop_partition): |
| # start = time.time() |
| data_frame = piece.read_all( |
| columns=column_names, |
| ) |
| # print(" total piece time taken is " + str(time.time() - start)) |
| # start = time.time() |
| data_frame = data_frame.to_pandas() |
| # print(" panda time is " + str(time.time() - start)) |
| |
| num_rows = len(data_frame) |
| num_partitions = shuffle_row_drop_partition[1] |
| this_partition = shuffle_row_drop_partition[0] |
| |
| partition_indexes = np.floor(np.arange(num_rows) / (float(num_rows) / min(num_rows, num_partitions))) |
| |
| if self._ngram: |
| # If we have an ngram we need to take elements from the next partition to build the sequence |
| next_partition_indexes = np.where(partition_indexes >= this_partition + 1) |
| if next_partition_indexes[0].size: |
| next_partition_to_add = next_partition_indexes[0][0:self._ngram.length - 1] |
| partition_indexes[next_partition_to_add] = this_partition |
| |
| selected_dataframe = data_frame.loc[partition_indexes == this_partition] |
| return selected_dataframe.to_dict('records') |