blob: 7075810ff0f099320c857bb9fc65fa31ad443939 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for apache_beam.ml.base."""
import math
import os
import pickle
import time
import unittest
from typing import Any
from typing import Dict
from typing import Iterable
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Union
import pytest
import apache_beam as beam
from apache_beam.examples.inference import run_inference_side_inputs
from apache_beam.metrics.metric import MetricsFilter
from apache_beam.ml.inference import base
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms import trigger
from apache_beam.transforms import window
from apache_beam.transforms.periodicsequence import TimestampedValue
from apache_beam.utils import multi_process_shared
class FakeModel:
def predict(self, example: int) -> int:
return example + 1
class FakeStatefulModel:
def __init__(self, state: int):
if state == 100:
raise Exception('Oh no')
self._state = state
def predict(self, example: int) -> int:
return self._state
def increment_state(self, amount: int):
self._state += amount
class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
def __init__(
self,
clock=None,
min_batch_size=1,
max_batch_size=9999,
multi_process_shared=False,
state=None,
num_bytes_per_element=None,
**kwargs):
self._fake_clock = clock
self._min_batch_size = min_batch_size
self._max_batch_size = max_batch_size
self._env_vars = kwargs.get('env_vars', {})
self._multi_process_shared = multi_process_shared
self._state = state
self._num_bytes_per_element = num_bytes_per_element
def load_model(self):
if self._fake_clock:
self._fake_clock.current_time_ns += 500_000_000 # 500ms
if self._state is not None:
return FakeStatefulModel(self._state)
return FakeModel()
def run_inference(
self,
batch: Sequence[int],
model: FakeModel,
inference_args=None) -> Iterable[int]:
multi_process_shared_loaded = "multi_process_shared" in str(type(model))
if self._multi_process_shared != multi_process_shared_loaded:
raise Exception(
f'Loaded model of type {type(model)}, was' +
f'{"" if self._multi_process_shared else " not"} ' +
'expecting multi_process_shared_model')
if self._fake_clock:
self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds
for example in batch:
yield model.predict(example)
def update_model_path(self, model_path: Optional[str] = None):
pass
def batch_elements_kwargs(self):
return {
'min_batch_size': self._min_batch_size,
'max_batch_size': self._max_batch_size
}
def share_model_across_processes(self):
return self._multi_process_shared
def get_num_bytes(self, batch: Sequence[int]) -> int:
if self._num_bytes_per_element:
return self._num_bytes_per_element * len(batch)
return super().get_num_bytes(batch)
class FakeModelHandlerReturnsPredictionResult(
base.ModelHandler[int, base.PredictionResult, FakeModel]):
def __init__(
self,
clock=None,
model_id='fake_model_id_default',
multi_process_shared=False,
state=None):
self.model_id = model_id
self._fake_clock = clock
self._env_vars = {}
self._multi_process_shared = multi_process_shared
self._state = state
def load_model(self):
if self._state is not None:
return FakeStatefulModel(0)
return FakeModel()
def run_inference(
self,
batch: Sequence[int],
model: Union[FakeModel, FakeStatefulModel],
inference_args=None) -> Iterable[base.PredictionResult]:
multi_process_shared_loaded = "multi_process_shared" in str(type(model))
if self._multi_process_shared != multi_process_shared_loaded:
raise Exception(
f'Loaded model of type {type(model)}, was' +
f'{"" if self._multi_process_shared else " not"} ' +
'expecting multi_process_shared_model')
for example in batch:
yield base.PredictionResult(
model_id=self.model_id,
example=example,
inference=model.predict(example))
if self._state is not None:
model.increment_state(1) # type: ignore[union-attr]
def update_model_path(self, model_path: Optional[str] = None):
self.model_id = model_path if model_path else self.model_id
def share_model_across_processes(self):
return self._multi_process_shared
class FakeModelHandlerNoEnvVars(base.ModelHandler[int, int, FakeModel]):
def __init__(
self, clock=None, min_batch_size=1, max_batch_size=9999, **kwargs):
self._fake_clock = clock
self._min_batch_size = min_batch_size
self._max_batch_size = max_batch_size
def load_model(self):
if self._fake_clock:
self._fake_clock.current_time_ns += 500_000_000 # 500ms
return FakeModel()
def run_inference(
self,
batch: Sequence[int],
model: FakeModel,
inference_args=None) -> Iterable[int]:
if self._fake_clock:
self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds
for example in batch:
yield model.predict(example)
def update_model_path(self, model_path: Optional[str] = None):
pass
def batch_elements_kwargs(self):
return {
'min_batch_size': self._min_batch_size,
'max_batch_size': self._max_batch_size
}
class FakeClock:
def __init__(self):
# Start at 10 seconds.
self.current_time_ns = 10_000_000_000
def time_ns(self) -> int:
return self.current_time_ns
class ExtractInferences(beam.DoFn):
def process(self, prediction_result):
yield prediction_result.inference
class FakeModelHandlerNeedsBigBatch(FakeModelHandler):
def run_inference(self, batch, unused_model, inference_args=None):
if len(batch) < 100:
raise ValueError('Unexpectedly small batch')
return batch
def batch_elements_kwargs(self):
return {'min_batch_size': 9999}
class FakeModelHandlerFailsOnInferenceArgs(FakeModelHandler):
def run_inference(self, batch, unused_model, inference_args=None):
raise ValueError(
'run_inference should not be called because error should already be '
'thrown from the validate_inference_args check.')
class FakeModelHandlerExpectedInferenceArgs(FakeModelHandler):
def run_inference(self, batch, unused_model, inference_args=None):
if not inference_args:
raise ValueError('inference_args should exist')
return batch
def validate_inference_args(self, inference_args):
pass
class RunInferenceBaseTest(unittest.TestCase):
def test_run_inference_impl_simple_examples(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
expected = [example + 1 for example in examples]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(FakeModelHandler())
assert_that(actual, equal_to(expected), label='assert:inferences')
def test_run_inference_impl_simple_examples_multi_process_shared(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
expected = [example + 1 for example in examples]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(
FakeModelHandler(multi_process_shared=True))
assert_that(actual, equal_to(expected), label='assert:inferences')
def test_run_inference_impl_with_keyed_examples(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
keyed_examples = [(i, example) for i, example in enumerate(examples)]
expected = [(i, example + 1) for i, example in enumerate(examples)]
pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
actual = pcoll | base.RunInference(
base.KeyedModelHandler(FakeModelHandler()))
assert_that(actual, equal_to(expected), label='assert:inferences')
def test_run_inference_impl_with_keyed_examples_many_model_handlers(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
keyed_examples = [(i, example) for i, example in enumerate(examples)]
expected = [(i, example + 1) for i, example in enumerate(examples)]
expected[0] = (0, 200)
pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
mhs = [
base.KeyModelMapping([0],
FakeModelHandler(
state=200, multi_process_shared=True)),
base.KeyModelMapping([1, 2, 3],
FakeModelHandler(multi_process_shared=True))
]
actual = pcoll | base.RunInference(base.KeyedModelHandler(mhs))
assert_that(actual, equal_to(expected), label='assert:inferences')
def test_run_inference_impl_with_keyed_examples_many_model_handlers_metrics(
self):
pipeline = TestPipeline()
examples = [1, 5, 3, 10]
metrics_namespace = 'test_namespace'
keyed_examples = [(i, example) for i, example in enumerate(examples)]
pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
mhs = [
base.KeyModelMapping([0],
FakeModelHandler(
state=200, multi_process_shared=True)),
base.KeyModelMapping([1, 2, 3],
FakeModelHandler(multi_process_shared=True))
]
_ = pcoll | base.RunInference(
base.KeyedModelHandler(mhs), metrics_namespace=metrics_namespace)
result = pipeline.run()
result.wait_until_finish()
metrics_filter = MetricsFilter().with_namespace(namespace=metrics_namespace)
metrics = result.metrics().query(metrics_filter)
assert len(metrics['counters']) != 0
assert len(metrics['distributions']) != 0
metrics_filter = MetricsFilter().with_name('0-_num_inferences')
metrics = result.metrics().query(metrics_filter)
num_inferences_counter_key_0 = metrics['counters'][0]
self.assertEqual(num_inferences_counter_key_0.committed, 1)
metrics_filter = MetricsFilter().with_name('1-_num_inferences')
metrics = result.metrics().query(metrics_filter)
num_inferences_counter_key_1 = metrics['counters'][0]
self.assertEqual(num_inferences_counter_key_1.committed, 3)
metrics_filter = MetricsFilter().with_name('num_inferences')
metrics = result.metrics().query(metrics_filter)
num_inferences_counter_aggregate = metrics['counters'][0]
self.assertEqual(num_inferences_counter_aggregate.committed, 4)
metrics_filter = MetricsFilter().with_name('0-_failed_batches_counter')
metrics = result.metrics().query(metrics_filter)
failed_batches_counter_key_0 = metrics['counters']
self.assertEqual(len(failed_batches_counter_key_0), 0)
metrics_filter = MetricsFilter().with_name('failed_batches_counter')
metrics = result.metrics().query(metrics_filter)
failed_batches_counter_aggregate = metrics['counters']
self.assertEqual(len(failed_batches_counter_aggregate), 0)
metrics_filter = MetricsFilter().with_name(
'0-_load_model_latency_milli_secs')
metrics = result.metrics().query(metrics_filter)
load_latency_dist_key_0 = metrics['distributions'][0]
self.assertEqual(load_latency_dist_key_0.committed.count, 1)
metrics_filter = MetricsFilter().with_name('load_model_latency_milli_secs')
metrics = result.metrics().query(metrics_filter)
load_latency_dist_aggregate = metrics['distributions'][0]
self.assertEqual(load_latency_dist_aggregate.committed.count, 2)
def test_run_inference_impl_with_keyed_examples_many_mhs_max_models_hint(
self):
pipeline = TestPipeline()
examples = [1, 5, 3, 10, 2, 4, 6, 8, 9, 7, 1, 5, 3, 10, 2, 4, 6, 8, 9, 7]
metrics_namespace = 'test_namespace'
keyed_examples = [(i, example) for i, example in enumerate(examples)]
pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
mhs = [
base.KeyModelMapping([0, 2, 4, 6, 8],
FakeModelHandler(
state=200, multi_process_shared=True)),
base.KeyModelMapping(
[1, 3, 5, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
FakeModelHandler(multi_process_shared=True))
]
_ = pcoll | base.RunInference(
base.KeyedModelHandler(mhs, max_models_per_worker_hint=1),
metrics_namespace=metrics_namespace)
result = pipeline.run()
result.wait_until_finish()
metrics_filter = MetricsFilter().with_namespace(namespace=metrics_namespace)
metrics = result.metrics().query(metrics_filter)
assert len(metrics['counters']) != 0
assert len(metrics['distributions']) != 0
metrics_filter = MetricsFilter().with_name('load_model_latency_milli_secs')
metrics = result.metrics().query(metrics_filter)
load_latency_dist_aggregate = metrics['distributions'][0]
# We should flip back and forth between models a bit since
# max_models_per_worker_hint=1, but we shouldn't thrash forever
# since most examples belong to the second ModelMapping
self.assertGreater(load_latency_dist_aggregate.committed.count, 2)
self.assertLess(load_latency_dist_aggregate.committed.count, 12)
def test_keyed_many_model_handlers_validation(self):
def mult_two(example: str) -> int:
return int(example) * 2
mhs = [
base.KeyModelMapping(
[0],
FakeModelHandler(
state=200,
multi_process_shared=True).with_preprocess_fn(mult_two)),
base.KeyModelMapping([1, 2, 3],
FakeModelHandler(multi_process_shared=True))
]
with self.assertRaises(ValueError):
base.KeyedModelHandler(mhs)
mhs = [
base.KeyModelMapping(
[0],
FakeModelHandler(
state=200,
multi_process_shared=True).with_postprocess_fn(mult_two)),
base.KeyModelMapping([1, 2, 3],
FakeModelHandler(multi_process_shared=True))
]
with self.assertRaises(ValueError):
base.KeyedModelHandler(mhs)
mhs = [
base.KeyModelMapping([0],
FakeModelHandler(
state=200, multi_process_shared=True)),
base.KeyModelMapping([0, 1, 2, 3],
FakeModelHandler(multi_process_shared=True))
]
with self.assertRaises(ValueError):
base.KeyedModelHandler(mhs)
mhs = [
base.KeyModelMapping([],
FakeModelHandler(
state=200, multi_process_shared=True)),
base.KeyModelMapping([0, 1, 2, 3],
FakeModelHandler(multi_process_shared=True))
]
with self.assertRaises(ValueError):
base.KeyedModelHandler(mhs)
def test_keyed_model_handler_get_num_bytes(self):
mh = base.KeyedModelHandler(FakeModelHandler(num_bytes_per_element=10))
batch = [('key1', 1), ('key2', 2), ('key1', 3)]
expected = len(pickle.dumps(('key1', 'key2', 'key1'))) + 30
actual = mh.get_num_bytes(batch)
self.assertEqual(expected, actual)
def test_keyed_model_handler_multiple_models_get_num_bytes(self):
mhs = [
base.KeyModelMapping(['key1'],
FakeModelHandler(num_bytes_per_element=10)),
base.KeyModelMapping(['key2'],
FakeModelHandler(num_bytes_per_element=20))
]
mh = base.KeyedModelHandler(mhs)
batch = [('key1', 1), ('key2', 2), ('key1', 3)]
expected = len(pickle.dumps(('key1', 'key2', 'key1'))) + 40
actual = mh.get_num_bytes(batch)
self.assertEqual(expected, actual)
def test_run_inference_impl_with_maybe_keyed_examples(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
keyed_examples = [(i, example) for i, example in enumerate(examples)]
expected = [example + 1 for example in examples]
keyed_expected = [(i, example + 1) for i, example in enumerate(examples)]
model_handler = base.MaybeKeyedModelHandler(FakeModelHandler())
pcoll = pipeline | 'Unkeyed' >> beam.Create(examples)
actual = pcoll | 'RunUnkeyed' >> base.RunInference(model_handler)
assert_that(actual, equal_to(expected), label='CheckUnkeyed')
keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples)
keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference(
model_handler)
assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
def test_run_inference_impl_with_keyed_examples_multi_process_shared(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
keyed_examples = [(i, example) for i, example in enumerate(examples)]
expected = [(i, example + 1) for i, example in enumerate(examples)]
pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
actual = pcoll | base.RunInference(
base.KeyedModelHandler(FakeModelHandler(multi_process_shared=True)))
assert_that(actual, equal_to(expected), label='assert:inferences')
def test_run_inference_impl_with_maybe_keyed_examples_multi_process_shared(
self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
keyed_examples = [(i, example) for i, example in enumerate(examples)]
expected = [example + 1 for example in examples]
keyed_expected = [(i, example + 1) for i, example in enumerate(examples)]
model_handler = base.MaybeKeyedModelHandler(
FakeModelHandler(multi_process_shared=True))
pcoll = pipeline | 'Unkeyed' >> beam.Create(examples)
actual = pcoll | 'RunUnkeyed' >> base.RunInference(model_handler)
assert_that(actual, equal_to(expected), label='CheckUnkeyed')
keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples)
keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference(
model_handler)
assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
def test_run_inference_preprocessing(self):
def mult_two(example: str) -> int:
return int(example) * 2
with TestPipeline() as pipeline:
examples = ["1", "5", "3", "10"]
expected = [int(example) * 2 + 1 for example in examples]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(
FakeModelHandler().with_preprocess_fn(mult_two))
assert_that(actual, equal_to(expected), label='assert:inferences')
def test_run_inference_preprocessing_multiple_fns(self):
def add_one(example: str) -> int:
return int(example) + 1
def mult_two(example: int) -> int:
return example * 2
with TestPipeline() as pipeline:
examples = ["1", "5", "3", "10"]
expected = [(int(example) + 1) * 2 + 1 for example in examples]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(
FakeModelHandler().with_preprocess_fn(mult_two).with_preprocess_fn(
add_one))
assert_that(actual, equal_to(expected), label='assert:inferences')
def test_run_inference_postprocessing(self):
def mult_two(example: int) -> str:
return str(example * 2)
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
expected = [str((example + 1) * 2) for example in examples]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(
FakeModelHandler().with_postprocess_fn(mult_two))
assert_that(actual, equal_to(expected), label='assert:inferences')
def test_run_inference_postprocessing_multiple_fns(self):
def add_one(example: int) -> str:
return str(int(example) + 1)
def mult_two(example: int) -> int:
return example * 2
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
expected = [str(((example + 1) * 2) + 1) for example in examples]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(
FakeModelHandler().with_postprocess_fn(mult_two).with_postprocess_fn(
add_one))
assert_that(actual, equal_to(expected), label='assert:inferences')
def test_run_inference_preprocessing_dlq(self):
def mult_two(example: str) -> int:
if example == "5":
raise Exception("TEST")
return int(example) * 2
with TestPipeline() as pipeline:
examples = ["1", "5", "3", "10"]
expected = [3, 7, 21]
expected_bad = ["5"]
pcoll = pipeline | 'start' >> beam.Create(examples)
main, other = pcoll | base.RunInference(
FakeModelHandler().with_preprocess_fn(mult_two)
).with_exception_handling()
assert_that(main, equal_to(expected), label='assert:inferences')
assert_that(
other.failed_inferences, equal_to([]), label='assert:bad_infer')
# bad will be in form [element, error]. Just pull out bad element.
bad_without_error = other.failed_preprocessing[0] | beam.Map(
lambda x: x[0])
assert_that(
bad_without_error, equal_to(expected_bad), label='assert:failures')
def test_run_inference_postprocessing_dlq(self):
def mult_two(example: int) -> str:
if example == 6:
raise Exception("TEST")
return str(example * 2)
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
expected = ["4", "8", "22"]
expected_bad = [6]
pcoll = pipeline | 'start' >> beam.Create(examples)
main, other = pcoll | base.RunInference(
FakeModelHandler().with_postprocess_fn(mult_two)
).with_exception_handling()
assert_that(main, equal_to(expected), label='assert:inferences')
assert_that(
other.failed_inferences, equal_to([]), label='assert:bad_infer')
# bad will be in form [element, error]. Just pull out bad element.
bad_without_error = other.failed_postprocessing[0] | beam.Map(
lambda x: x[0])
assert_that(
bad_without_error, equal_to(expected_bad), label='assert:failures')
def test_run_inference_pre_and_post_processing_dlq(self):
def mult_two_pre(example: str) -> int:
if example == "5":
raise Exception("TEST")
return int(example) * 2
def mult_two_post(example: int) -> str:
if example == 7:
raise Exception("TEST")
return str(example * 2)
with TestPipeline() as pipeline:
examples = ["1", "5", "3", "10"]
expected = ["6", "42"]
expected_bad_pre = ["5"]
expected_bad_post = [7]
pcoll = pipeline | 'start' >> beam.Create(examples)
main, other = pcoll | base.RunInference(
FakeModelHandler().with_preprocess_fn(
mult_two_pre
).with_postprocess_fn(
mult_two_post
)).with_exception_handling()
assert_that(main, equal_to(expected), label='assert:inferences')
assert_that(
other.failed_inferences, equal_to([]), label='assert:bad_infer')
# bad will be in form [elements, error]. Just pull out bad element.
bad_without_error_pre = other.failed_preprocessing[0] | beam.Map(
lambda x: x[0])
assert_that(
bad_without_error_pre,
equal_to(expected_bad_pre),
label='assert:failures_pre')
# bad will be in form [elements, error]. Just pull out bad element.
bad_without_error_post = other.failed_postprocessing[0] | beam.Map(
lambda x: x[0])
assert_that(
bad_without_error_post,
equal_to(expected_bad_post),
label='assert:failures_post')
def test_run_inference_keyed_pre_and_post_processing(self):
def mult_two(element):
return (element[0], element[1] * 2)
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
keyed_examples = [(i, example) for i, example in enumerate(examples)]
expected = [
(i, ((example * 2) + 1) * 2) for i, example in enumerate(examples)
]
pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
actual = pcoll | base.RunInference(
base.KeyedModelHandler(FakeModelHandler()).with_preprocess_fn(
mult_two).with_postprocess_fn(mult_two))
assert_that(actual, equal_to(expected), label='assert:inferences')
def test_run_inference_maybe_keyed_pre_and_post_processing(self):
def mult_two(element):
return element * 2
def mult_two_keyed(element):
return (element[0], element[1] * 2)
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
keyed_examples = [(i, example) for i, example in enumerate(examples)]
expected = [((2 * example) + 1) * 2 for example in examples]
keyed_expected = [
(i, ((2 * example) + 1) * 2) for i, example in enumerate(examples)
]
model_handler = base.MaybeKeyedModelHandler(FakeModelHandler())
pcoll = pipeline | 'Unkeyed' >> beam.Create(examples)
actual = pcoll | 'RunUnkeyed' >> base.RunInference(
model_handler.with_preprocess_fn(mult_two).with_postprocess_fn(
mult_two))
assert_that(actual, equal_to(expected), label='CheckUnkeyed')
keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples)
keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference(
model_handler.with_preprocess_fn(mult_two_keyed).with_postprocess_fn(
mult_two_keyed))
assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
def test_run_inference_impl_dlq(self):
with TestPipeline() as pipeline:
examples = [1, 'TEST', 3, 10, 'TEST2']
expected_good = [2, 4, 11]
expected_bad = ['TEST', 'TEST2']
pcoll = pipeline | 'start' >> beam.Create(examples)
main, other = pcoll | base.RunInference(
FakeModelHandler(
min_batch_size=1,
max_batch_size=1
)).with_exception_handling()
assert_that(main, equal_to(expected_good), label='assert:inferences')
# bad.failed_inferences will be in form [batch[elements], error].
# Just pull out bad element.
bad_without_error = other.failed_inferences | beam.Map(lambda x: x[0][0])
assert_that(
bad_without_error, equal_to(expected_bad), label='assert:failures')
def test_run_inference_impl_inference_args(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
pcoll = pipeline | 'start' >> beam.Create(examples)
inference_args = {'key': True}
actual = pcoll | base.RunInference(
FakeModelHandlerExpectedInferenceArgs(),
inference_args=inference_args)
assert_that(actual, equal_to(examples), label='assert:inferences')
def test_run_inference_metrics_with_custom_namespace(self):
metrics_namespace = 'my_custom_namespace'
pipeline = TestPipeline()
examples = [1, 5, 3, 10]
pcoll = pipeline | 'start' >> beam.Create(examples)
_ = pcoll | base.RunInference(
FakeModelHandler(), metrics_namespace=metrics_namespace)
result = pipeline.run()
result.wait_until_finish()
metrics_filter = MetricsFilter().with_namespace(namespace=metrics_namespace)
metrics = result.metrics().query(metrics_filter)
assert len(metrics['counters']) != 0
assert len(metrics['distributions']) != 0
metrics_filter = MetricsFilter().with_namespace(namespace='fake_namespace')
metrics = result.metrics().query(metrics_filter)
assert len(metrics['counters']) == len(metrics['distributions']) == 0
def test_unexpected_inference_args_passed(self):
with self.assertRaisesRegex(ValueError, r'inference_args were provided'):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
pcoll = pipeline | 'start' >> beam.Create(examples)
inference_args = {'key': True}
_ = pcoll | base.RunInference(
FakeModelHandlerFailsOnInferenceArgs(),
inference_args=inference_args)
def test_increment_failed_batches_counter(self):
with self.assertRaises(ValueError):
with TestPipeline() as pipeline:
examples = [7]
pcoll = pipeline | 'start' >> beam.Create(examples)
_ = pcoll | base.RunInference(FakeModelHandlerExpectedInferenceArgs())
run_result = pipeline.run()
run_result.wait_until_finish()
metric_results = (
run_result.metrics().query(
MetricsFilter().with_name('failed_batches_counter')))
num_failed_batches_counter = metric_results['counters'][0]
self.assertEqual(num_failed_batches_counter.committed, 3)
# !!!: The above will need to be updated if retry behavior changes
def test_failed_batches_counter_no_failures(self):
pipeline = TestPipeline()
examples = [7]
pcoll = pipeline | 'start' >> beam.Create(examples)
inference_args = {'key': True}
_ = pcoll | base.RunInference(
FakeModelHandlerExpectedInferenceArgs(), inference_args=inference_args)
run_result = pipeline.run()
run_result.wait_until_finish()
metric_results = (
run_result.metrics().query(
MetricsFilter().with_name('failed_batches_counter')))
self.assertEqual(len(metric_results['counters']), 0)
def test_counted_metrics(self):
pipeline = TestPipeline()
examples = [1, 5, 3, 10]
pcoll = pipeline | 'start' >> beam.Create(examples)
_ = pcoll | base.RunInference(FakeModelHandler())
run_result = pipeline.run()
run_result.wait_until_finish()
metric_results = (
run_result.metrics().query(MetricsFilter().with_name('num_inferences')))
num_inferences_counter = metric_results['counters'][0]
self.assertEqual(num_inferences_counter.committed, 4)
inference_request_batch_size = run_result.metrics().query(
MetricsFilter().with_name('inference_request_batch_size'))
self.assertTrue(inference_request_batch_size['distributions'])
self.assertEqual(
inference_request_batch_size['distributions'][0].result.sum, 4)
inference_request_batch_byte_size = run_result.metrics().query(
MetricsFilter().with_name('inference_request_batch_byte_size'))
self.assertTrue(inference_request_batch_byte_size['distributions'])
self.assertGreaterEqual(
inference_request_batch_byte_size['distributions'][0].result.sum,
len(pickle.dumps(examples)))
inference_request_batch_byte_size = run_result.metrics().query(
MetricsFilter().with_name('model_byte_size'))
self.assertTrue(inference_request_batch_byte_size['distributions'])
def test_timing_metrics(self):
pipeline = TestPipeline()
examples = [1, 5, 3, 10]
pcoll = pipeline | 'start' >> beam.Create(examples)
fake_clock = FakeClock()
_ = pcoll | base.RunInference(
FakeModelHandler(clock=fake_clock), clock=fake_clock)
res = pipeline.run()
res.wait_until_finish()
metric_results = (
res.metrics().query(
MetricsFilter().with_name('inference_batch_latency_micro_secs')))
batch_latency = metric_results['distributions'][0]
self.assertEqual(batch_latency.result.count, 3)
self.assertEqual(batch_latency.result.mean, 3000)
metric_results = (
res.metrics().query(
MetricsFilter().with_name('load_model_latency_milli_secs')))
load_model_latency = metric_results['distributions'][0]
self.assertEqual(load_model_latency.result.count, 1)
self.assertEqual(load_model_latency.result.mean, 500)
def test_forwards_batch_args(self):
examples = list(range(100))
with TestPipeline() as pipeline:
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(FakeModelHandlerNeedsBigBatch())
assert_that(actual, equal_to(examples), label='assert:inferences')
def test_run_inference_unkeyed_examples_with_keyed_model_handler(self):
pipeline = TestPipeline()
with self.assertRaises(TypeError):
examples = [1, 3, 5]
model_handler = base.KeyedModelHandler(FakeModelHandler())
_ = (
pipeline | 'Unkeyed' >> beam.Create(examples)
| 'RunUnkeyed' >> base.RunInference(model_handler))
pipeline.run()
def test_run_inference_keyed_examples_with_unkeyed_model_handler(self):
pipeline = TestPipeline()
examples = [1, 3, 5]
keyed_examples = [(i, example) for i, example in enumerate(examples)]
model_handler = FakeModelHandler()
with self.assertRaises(TypeError):
_ = (
pipeline | 'keyed' >> beam.Create(keyed_examples)
| 'RunKeyed' >> base.RunInference(model_handler))
pipeline.run()
def test_model_handler_compatibility(self):
# ** IMPORTANT ** Do not change this test to make your PR pass without
# first reading below.
# Be certain that the modification will not break third party
# implementations of ModelHandler.
# See issue https://github.com/apache/beam/issues/23484
# If this test fails, likely third party implementations of
# ModelHandler will break.
class ThirdPartyHandler(base.ModelHandler[int, int, FakeModel]):
def __init__(self, custom_parameter=None):
pass
def load_model(self) -> FakeModel:
return FakeModel()
def run_inference(
self,
batch: Sequence[int],
model: FakeModel,
inference_args: Optional[Dict[str, Any]] = None) -> Iterable[int]:
yield 0
def get_num_bytes(self, batch: Sequence[int]) -> int:
return 1
def get_metrics_namespace(self) -> str:
return 'ThirdParty'
def get_resource_hints(self) -> dict:
return {}
def batch_elements_kwargs(self) -> Mapping[str, Any]:
return {}
def validate_inference_args(
self, inference_args: Optional[Dict[str, Any]]):
pass
# This test passes if calling these methods does not cause
# any runtime exceptions.
third_party_model_handler = ThirdPartyHandler(custom_parameter=0)
fake_model = third_party_model_handler.load_model()
third_party_model_handler.run_inference([], fake_model)
fake_inference_args = {'some_arg': 1}
third_party_model_handler.run_inference([],
fake_model,
inference_args=fake_inference_args)
third_party_model_handler.get_num_bytes([1, 2, 3])
third_party_model_handler.get_metrics_namespace()
third_party_model_handler.get_resource_hints()
third_party_model_handler.batch_elements_kwargs()
third_party_model_handler.validate_inference_args({})
def test_run_inference_prediction_result_with_model_id(self):
examples = [1, 5, 3, 10]
expected = [
base.PredictionResult(
example=example,
inference=example + 1,
model_id='fake_model_id_default') for example in examples
]
with TestPipeline() as pipeline:
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(
FakeModelHandlerReturnsPredictionResult())
assert_that(actual, equal_to(expected), label='assert:inferences')
def test_run_inference_with_iterable_side_input(self):
test_pipeline = TestPipeline()
side_input = (
test_pipeline | "CreateDummySideInput" >> beam.Create(
[base.ModelMetadata(1, 1), base.ModelMetadata(2, 2)])
| "ApplySideInputWindow" >> beam.WindowInto(
window.GlobalWindows(),
trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)),
accumulation_mode=trigger.AccumulationMode.DISCARDING))
test_pipeline.options.view_as(StandardOptions).streaming = True
with self.assertRaises(ValueError) as e:
_ = (
test_pipeline
| beam.Create([1, 2, 3, 4])
| base.RunInference(
FakeModelHandler(), model_metadata_pcoll=side_input))
test_pipeline.run()
self.assertTrue(
'PCollection of size 2 with more than one element accessed as a '
'singleton view. First two elements encountered are' in str(
e.exception))
def test_run_inference_with_iterable_side_input_multi_process_shared(self):
test_pipeline = TestPipeline()
side_input = (
test_pipeline | "CreateDummySideInput" >> beam.Create(
[base.ModelMetadata(1, 1), base.ModelMetadata(2, 2)])
| "ApplySideInputWindow" >> beam.WindowInto(
window.GlobalWindows(),
trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)),
accumulation_mode=trigger.AccumulationMode.DISCARDING))
test_pipeline.options.view_as(StandardOptions).streaming = True
with self.assertRaises(ValueError) as e:
_ = (
test_pipeline
| beam.Create([1, 2, 3, 4])
| base.RunInference(
FakeModelHandler(multi_process_shared=True),
model_metadata_pcoll=side_input))
test_pipeline.run()
self.assertTrue(
'PCollection of size 2 with more than one element accessed as a '
'singleton view. First two elements encountered are' in str(
e.exception))
def test_run_inference_empty_side_input(self):
model_handler = FakeModelHandlerReturnsPredictionResult()
main_input_elements = [1, 2]
with TestPipeline(is_integration_test=False) as pipeline:
side_pcoll = pipeline | "side" >> beam.Create([])
result_pcoll = (
pipeline
| beam.Create(main_input_elements)
| base.RunInference(model_handler, model_metadata_pcoll=side_pcoll))
expected = [
base.PredictionResult(ele, ele + 1, 'fake_model_id_default')
for ele in main_input_elements
]
assert_that(result_pcoll, equal_to(expected))
def test_run_inference_side_input_in_batch(self):
first_ts = math.floor(time.time()) - 30
interval = 7
sample_main_input_elements = ([
first_ts - 2,
first_ts + 1,
first_ts + 8,
first_ts + 15,
first_ts + 22,
])
sample_side_input_elements = [
(first_ts + 1, base.ModelMetadata(model_id='', model_name='')),
# if model_id is empty string, we use the default model
# handler model URI.
(
first_ts + 8,
base.ModelMetadata(
model_id='fake_model_id_1', model_name='fake_model_id_1')),
(
first_ts + 15,
base.ModelMetadata(
model_id='fake_model_id_2', model_name='fake_model_id_2'))
]
model_handler = FakeModelHandlerReturnsPredictionResult()
# applying GroupByKey to utilize windowing according to
# https://beam.apache.org/documentation/programming-guide/#windowing-bounded-collections
class _EmitElement(beam.DoFn):
def process(self, element):
for e in element:
yield e
with TestPipeline() as pipeline:
side_input = (
pipeline
|
"CreateSideInputElements" >> beam.Create(sample_side_input_elements)
| beam.Map(lambda x: TimestampedValue(x[1], x[0]))
| beam.WindowInto(
window.FixedWindows(interval),
accumulation_mode=trigger.AccumulationMode.DISCARDING)
| beam.Map(lambda x: ('key', x))
| beam.GroupByKey()
| beam.Map(lambda x: x[1])
| "EmitSideInput" >> beam.ParDo(_EmitElement()))
result_pcoll = (
pipeline
| beam.Create(sample_main_input_elements)
| "MapTimeStamp" >> beam.Map(lambda x: TimestampedValue(x, x))
| "ApplyWindow" >> beam.WindowInto(window.FixedWindows(interval))
| beam.Map(lambda x: ('key', x))
| "MainInputGBK" >> beam.GroupByKey()
| beam.Map(lambda x: x[1])
| beam.ParDo(_EmitElement())
| "RunInference" >> base.RunInference(
model_handler, model_metadata_pcoll=side_input))
expected_model_id_order = [
'fake_model_id_default',
'fake_model_id_default',
'fake_model_id_1',
'fake_model_id_2',
'fake_model_id_2'
]
expected_result = [
base.PredictionResult(
example=sample_main_input_elements[i],
inference=sample_main_input_elements[i] + 1,
model_id=expected_model_id_order[i]) for i in range(5)
]
assert_that(result_pcoll, equal_to(expected_result))
def test_run_inference_side_input_in_batch_per_key_models(self):
first_ts = math.floor(time.time()) - 30
interval = 7
sample_main_input_elements = ([
('key1', first_ts - 2),
('key2', first_ts + 1),
('key2', first_ts + 8),
('key1', first_ts + 15),
('key2', first_ts + 22),
('key1', first_ts + 29),
])
sample_side_input_elements = [
(
first_ts + 1,
[
base.KeyModelPathMapping(
keys=['key1'], update_path='fake_model_id_default'),
base.KeyModelPathMapping(
keys=['key2'], update_path='fake_model_id_default')
]),
# if model_id is empty string, we use the default model
# handler model URI.
(
first_ts + 8,
[
base.KeyModelPathMapping(
keys=['key1'], update_path='fake_model_id_1'),
base.KeyModelPathMapping(
keys=['key2'], update_path='fake_model_id_default')
]),
(
first_ts + 15,
[
base.KeyModelPathMapping(
keys=['key1'], update_path='fake_model_id_1'),
base.KeyModelPathMapping(
keys=['key2'], update_path='fake_model_id_2')
]),
]
model_handler = base.KeyedModelHandler([
base.KeyModelMapping(['key1'],
FakeModelHandlerReturnsPredictionResult(
multi_process_shared=True, state=True)),
base.KeyModelMapping(['key2'],
FakeModelHandlerReturnsPredictionResult(
multi_process_shared=True, state=True))
])
class _EmitElement(beam.DoFn):
def process(self, element):
for e in element:
yield e
with TestPipeline() as pipeline:
side_input = (
pipeline
|
"CreateSideInputElements" >> beam.Create(sample_side_input_elements)
| beam.Map(lambda x: TimestampedValue(x[1], x[0]))
| beam.WindowInto(
window.FixedWindows(interval),
accumulation_mode=trigger.AccumulationMode.DISCARDING)
| beam.Map(lambda x: ('key', x))
| beam.GroupByKey()
| beam.Map(lambda x: x[1])
| "EmitSideInput" >> beam.ParDo(_EmitElement()))
result_pcoll = (
pipeline
| beam.Create(sample_main_input_elements)
| "MapTimeStamp" >> beam.Map(lambda x: TimestampedValue(x, x[1]))
| "ApplyWindow" >> beam.WindowInto(window.FixedWindows(interval))
| "RunInference" >> base.RunInference(
model_handler, model_metadata_pcoll=side_input)
| beam.Map(lambda x: x[1]))
expected_model_id_order = [
'fake_model_id_default',
'fake_model_id_default',
'fake_model_id_default',
'fake_model_id_1',
'fake_model_id_2',
'fake_model_id_1',
]
expected_inferences = [
0,
0,
1,
0,
0,
1,
]
expected_result = [
base.PredictionResult(
example=sample_main_input_elements[i][1],
inference=expected_inferences[i],
model_id=expected_model_id_order[i])
for i in range(len(expected_inferences))
]
assert_that(result_pcoll, equal_to(expected_result))
def test_run_inference_side_input_in_batch_per_key_models_split_cohort(self):
first_ts = math.floor(time.time()) - 30
interval = 7
sample_main_input_elements = ([
('key1', first_ts - 2),
('key2', first_ts + 1),
('key1', first_ts + 8),
('key2', first_ts + 15),
('key1', first_ts + 22),
])
sample_side_input_elements = [
(
first_ts + 1,
[
base.KeyModelPathMapping(
keys=['key1', 'key2'], update_path='fake_model_id_default')
]),
# if model_id is empty string, we use the default model
# handler model URI.
(
first_ts + 8,
[
base.KeyModelPathMapping(
keys=['key1'], update_path='fake_model_id_1'),
base.KeyModelPathMapping(
keys=['key2'], update_path='fake_model_id_default')
]),
(
first_ts + 15,
[
base.KeyModelPathMapping(
keys=['key1'], update_path='fake_model_id_1'),
base.KeyModelPathMapping(
keys=['key2'], update_path='fake_model_id_2')
]),
]
model_handler = base.KeyedModelHandler([
base.KeyModelMapping(
['key1', 'key2'],
FakeModelHandlerReturnsPredictionResult(multi_process_shared=True))
])
class _EmitElement(beam.DoFn):
def process(self, element):
for e in element:
yield e
with TestPipeline() as pipeline:
side_input = (
pipeline
|
"CreateSideInputElements" >> beam.Create(sample_side_input_elements)
| beam.Map(lambda x: TimestampedValue(x[1], x[0]))
| beam.WindowInto(
window.FixedWindows(interval),
accumulation_mode=trigger.AccumulationMode.DISCARDING)
| beam.Map(lambda x: ('key', x))
| beam.GroupByKey()
| beam.Map(lambda x: x[1])
| "EmitSideInput" >> beam.ParDo(_EmitElement()))
result_pcoll = (
pipeline
| beam.Create(sample_main_input_elements)
| "MapTimeStamp" >> beam.Map(lambda x: TimestampedValue(x, x[1]))
| "ApplyWindow" >> beam.WindowInto(window.FixedWindows(interval))
| "RunInference" >> base.RunInference(
model_handler, model_metadata_pcoll=side_input)
| beam.Map(lambda x: x[1]))
expected_model_id_order = [
'fake_model_id_default',
'fake_model_id_default',
'fake_model_id_1',
'fake_model_id_2',
'fake_model_id_1'
]
expected_result = [
base.PredictionResult(
example=sample_main_input_elements[i][1],
inference=sample_main_input_elements[i][1] + 1,
model_id=expected_model_id_order[i]) for i in range(5)
]
assert_that(result_pcoll, equal_to(expected_result))
def test_run_inference_side_input_in_batch_multi_process_shared(self):
first_ts = math.floor(time.time()) - 30
interval = 7
sample_main_input_elements = ([
first_ts - 2,
first_ts + 1,
first_ts + 8,
first_ts + 15,
first_ts + 22,
])
sample_side_input_elements = [
(first_ts + 1, base.ModelMetadata(model_id='', model_name='')),
# if model_id is empty string, we use the default model
# handler model URI.
(
first_ts + 8,
base.ModelMetadata(
model_id='fake_model_id_1', model_name='fake_model_id_1')),
(
first_ts + 15,
base.ModelMetadata(
model_id='fake_model_id_2', model_name='fake_model_id_2'))
]
model_handler = FakeModelHandlerReturnsPredictionResult(
multi_process_shared=True)
# applying GroupByKey to utilize windowing according to
# https://beam.apache.org/documentation/programming-guide/#windowing-bounded-collections
class _EmitElement(beam.DoFn):
def process(self, element):
for e in element:
yield e
with TestPipeline() as pipeline:
side_input = (
pipeline
|
"CreateSideInputElements" >> beam.Create(sample_side_input_elements)
| beam.Map(lambda x: TimestampedValue(x[1], x[0]))
| beam.WindowInto(
window.FixedWindows(interval),
accumulation_mode=trigger.AccumulationMode.DISCARDING)
| beam.Map(lambda x: ('key', x))
| beam.GroupByKey()
| beam.Map(lambda x: x[1])
| "EmitSideInput" >> beam.ParDo(_EmitElement()))
result_pcoll = (
pipeline
| beam.Create(sample_main_input_elements)
| "MapTimeStamp" >> beam.Map(lambda x: TimestampedValue(x, x))
| "ApplyWindow" >> beam.WindowInto(window.FixedWindows(interval))
| beam.Map(lambda x: ('key', x))
| "MainInputGBK" >> beam.GroupByKey()
| beam.Map(lambda x: x[1])
| beam.ParDo(_EmitElement())
| "RunInference" >> base.RunInference(
model_handler, model_metadata_pcoll=side_input))
expected_model_id_order = [
'fake_model_id_default',
'fake_model_id_default',
'fake_model_id_1',
'fake_model_id_2',
'fake_model_id_2'
]
expected_result = [
base.PredictionResult(
example=sample_main_input_elements[i],
inference=sample_main_input_elements[i] + 1,
model_id=expected_model_id_order[i]) for i in range(5)
]
assert_that(result_pcoll, equal_to(expected_result))
@unittest.skipIf(
not TestPipeline().get_pipeline_options().view_as(
StandardOptions).streaming,
"SideInputs to RunInference are only supported in streaming mode.")
@pytest.mark.it_postcommit
@pytest.mark.sickbay_direct
@pytest.mark.it_validatesrunner
def test_run_inference_with_side_inputin_streaming(self):
test_pipeline = TestPipeline(is_integration_test=True)
test_pipeline.options.view_as(StandardOptions).streaming = True
run_inference_side_inputs.run(
test_pipeline.get_full_options_as_args(), save_main_session=False)
def test_env_vars_set_correctly(self):
handler_with_vars = FakeModelHandler(env_vars={'FOO': 'bar'})
os.environ.pop('FOO', None)
self.assertFalse('FOO' in os.environ)
with TestPipeline() as pipeline:
examples = [1, 2, 3]
_ = (
pipeline
| 'start' >> beam.Create(examples)
| base.RunInference(handler_with_vars))
pipeline.run()
self.assertTrue('FOO' in os.environ)
self.assertTrue((os.environ['FOO']) == 'bar')
def test_child_class_without_env_vars(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
expected = [example + 1 for example in examples]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars())
assert_that(actual, equal_to(expected), label='assert:inferences')
def test_model_manager_loads_shared_model(self):
mhs = {
'key1': FakeModelHandler(state=1),
'key2': FakeModelHandler(state=2),
'key3': FakeModelHandler(state=3)
}
mm = base._ModelManager(mh_map=mhs)
tag1 = mm.load('key1').model_tag
# Use bad_mh's load function to make sure we're actually loading the
# version already stored
bad_mh = FakeModelHandler(state=100)
model1 = multi_process_shared.MultiProcessShared(
bad_mh.load_model, tag=tag1).acquire()
self.assertEqual(1, model1.predict(10))
tag2 = mm.load('key2').model_tag
tag3 = mm.load('key3').model_tag
model2 = multi_process_shared.MultiProcessShared(
bad_mh.load_model, tag=tag2).acquire()
model3 = multi_process_shared.MultiProcessShared(
bad_mh.load_model, tag=tag3).acquire()
self.assertEqual(2, model2.predict(10))
self.assertEqual(3, model3.predict(10))
def test_model_manager_evicts_models(self):
mh1 = FakeModelHandler(state=1)
mh2 = FakeModelHandler(state=2)
mh3 = FakeModelHandler(state=3)
mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
mm = base._ModelManager(mh_map=mhs)
mm.increment_max_models(2)
tag1 = mm.load('key1').model_tag
sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
model1 = sh1.acquire()
self.assertEqual(1, model1.predict(10))
model1.increment_state(5)
tag2 = mm.load('key2').model_tag
tag3 = mm.load('key3').model_tag
sh2 = multi_process_shared.MultiProcessShared(mh2.load_model, tag=tag2)
model2 = sh2.acquire()
sh3 = multi_process_shared.MultiProcessShared(mh3.load_model, tag=tag3)
model3 = sh3.acquire()
model2.increment_state(5)
model3.increment_state(5)
self.assertEqual(7, model2.predict(10))
self.assertEqual(8, model3.predict(10))
sh2.release(model2)
sh3.release(model3)
# model1 should have retained a valid reference to the model until released
self.assertEqual(6, model1.predict(10))
sh1.release(model1)
# This should now have been garbage collected, so it now it should get
# recreated and shouldn't have the state updates
model1 = multi_process_shared.MultiProcessShared(
mh1.load_model, tag=tag1).acquire()
self.assertEqual(1, model1.predict(10))
# These should not get recreated, so they should have the state updates
model2 = multi_process_shared.MultiProcessShared(
mh2.load_model, tag=tag2).acquire()
self.assertEqual(7, model2.predict(10))
model3 = multi_process_shared.MultiProcessShared(
mh3.load_model, tag=tag3).acquire()
self.assertEqual(8, model3.predict(10))
def test_model_manager_evicts_models_after_update(self):
mh1 = FakeModelHandler(state=1)
mhs = {'key1': mh1}
mm = base._ModelManager(mh_map=mhs)
tag1 = mm.load('key1').model_tag
sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
model1 = sh1.acquire()
self.assertEqual(1, model1.predict(10))
model1.increment_state(5)
self.assertEqual(6, model1.predict(10))
mm.update_model_handler('key1', 'fake/path', 'key1')
self.assertEqual(6, model1.predict(10))
sh1.release(model1)
tag1 = mm.load('key1').model_tag
sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
model1 = sh1.acquire()
self.assertEqual(1, model1.predict(10))
model1.increment_state(5)
self.assertEqual(6, model1.predict(10))
sh1.release(model1)
# Shouldn't evict if path is the same as last update
mm.update_model_handler('key1', 'fake/path', 'key1')
tag1 = mm.load('key1').model_tag
sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
model1 = sh1.acquire()
self.assertEqual(6, model1.predict(10))
sh1.release(model1)
def test_model_manager_evicts_correct_num_of_models_after_being_incremented(
self):
mh1 = FakeModelHandler(state=1)
mh2 = FakeModelHandler(state=2)
mh3 = FakeModelHandler(state=3)
mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
mm = base._ModelManager(mh_map=mhs)
mm.increment_max_models(1)
mm.increment_max_models(1)
tag1 = mm.load('key1').model_tag
sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
model1 = sh1.acquire()
self.assertEqual(1, model1.predict(10))
model1.increment_state(5)
self.assertEqual(6, model1.predict(10))
sh1.release(model1)
tag2 = mm.load('key2').model_tag
tag3 = mm.load('key3').model_tag
sh2 = multi_process_shared.MultiProcessShared(mh2.load_model, tag=tag2)
model2 = sh2.acquire()
sh3 = multi_process_shared.MultiProcessShared(mh3.load_model, tag=tag3)
model3 = sh3.acquire()
model2.increment_state(5)
model3.increment_state(5)
self.assertEqual(7, model2.predict(10))
self.assertEqual(8, model3.predict(10))
sh2.release(model2)
sh3.release(model3)
# This should get recreated, so it shouldn't have the state updates
model1 = multi_process_shared.MultiProcessShared(
mh1.load_model, tag=tag1).acquire()
self.assertEqual(1, model1.predict(10))
# These should not get recreated, so they should have the state updates
model2 = multi_process_shared.MultiProcessShared(
mh2.load_model, tag=tag2).acquire()
self.assertEqual(7, model2.predict(10))
model3 = multi_process_shared.MultiProcessShared(
mh3.load_model, tag=tag3).acquire()
self.assertEqual(8, model3.predict(10))
def test_run_inference_watch_file_pattern_side_input_label(self):
pipeline = TestPipeline()
# label of the WatchPattern transform.
side_input_str = 'WatchFilePattern/ApplyGlobalWindow'
from apache_beam.ml.inference.utils import WatchFilePattern
file_pattern_side_input = (
pipeline
| 'WatchFilePattern' >> WatchFilePattern(file_pattern='fake/path/*'))
pcoll = pipeline | 'start' >> beam.Create([1, 2, 3])
result_pcoll = pcoll | base.RunInference(
FakeModelHandler(), model_metadata_pcoll=file_pattern_side_input)
assert side_input_str in str(result_pcoll.producer.side_inputs[0])
def test_run_inference_watch_file_pattern_keyword_arg_side_input_label(self):
# label of the WatchPattern transform.
side_input_str = 'WatchFilePattern/ApplyGlobalWindow'
pipeline = TestPipeline()
pcoll = pipeline | 'start' >> beam.Create([1, 2, 3])
result_pcoll = pcoll | base.RunInference(
FakeModelHandler(), watch_model_pattern='fake/path/*')
assert side_input_str in str(result_pcoll.producer.side_inputs[0])
if __name__ == '__main__':
unittest.main()