blob: af9b8eb73ddbfbdfd0ba8d88d350d5bb743d4871 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from pyflink.common import Types
from pyflink.common.time import Time, Instant
from pyflink.java_gateway import get_gateway
from pyflink.table import Schema
from pyflink.table.types import DataTypes
from pyflink.table.expressions import col
from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo
from pyflink.ml.feature.onlinestandardscaler import OnlineStandardScaler, \
OnlineStandardScalerModel
from pyflink.ml.tests.test_utils import PyFlinkMLTestCase, update_existing_params
from pyflink.ml.common.window import GlobalWindows, EventTimeTumblingWindows
class OnlineStandardScalerTest(PyFlinkMLTestCase):
def setUp(self):
super(OnlineStandardScalerTest, self).setUp()
dense_vector_serializer = get_gateway().jvm.org.apache.flink.table.types.logical.RawType(
get_gateway().jvm.org.apache.flink.ml.linalg.DenseVector(0).getClass(),
get_gateway().jvm.org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer()
).getSerializerString()
schema = Schema.new_builder() \
.column("ts_in_long", DataTypes.BIGINT()) \
.column("ts", "TIMESTAMP_LTZ(3)") \
.column("input", "RAW('org.apache.flink.ml.linalg.DenseVector', '{serializer}')"
.format(serializer=dense_vector_serializer)) \
.watermark("ts", "ts - INTERVAL '1' SECOND") \
.build()
self.input_table = self.t_env.from_data_stream(
self.env.from_collection([
(0, Instant.of_epoch_milli(0), Vectors.dense(-2.5, 9, 1),),
(1000, Instant.of_epoch_milli(1000), Vectors.dense(1.4, -5, 1),),
(2000, Instant.of_epoch_milli(2000), Vectors.dense(2, -1, -2),),
(6000, Instant.of_epoch_milli(6000), Vectors.dense(0.7, 3, 1),),
(7000, Instant.of_epoch_milli(7000), Vectors.dense(0, 1, 1),),
(8000, Instant.of_epoch_milli(8000), Vectors.dense(0.5, 0, -2),),
(9000, Instant.of_epoch_milli(9000), Vectors.dense(0.4, 1, 1),),
(10000, Instant.of_epoch_milli(10000), Vectors.dense(0.3, 2, 1),),
(11000, Instant.of_epoch_milli(11000), Vectors.dense(0.5, 1, -2),)
],
type_info=Types.ROW_NAMED(
['ts_in_long', 'ts', 'input'],
[Types.LONG(), Types.INSTANT(), DenseVectorTypeInfo()])),
schema)
self.window_size_ms = 3000
self.expected_model_data = [
[
Vectors.dense(0.3, 1, 0),
Vectors.dense(2.4433583, 7.2111026, 1.7320508),
0,
2999
],
[
Vectors.dense(0.35, 1.1666667, 0),
Vectors.dense(1.5630099, 4.6654760, 1.5491933),
1,
8999
],
[
Vectors.dense(0.3666667, 1.2222222, 0),
Vectors.dense(1.2369316, 3.7006005, 1.5),
2,
11999
]
]
self.tolerance = 1e-7
def test_param(self):
standard_scaler = OnlineStandardScaler()
self.assertEqual('input', standard_scaler.input_col)
self.assertEqual(False, standard_scaler.with_mean)
self.assertEqual(True, standard_scaler.with_std)
self.assertEqual('output', standard_scaler.output_col)
self.assertEqual('version', standard_scaler.model_version_col)
self.assertEqual(GlobalWindows(), standard_scaler.windows)
self.assertEqual(0, standard_scaler.max_allowed_model_delay_ms)
standard_scaler.set_input_col('test_input') \
.set_with_mean(True) \
.set_with_std(False) \
.set_output_col('test_output') \
.set_model_version_col('test_version') \
.set_windows(EventTimeTumblingWindows.of(Time.milliseconds(3000))) \
.set_max_allowed_model_delay_ms(3000)
self.assertEqual('test_input', standard_scaler.input_col)
self.assertEqual(True, standard_scaler.with_mean)
self.assertEqual(False, standard_scaler.with_std)
self.assertEqual('test_output', standard_scaler.output_col)
self.assertEqual("test_version", standard_scaler.model_version_col)
self.assertEqual(EventTimeTumblingWindows.of(Time.milliseconds(3000)),
standard_scaler.windows)
self.assertEqual(3000, standard_scaler.max_allowed_model_delay_ms)
def test_output_schema(self):
temp_table = self.input_table.alias('ts_in_long', 'ts', "test_input")
# Tests the case when modelVersionCol is not null.
standard_scaler = OnlineStandardScaler() \
.set_input_col('test_input') \
.set_output_col('test_output')
output = standard_scaler.fit(temp_table).transform(temp_table)[0]
self.assertEqual(['ts_in_long', 'ts', 'test_input', 'test_output', 'version'],
output.get_schema().get_field_names())
# Tests the case when modelVersionCol is null.
standard_scaler.set_model_version_col(None)
output = standard_scaler.fit(temp_table).transform(temp_table)[0]
self.assertEqual(['ts_in_long', 'ts', 'test_input', 'test_output'],
output.get_schema().get_field_names())
def test_fit_and_predict(self):
standard_scaler = OnlineStandardScaler()
standard_scaler \
.set_windows(EventTimeTumblingWindows.of(Time.milliseconds(self.window_size_ms)))
output = standard_scaler.fit(self.input_table).transform(self.input_table)[0]
self.verify_used_model_version(output, standard_scaler.model_version_col,
standard_scaler.max_allowed_model_delay_ms)
def test_get_model_data(self):
standard_scaler = OnlineStandardScaler()
standard_scaler \
.set_windows(EventTimeTumblingWindows.of(Time.milliseconds(self.window_size_ms)))
model_data = standard_scaler.fit(self.input_table).get_model_data()[0]
self.assertEqual(["mean", "std", "version", "timestamp"],
model_data.get_schema().get_field_names())
model_rows = [result for result in
self.t_env.to_data_stream(model_data).execute_and_collect()]
self.assertEqual(len(self.expected_model_data), len(model_rows))
for idx in range(len(self.expected_model_data)):
self.assertListAlmostEqual(self.expected_model_data[idx][0].to_array(),
model_rows[idx][0].to_array(),
delta=self.tolerance)
self.assertListAlmostEqual(self.expected_model_data[idx][1].to_array(),
model_rows[idx][1].to_array(),
delta=self.tolerance)
self.assertEqual(self.expected_model_data[idx][2], model_rows[idx][2])
self.assertEqual(self.expected_model_data[idx][3], model_rows[idx][3])
def test_set_model_data(self):
standard_scaler = OnlineStandardScaler()
model = standard_scaler \
.set_windows(EventTimeTumblingWindows.of(Time.milliseconds(self.window_size_ms))) \
.fit(self.input_table)
model_data = model.get_model_data()[0]
new_model = OnlineStandardScalerModel().set_model_data(model_data)
update_existing_params(new_model, model)
output = new_model.transform(self.input_table)[0]
self.verify_used_model_version(output, model.model_version_col,
model.max_allowed_model_delay_ms)
def test_save_load_and_predict(self):
standard_scaler = OnlineStandardScaler() \
.set_windows(EventTimeTumblingWindows.of(Time.milliseconds(self.window_size_ms)))
reloaded_standard_scaler = self.save_and_reload(standard_scaler)
model = reloaded_standard_scaler.fit(self.input_table)
reloaded_model = self.save_and_reload(model)
model_data = model.get_model_data()[0]
reloaded_model.set_model_data(model_data)
output = reloaded_model.transform(self.input_table)[0]
self.verify_used_model_version(output, reloaded_model.model_version_col,
reloaded_model.max_allowed_model_delay_ms)
def verify_used_model_version(self, output_table, model_version_col,
max_allowed_model_delay_ms):
# TODO: remove this line when Instant values can be acquired without errors.
output_table = output_table.select(col("ts_in_long"), col(model_version_col))
collected_results = [result for result in
self.t_env.to_data_stream(output_table).execute_and_collect()]
for row in collected_results:
data_timestamp = row[0]
model_version = row[1]
model_time_stamp = self.expected_model_data[model_version][3]
self.assertTrue(data_timestamp - model_time_stamp <= max_allowed_model_delay_ms)