blob: e0a0c09a7c6a6344e7b608b87f152e1273de882d [file]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Unit tests for table row inference pipeline."""
import json
import os
import pickle
import tempfile
import unittest
import apache_beam as beam
import numpy as np
from apache_beam.examples.inference.table_row_inference import FormatTableOutput
from apache_beam.examples.inference.table_row_inference import TableRowModelHandler
from apache_beam.examples.inference.table_row_inference import build_output_schema
from apache_beam.examples.inference.table_row_inference import (
parse_json_to_table_row)
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
# Module-level matcher for assert_that (must be picklable; no closure over
# self).
REQUIRED_OUTPUT_KEYS = (
'row_key',
'prediction',
'input_feature1',
'input_feature2',
'input_feature3')
def _assert_table_inference_outputs(outputs):
"""Asserts pipeline output has expected structure. Used in assert_that."""
if len(outputs) != 2:
raise AssertionError(f'Expected 2 outputs, got {len(outputs)}')
for output in outputs:
for key in REQUIRED_OUTPUT_KEYS:
if key not in output:
raise AssertionError(f'Missing key {key!r} in output {output}')
try:
from sklearn.linear_model import LinearRegression
SKLEARN_AVAILABLE = True
except ImportError:
SKLEARN_AVAILABLE = False
class SimpleLinearModel:
"""Simple model for testing without sklearn dependency."""
def predict(self, X):
return np.sum(X, axis=1)
@unittest.skipIf(not SKLEARN_AVAILABLE, 'sklearn is not available')
class TableRowInferenceTest(unittest.TestCase):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
self.model_path = os.path.join(self.tmp_dir, 'test_model.pkl')
model = LinearRegression()
X_train = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
y_train = np.array([6, 15, 24])
model.fit(X_train, y_train)
with open(self.model_path, 'wb') as f:
pickle.dump(model, f)
def test_parse_json_to_table_row(self):
json_data = json.dumps({
'id': 'test_1', 'feature1': 1.0, 'feature2': 2.0, 'feature3': 3.0
}).encode('utf-8')
key, row = parse_json_to_table_row(
json_data, schema_fields=['feature1', 'feature2', 'feature3'])
self.assertEqual(key, 'test_1')
self.assertEqual(row.feature1, 1.0)
self.assertEqual(row.feature2, 2.0)
self.assertEqual(row.feature3, 3.0)
def test_build_output_schema(self):
feature_cols = ['feature1', 'feature2', 'feature3']
schema = build_output_schema(feature_cols)
expected_fields = [
'row_key',
'prediction',
'model_id',
'input_feature1',
'input_feature2',
'input_feature3'
]
for field in expected_fields:
self.assertIn(field, schema)
def test_table_row_model_handler(self):
model_handler = TableRowModelHandler(
model_uri=self.model_path, feature_columns=['f1', 'f2', 'f3'])
model = model_handler.load_model()
test_rows = [
beam.Row(f1=1.0, f2=2.0, f3=3.0),
beam.Row(f1=4.0, f2=5.0, f3=6.0),
]
results = list(model_handler.run_inference(test_rows, model))
self.assertEqual(len(results), 2)
self.assertIsInstance(results[0], PredictionResult)
self.assertEqual(results[0].example, test_rows[0])
self.assertIsNotNone(results[0].inference)
def test_format_table_output(self):
row = beam.Row(feature1=1.0, feature2=2.0, feature3=3.0)
prediction_result = PredictionResult(
example=row, inference=6.0, model_id='test_model')
keyed_result = ('test_key', prediction_result)
feature_columns = ['feature1', 'feature2', 'feature3']
formatter = FormatTableOutput(feature_columns=feature_columns)
outputs = list(formatter.process(keyed_result))
self.assertEqual(len(outputs), 1)
output = outputs[0]
self.assertEqual(output['row_key'], 'test_key')
self.assertEqual(output['prediction'], 6.0)
self.assertEqual(output['model_id'], 'test_model')
self.assertEqual(output['input_feature1'], 1.0)
self.assertEqual(output['input_feature2'], 2.0)
self.assertEqual(output['input_feature3'], 3.0)
def test_pipeline_integration(self):
test_data = [
json.dumps({
'id': 'row_1', 'feature1': 1.0, 'feature2': 2.0, 'feature3': 3.0
}),
json.dumps({
'id': 'row_2', 'feature1': 4.0, 'feature2': 5.0, 'feature3': 6.0
}),
]
feature_columns = ['feature1', 'feature2', 'feature3']
model_handler = TableRowModelHandler(
model_uri=self.model_path, feature_columns=feature_columns)
with TestPipeline() as p:
input_data = (
p
| beam.Create(test_data)
| beam.Map(
lambda line: parse_json_to_table_row(
line.encode('utf-8'), feature_columns)))
predictions = (
input_data
| RunInference(KeyedModelHandler(model_handler))
| beam.ParDo(FormatTableOutput(feature_columns=feature_columns)))
assert_that(predictions, _assert_table_inference_outputs)
class TableRowInferenceNoSklearnTest(unittest.TestCase):
"""Tests that don't require sklearn."""
def test_parse_json_without_schema(self):
json_data = json.dumps({'id': 'test', 'value': 123}).encode('utf-8')
key, row = parse_json_to_table_row(json_data)
self.assertEqual(key, 'test')
self.assertTrue(hasattr(row, 'value'))
if __name__ == '__main__':
unittest.main()