blob: 03cfd235983d1eb6d1fea50e1b0e9ec177f2c6aa [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import hashlib
import operator
import numpy as np
import pandas as pd
import pyarrow as pa
from petastorm.cache import NullCache
from petastorm.workers_pool.worker_base import WorkerBase
from petastorm.arrow_reader_worker import ArrowReaderWorkerResultsQueueReader
class ArrowCarbonReaderWorker(WorkerBase):
def __init__(self, worker_id, publish_func, args):
super(ArrowCarbonReaderWorker, self).__init__(worker_id, publish_func, args)
self._filesystem = args[0]
self._dataset_path = args[1]
self._schema = args[2]
self._ngram = args[3]
self._split_pieces = args[4]
self._local_cache = args[5]
self._transform_spec = args[6]
if self._ngram:
raise NotImplementedError('ngrams are not supported by ArrowReaderWorker')
# We create datasets lazily in the first invocation of 'def process'. This speeds up startup time since
# all Worker constructors are serialized
self._dataset = None
def new_results_queue_reader():
return ArrowReaderWorkerResultsQueueReader()
# pylint: disable=arguments-differ
def process(self, piece_index, worker_predicate, shuffle_row_drop_partition):
"""Main worker function. Loads and returns all rows matching the predicate from a blocklet
Looks up the requested piece (a single row-group in a carbon file). If a predicate is specified,
columns needed by the predicate are loaded first. If no rows in the blocklet matches the predicate criteria
the rest of the columns are not loaded.
:param piece_index:
:param shuffle_row_drop_partition: A tuple 2 of the current row drop partition and the total number
of partitions.
piece = self._split_pieces[piece_index]
if not isinstance(self._local_cache, NullCache):
if worker_predicate:
raise RuntimeError('Local cache is not supported together with predicates, '
'unless the dataset is partitioned by the column the predicate operates on.')
if shuffle_row_drop_partition[1] != 1:
raise RuntimeError('Local cache is not supported together with shuffle_row_drop_partitions > 1')
if worker_predicate:
all_cols = self._load_rows_with_predicate(piece, worker_predicate, shuffle_row_drop_partition)
# Using hash of the dataset path with the relative path in order to:
# 1. Make sure if a common cache serves multiple processes (e.g. redis), we don't have conflicts
# 2. Dataset path is hashed, to make sure we don't create too long keys, which maybe incompatible with
# some cache implementations
# 3. Still leave relative path and the piece_index in plain text to make it easier to debug
cache_key = '{}:{}:{}'.format(hashlib.md5(self._dataset_path.encode('utf-8')).hexdigest(),
piece.path, piece_index)
all_cols = self._local_cache.get(cache_key,
lambda: self._load_rows(piece, shuffle_row_drop_partition))
if all_cols:
def _load_rows(self, piece, shuffle_row_drop_range):
"""Loads all rows from a piece"""
# pyarrow would fail if we request a column names that the dataset is partitioned by, so we strip them from
# the `columns` argument.
# partitions = self._dataset.partitions
column_names_in_schema = list( for field in self._schema.fields.values())
# column_names = column_names_in_schema - partitions.partition_names
result = self._read_with_shuffle_row_drop(piece, column_names_in_schema, shuffle_row_drop_range)
if self._transform_spec:
result = pa.Table.from_pandas(self._transform_spec.func(result.to_pandas()), preserve_index=False)
return result
def _load_rows_with_predicate(self, piece, worker_predicate, shuffle_row_drop_partition):
"""Loads all rows that match a predicate from a piece"""
# 1. Read all columns needed by predicate
# 2. Apply the predicate. If nothing matches, exit early
# 3. Read the remaining columns
# Split all column names into ones that are needed by predicateand the rest.
predicate_column_names = set(worker_predicate.get_fields())
if not predicate_column_names:
raise ValueError('At least one field name must be returned by predicate\'s get_field() method')
all_schema_names = set( for field in self._schema.fields.values())
invalid_column_names = predicate_column_names - all_schema_names
if invalid_column_names:
raise ValueError('At least some column names requested by the predicate ({}) '
'are not valid schema names: ({})'.format(', '.join(invalid_column_names),
', '.join(all_schema_names)))
# Split into 'columns for predicate evaluation' and 'other columns'. We load 'other columns' only if at
# least one row in the blocklet matched the predicate
other_column_names = all_schema_names - predicate_column_names
# Read columns needed for the predicate
predicate_column_names_list = list(predicate_column_names)
predicates_table = self._read_with_shuffle_row_drop(piece, predicate_column_names_list,
predicates_data_frame = predicates_table.to_pandas()
match_predicate_mask = worker_predicate.do_include(predicates_data_frame)
erase_mask =
# Don't have anything left after filtering? Exit early.
if erase_mask.all():
return []
predicates_data_frame[erase_mask] = None
if other_column_names:
# Read remaining columns
other_column_names_list = list(other_column_names)
other_table = self._read_with_shuffle_row_drop(piece, other_column_names_list,
other_data_frame = other_table.to_pandas()
other_data_frame[erase_mask] = None
# Partition-by columns will appear in both other and predicate data frames. Deduplicate.
columns_from_predicates = predicates_data_frame.columns.difference(other_data_frame.columns)
result_data_frame = pd.merge(predicates_data_frame[columns_from_predicates], other_data_frame,
copy=False, left_index=True, right_index=True)
result_data_frame = predicates_data_frame
result = result_data_frame[match_predicate_mask]
if self._transform_spec:
result = self._transform_spec.func(result)
return pa.Table.from_pandas(result, preserve_index=False)
def _read_with_shuffle_row_drop(self, piece, column_names, shuffle_row_drop_partition):
table = piece.read_all(
num_rows = len(table)
num_partitions = shuffle_row_drop_partition[1]
this_partition = shuffle_row_drop_partition[0]
if num_partitions > 1:
data_frame_pandas = table.to_pandas()
partition_indexes = np.floor(np.arange(num_rows) / (float(num_rows) / min(num_rows, num_partitions)))
table = pa.Table.from_pandas(data_frame_pandas.loc[partition_indexes == this_partition],
return table