blob: 1f4242933cd472e89d168eeec4f27caad4a9d528 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""A ValueProvider class to implement templates with both statically
and dynamically provided values.
"""
from __future__ import absolute_import
from builtins import object
from functools import wraps
from apache_beam import error
__all__ = [
'ValueProvider',
'StaticValueProvider',
'RuntimeValueProvider',
'check_accessible',
]
class ValueProvider(object):
def is_accessible(self):
raise NotImplementedError(
'ValueProvider.is_accessible implemented in derived classes'
)
def get(self):
raise NotImplementedError(
'ValueProvider.get implemented in derived classes'
)
class StaticValueProvider(ValueProvider):
def __init__(self, value_type, value):
self.value_type = value_type
self.value = value_type(value)
def is_accessible(self):
return True
def get(self):
return self.value
def __str__(self):
return str(self.value)
def __eq__(self, other):
if self.value == other:
return True
if isinstance(other, StaticValueProvider):
if (self.value_type == other.value_type and
self.value == other.value):
return True
return False
def __hash__(self):
return hash((type(self), self.value_type, self.value))
class RuntimeValueProvider(ValueProvider):
runtime_options = None
experiments = set()
def __init__(self, option_name, value_type, default_value):
self.option_name = option_name
self.default_value = default_value
self.value_type = value_type
def is_accessible(self):
return RuntimeValueProvider.runtime_options is not None
@classmethod
def get_value(cls, option_name, value_type, default_value):
if not RuntimeValueProvider.runtime_options:
return default_value
candidate = RuntimeValueProvider.runtime_options.get(option_name)
if candidate:
return value_type(candidate)
else:
return default_value
def get(self):
if RuntimeValueProvider.runtime_options is None:
raise error.RuntimeValueProviderError(
'%s.get() not called from a runtime context' % self)
return RuntimeValueProvider.get_value(self.option_name,
self.value_type,
self.default_value)
@classmethod
def set_runtime_options(cls, pipeline_options):
RuntimeValueProvider.runtime_options = pipeline_options
RuntimeValueProvider.experiments = RuntimeValueProvider.get_value(
'experiments', set, set())
def __str__(self):
return '%s(option: %s, type: %s, default_value: %s)' % (
self.__class__.__name__,
self.option_name,
self.value_type.__name__,
repr(self.default_value)
)
def check_accessible(value_provider_list):
"""Check accessibility of a list of ValueProvider objects."""
assert isinstance(value_provider_list, list)
def _check_accessible(fnc):
@wraps(fnc)
def _f(self, *args, **kwargs):
for obj in [getattr(self, vp) for vp in value_provider_list]:
if not obj.is_accessible():
raise error.RuntimeValueProviderError('%s not accessible' % obj)
return fnc(self, *args, **kwargs)
return _f
return _check_accessible