Simplify common patterns for pandas methods.
diff --git a/sdks/python/apache_beam/dataframe/frame_base.py b/sdks/python/apache_beam/dataframe/frame_base.py
index 7780cd3..0b1c296 100644
--- a/sdks/python/apache_beam/dataframe/frame_base.py
+++ b/sdks/python/apache_beam/dataframe/frame_base.py
@@ -16,7 +16,9 @@
from __future__ import absolute_import
+import functools
import inspect
+import sys
from typing import Any
from typing import Callable
from typing import Dict
@@ -29,6 +31,18 @@
from apache_beam.dataframe import expressions
from apache_beam.dataframe import partitionings
+# pylint: disable=deprecated-method
+if sys.version_info < (3, ):
+ _getargspec = inspect.getargspec
+
+ def _unwrap(func):
+ while hasattr(func, '__wrapped__'):
+ func = func.__wrapped__
+ return func
+else:
+ _getargspec = inspect.getfullargspec
+ _unwrap = inspect.unwrap
+
class DeferredBase(object):
@@ -146,8 +160,7 @@
value = kwargs[key]
else:
try:
- # pylint: disable=deprecated-method
- ix = inspect.getargspec(func).args.index(key)
+ ix = _getargspec(func).args.index(key)
except ValueError:
# TODO: fix for delegation?
continue
@@ -226,6 +239,68 @@
return wrapper
+def maybe_inplace(func):
+ @functools.wraps(func)
+ def wrapper(self, inplace=False, **kwargs):
+ result = func(self, **kwargs)
+ if inplace:
+ self._expr = result._expr
+ else:
+ return result
+
+ return wrapper
+
+
+def args_to_kwargs(base_type):
+ def wrap(func):
+ arg_names = _getargspec(_unwrap(getattr(base_type, func.__name__))).args
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ for name, value in zip(arg_names, args):
+ if name in kwargs:
+ raise TypeError(
+ "%s() got multiple values for argument '%s'" %
+ (func.__name__, name))
+ kwargs[name] = value
+ return func(**kwargs)
+
+ return wrapper
+
+ return wrap
+
+
+def populate_defaults(base_type):
+ def wrap(func):
+ base_argspec = _getargspec(_unwrap(getattr(base_type, func.__name__)))
+ if not base_argspec.defaults:
+ return func
+
+ arg_to_default = dict(
+ zip(
+ base_argspec.args[-len(base_argspec.defaults):],
+ base_argspec.defaults))
+
+ unwrapped_func = _unwrap(func)
+ # args that do not have defaults in func, but do have defaults in base
+ func_argspec = _getargspec(unwrapped_func)
+ num_non_defaults = len(func_argspec.args) - len(func_argspec.defaults or ())
+ defaults_to_populate = set(
+ func_argspec.args[:num_non_defaults]).intersection(
+ arg_to_default.keys())
+
+ @functools.wraps(func)
+ def wrapper(**kwargs):
+ for name in defaults_to_populate:
+ if name not in kwargs:
+ kwargs[name] = arg_to_default[name]
+ return func(**kwargs)
+
+ return wrapper
+
+ return wrap
+
+
class WontImplementError(NotImplementedError):
"""An subclass of NotImplementedError to raise indicating that implementing
the given method is infeasible.
diff --git a/sdks/python/apache_beam/dataframe/frame_base_test.py b/sdks/python/apache_beam/dataframe/frame_base_test.py
index 392272c..b527da0 100644
--- a/sdks/python/apache_beam/dataframe/frame_base_test.py
+++ b/sdks/python/apache_beam/dataframe/frame_base_test.py
@@ -41,6 +41,59 @@
self.assertTrue(sub(x, b)._expr.evaluate_at(session).equals(a - b))
self.assertTrue(sub(a, y)._expr.evaluate_at(session).equals(a - b))
+ def test_maybe_inplace(self):
+ @frame_base.maybe_inplace
+ def add_one(frame):
+ return frame + 1
+
+ frames.DeferredSeries.add_one = add_one
+ original_expr = expressions.PlaceholderExpression(pd.Series([1, 2, 3]))
+ x = frames.DeferredSeries(original_expr)
+ x.add_one()
+ self.assertIs(x._expr, original_expr)
+ x.add_one(inplace=False)
+ self.assertIs(x._expr, original_expr)
+ x.add_one(inplace=True)
+ self.assertIsNot(x._expr, original_expr)
+
+ def test_args_to_kwargs(self):
+ class Base(object):
+ def func(self, a=1, b=2, c=3):
+ pass
+
+ class Proxy(object):
+ @frame_base.args_to_kwargs(Base)
+ def func(self, **kwargs):
+ return kwargs
+
+ proxy = Proxy()
+ # pylint: disable=too-many-function-args
+ self.assertEqual(proxy.func(), {})
+ self.assertEqual(proxy.func(100), {'a': 100})
+ self.assertEqual(proxy.func(2, 4, 6), {'a': 2, 'b': 4, 'c': 6})
+ self.assertEqual(proxy.func(2, c=6), {'a': 2, 'c': 6})
+ self.assertEqual(proxy.func(c=6, a=2), {'a': 2, 'c': 6})
+
+ def test_args_to_kwargs_populates_defaults(self):
+ class Base(object):
+ def func(self, a=1, b=2, c=3):
+ pass
+
+ class Proxy(object):
+ @frame_base.args_to_kwargs(Base)
+ @frame_base.populate_defaults(Base)
+ def func(self, a, c=1000, **kwargs):
+ return dict(kwargs, a=a, c=c)
+
+ proxy = Proxy()
+ # pylint: disable=too-many-function-args
+ self.assertEqual(proxy.func(), {'a': 1, 'c': 1000})
+ self.assertEqual(proxy.func(100), {'a': 100, 'c': 1000})
+ self.assertEqual(proxy.func(2, 4, 6), {'a': 2, 'b': 4, 'c': 6})
+ self.assertEqual(proxy.func(2, c=6), {'a': 2, 'c': 6})
+ self.assertEqual(proxy.func(c=6, a=2), {'a': 2, 'c': 6})
+ self.assertEqual(proxy.func(c=6), {'a': 1, 'c': 6})
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/dataframe/frames.py b/sdks/python/apache_beam/dataframe/frames.py
index 89e9154..6690537 100644
--- a/sdks/python/apache_beam/dataframe/frames.py
+++ b/sdks/python/apache_beam/dataframe/frames.py
@@ -258,7 +258,9 @@
else:
return result
- def reset_index(self, level=None, drop=False, inplace=False, *args, **kwargs):
+ @frame_base.args_to_kwargs(pd.DataFrame)
+ @frame_base.maybe_inplace
+ def reset_index(self, level=None, **kwargs):
if level is not None and not isinstance(level, (tuple, list)):
level = [level]
if level is None or len(level) == len(self._expr.proxy().index.levels):
@@ -266,17 +268,13 @@
requires_partition_by = partitionings.Singleton()
else:
requires_partition_by = partitionings.Nothing()
- result = frame_base.DeferredFrame.wrap(
+ return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
'reset_index',
- lambda df: df.reset_index(level, drop, False, *args, **kwargs),
+ lambda df: df.reset_index(level=level, **kwargs),
[self._expr],
preserves_partition_by=partitionings.Singleton(),
requires_partition_by=requires_partition_by))
- if inplace:
- self._expr = result._expr
- else:
- return result
round = frame_base._elementwise_method('round')
select_dtypes = frame_base._elementwise_method('select_dtypes')