blob: 81917cfff54f895eb85abda69e46afb03709ac68 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
import unittest
import pandas as pd
import apache_beam as beam
from apache_beam.dataframe import expressions
from apache_beam.dataframe import frame_base
from apache_beam.dataframe import transforms
from apache_beam.testing.util import assert_that
class TransformTest(unittest.TestCase):
def run_scenario(self, input, func):
expected = func(input)
empty = input[0:0]
input_placeholder = expressions.PlaceholderExpression(empty)
input_deferred = frame_base.DeferredFrame.wrap(input_placeholder)
actual_deferred = func(input_deferred)._expr.evaluate_at(
expressions.Session({input_placeholder: input}))
def concat(parts):
if len(parts) > 1:
return pd.concat(parts)
elif len(parts) == 1:
return parts[0]
else:
return None
def check_correct(actual):
if actual is None:
raise AssertionError('Empty frame but expected: \n\n%s' % (expected))
if isinstance(expected, pd.core.generic.NDFrame):
sorted_actual = actual.sort_index()
sorted_expected = expected.sort_index()
if not sorted_actual.equals(sorted_expected):
raise AssertionError(
'Dataframes not equal: \n\n%s\n\n%s' %
(sorted_actual, sorted_expected))
else:
if actual != expected:
raise AssertionError(
'Scalars not equal: %s != %s' % (actual, expected))
check_correct(actual_deferred)
with beam.Pipeline() as p:
input_pcoll = p | beam.Create([input[::2], input[1::2]])
output_pcoll = input_pcoll | transforms.DataframeTransform(
func, proxy=empty)
assert_that(output_pcoll, lambda actual: check_correct(concat(actual)))
def test_identity(self):
df = pd.DataFrame({
'Animal': ['Falcon', 'Falcon', 'Parrot', 'Parrot'],
'Speed': [380., 370., 24., 26.]
})
self.run_scenario(df, lambda x: x)
def test_sum_mean(self):
df = pd.DataFrame({
'Animal': ['Falcon', 'Falcon', 'Parrot', 'Parrot'],
'Speed': [380., 370., 24., 26.]
})
self.run_scenario(df, lambda df: df.groupby('Animal').sum())
self.run_scenario(df, lambda df: df.groupby('Animal').mean())
def test_filter(self):
df = pd.DataFrame({
'Animal': ['Aardvark', 'Ant', 'Elephant', 'Zebra'],
'Speed': [5, 2, 35, 40]
})
self.run_scenario(df, lambda df: df.filter(items=['Animal']))
self.run_scenario(df, lambda df: df.filter(regex='A.*'))
self.run_scenario(
df, lambda df: df.set_index('Animal').filter(regex='F.*', axis='index'))
def test_aggregate(self):
a = pd.DataFrame({'col': [1, 2, 3]})
self.run_scenario(a, lambda a: a.agg(sum))
self.run_scenario(a, lambda a: a.agg(['mean', 'min', 'max']))
def test_scalar(self):
a = pd.Series([1, 2, 6])
self.run_scenario(a, lambda a: a.agg(sum))
self.run_scenario(a, lambda a: a / a.agg(sum))
# Tests scalar being used as an input to a downstream stage.
df = pd.DataFrame({'key': ['a', 'a', 'b'], 'val': [1, 2, 6]})
self.run_scenario(
df, lambda df: df.groupby('key').sum().val / df.val.agg(sum))
def test_input_output_polymorphism(self):
one_series = pd.Series([1])
two_series = pd.Series([2])
three_series = pd.Series([3])
proxy = one_series[:0]
def equal_to_series(expected):
def check(actual):
actual = pd.concat(actual)
if not expected.equals(actual):
raise AssertionError(
'Series not equal: \n%s\n%s\n' % (expected, actual))
return check
with beam.Pipeline() as p:
one = p | 'One' >> beam.Create([one_series])
two = p | 'Two' >> beam.Create([two_series])
assert_that(
one | 'PcollInPcollOut' >> transforms.DataframeTransform(
lambda x: 3 * x, proxy=proxy),
equal_to_series(three_series),
label='CheckPcollInPcollOut')
assert_that((one, two)
| 'TupleIn' >> transforms.DataframeTransform(
lambda x, y: (x + y), (proxy, proxy)),
equal_to_series(three_series),
label='CheckTupleIn')
assert_that(
dict(x=one, y=two)
| 'DictIn' >> transforms.DataframeTransform(
lambda x, y: (x + y), proxy=dict(x=proxy, y=proxy)),
equal_to_series(three_series),
label='CheckDictIn')
double, triple = one | 'TupleOut' >> transforms.DataframeTransform(
lambda x: (2*x, 3*x), proxy)
assert_that(double, equal_to_series(two_series), 'CheckTupleOut0')
assert_that(triple, equal_to_series(three_series), 'CheckTupleOut1')
res = one | 'DictOut' >> transforms.DataframeTransform(
lambda x: {'res': 3 * x}, proxy)
assert_that(res['res'], equal_to_series(three_series), 'CheckDictOut')
if __name__ == '__main__':
unittest.main()