blob: 60ef98ee0efce84e40fbed64124987496b21318d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import hashlib
import numpy as np
from petastorm import utils
from petastorm.cache import NullCache
from petastorm.workers_pool.worker_base import WorkerBase
from petastorm.py_dict_reader_worker import PyDictReaderWorkerResultsQueueReader
from petastorm.py_dict_reader_worker import _select_cols, _merge_two_dicts
class PyDictCarbonReaderWorker(WorkerBase):
def __init__(self, worker_id, publish_func, args):
super(PyDictCarbonReaderWorker, self).__init__(worker_id, publish_func, args)
self._filesystem = args[0]
self._dataset_path = args[1]
self._schema = args[2]
self._ngram = args[3]
self._split_pieces = args[4]
self._local_cache = args[5]
self._transform_spec = args[6]
# We create datasets lazily in the first invocation of 'def process'. This speeds up startup time since
# all Worker constructors are serialized
self._dataset = None
@staticmethod
def new_results_queue_reader():
return PyDictReaderWorkerResultsQueueReader()
# pylint: disable=arguments-differ
def process(self, piece_index, worker_predicate, shuffle_row_drop_partition):
"""Main worker function. Loads and returns all rows matching the predicate from a blocklet
Looks up the requested piece (a single row-group in a carbon file). If a predicate is specified,
columns needed by the predicate are loaded first. If no rows in the blocklet matches the predicate criteria
the rest of the columns are not loaded.
:param piece_index:
:param shuffle_row_drop_partition: A tuple 2 of the current row drop partition and the total number
of partitions.
:return:
"""
# start = time.time()
piece = self._split_pieces[piece_index]
if not isinstance(self._local_cache, NullCache):
if worker_predicate:
raise RuntimeError('Local cache is not supported together with predicates, '
'unless the dataset is partitioned by the column the predicate operates on.')
if shuffle_row_drop_partition[1] != 1:
raise RuntimeError('Local cache is not supported together with shuffle_row_drop_partitions > 1')
if worker_predicate:
all_cols = self._load_rows_with_predicate(piece, worker_predicate, shuffle_row_drop_partition)
else:
# Using hash of the dataset path with the relative path in order to:
# 1. Make sure if a common cache serves multiple processes (e.g. redis), we don't have conflicts
# 2. Dataset path is hashed, to make sure we don't create too long keys, which maybe incompatible with
# some cache implementations
# 3. Still leave relative path and the piece_index in plain text to make it easier to debug
cache_key = '{}:{}:{}'.format(hashlib.md5(self._dataset_path.encode('utf-8')).hexdigest(),
piece.path, piece_index)
# start1 = time.time()
all_cols = self._local_cache.get(cache_key,
lambda: self._load_rows(piece, shuffle_row_drop_partition))
# print("cache time is " + str(time.time() - start))
if self._ngram:
all_cols = self._ngram.form_ngram(data=all_cols, schema=self._schema)
if all_cols:
self.publish_func(all_cols)
# print("process time is " + str(time.time() - start))
def _load_rows_with_predicate(self, piece, worker_predicate, shuffle_row_drop_partition):
"""Loads all rows that match a predicate from a piece"""
# 1. Read all columns needed by predicate and decode
# 2. Apply the predicate. If nothing matches, exit early
# 3. Read the remaining columns and decode
# 4. Combine with columns already decoded for the predicate.
# Split all column names into ones that are needed by predicateand the rest.
predicate_column_names = set(worker_predicate.get_fields())
if not predicate_column_names:
raise ValueError('At least one field name must be returned by predicate\'s get_field() method')
all_schema_names = set(field.name for field in self._schema.fields.values())
invalid_column_names = predicate_column_names - all_schema_names
if invalid_column_names:
raise ValueError('At least some column names requested by the predicate ({}) '
'are not valid schema names: ({})'.format(', '.join(invalid_column_names),
', '.join(all_schema_names)))
other_column_names = all_schema_names - predicate_column_names
other_column_names_list = list(other_column_names)
predicate_column_names_list = list(predicate_column_names)
# Read columns needed for the predicate
predicate_rows = self._read_with_shuffle_row_drop(piece, predicate_column_names_list,
shuffle_row_drop_partition)
# Decode values
transform_func = self._transform_spec.func if self._transform_spec else (lambda x: x)
decoded_predicate_rows = [
transform_func(utils.decode_row(_select_cols(row, predicate_column_names), self._schema))
for row in predicate_rows]
# Use the predicate to filter
match_predicate_mask = [worker_predicate.do_include(row) for row in decoded_predicate_rows]
# Don't have anything left after filtering? Exit early.
if not any(match_predicate_mask):
return []
# Remove rows that were filtered out by the predicate
filtered_decoded_predicate_rows = [row for i, row in enumerate(decoded_predicate_rows) if
match_predicate_mask[i]]
if other_column_names:
# Read remaining columns
other_rows = self._read_with_shuffle_row_drop(piece, other_column_names_list,
shuffle_row_drop_partition)
# Remove rows that were filtered out by the predicate
filtered_other_rows = [row for i, row in enumerate(other_rows) if match_predicate_mask[i]]
# Decode remaining columns
decoded_other_rows = [utils.decode_row(row, self._schema) for row in filtered_other_rows]
# Merge predicate needed columns with the remaining
all_cols = [_merge_two_dicts(a, b) for a, b in zip(decoded_other_rows, filtered_decoded_predicate_rows)]
return all_cols
else:
return filtered_decoded_predicate_rows
def _load_rows(self, piece, shuffle_row_drop_range):
"""Loads all rows from a piece"""
# pyarrow would fail if we request a column names that the dataset is partitioned by, so we strip them from
# the `columns` argument.
column_names = list(field.name for field in self._schema.fields.values())
all_rows = self._read_with_shuffle_row_drop(piece, column_names, shuffle_row_drop_range)
transform_func = self._transform_spec.func if self._transform_spec else (lambda x: x)
return [transform_func(utils.decode_row(row, self._schema)) for row in all_rows]
def _read_with_shuffle_row_drop(self, piece, column_names, shuffle_row_drop_partition):
# start = time.time()
data_frame = piece.read_all(
columns=column_names,
)
# print(" total piece time taken is " + str(time.time() - start))
# start = time.time()
data_frame = data_frame.to_pandas()
# print(" panda time is " + str(time.time() - start))
num_rows = len(data_frame)
num_partitions = shuffle_row_drop_partition[1]
this_partition = shuffle_row_drop_partition[0]
partition_indexes = np.floor(np.arange(num_rows) / (float(num_rows) / min(num_rows, num_partitions)))
if self._ngram:
# If we have an ngram we need to take elements from the next partition to build the sequence
next_partition_indexes = np.where(partition_indexes >= this_partition + 1)
if next_partition_indexes[0].size:
next_partition_to_add = next_partition_indexes[0][0:self._ngram.length - 1]
partition_indexes[next_partition_to_add] = this_partition
selected_dataframe = data_frame.loc[partition_indexes == this_partition]
return selected_dataframe.to_dict('records')