blob: f287dee323b00db40e043b93ec6a3e7a7697a5ce [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import os
from pyflink.common import Types
from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo
from pyflink.ml.feature.idf import IDF, IDFModel
from pyflink.ml.tests.test_utils import PyFlinkMLTestCase, update_existing_params
class IDFTest(PyFlinkMLTestCase):
def setUp(self):
super(IDFTest, self).setUp()
self.input_data = self.t_env.from_data_stream(
self.env.from_collection([
(Vectors.dense(0, 1, 0, 2),),
(Vectors.dense(0, 1, 2, 3),),
(Vectors.dense(0, 1, 0, 0),),
],
type_info=Types.ROW_NAMED(
['input', ],
[DenseVectorTypeInfo(), ])))
self.expected_output = [
Vectors.dense(0.0, 0.0, 0.0, 0.5753641),
Vectors.dense(0.0, 0.0, 1.3862943, 0.8630462),
Vectors.dense(0.0, 0.0, 0.0, 0.0),
]
self.expected_output_min_doc_freq_as_two = [
Vectors.dense(0.0, 0.0, 0.0, 0.5753641),
Vectors.dense(0.0, 0.0, 0.0, 0.8630462),
Vectors.dense(0.0, 0.0, 0.0, 0.0),
]
self.tolerance = 1e-7
def verify_prediction_result(self, expected, output_table):
predicted_results = [result[1] for result in
self.t_env.to_data_stream(output_table).execute_and_collect()]
predicted_results.sort(key=lambda x: x[3])
expected.sort(key=lambda x: x[3])
self.assertEqual(len(expected), len(predicted_results))
for i in range(len(expected)):
expected_row = expected[i]
predicted_row = predicted_results[i]
self.assertEqual(len(expected_row), len(predicted_row))
for j in range(len(expected_row)):
self.assertAlmostEqual(expected_row[j], predicted_row[j], delta=self.tolerance)
def test_param(self):
idf = IDF()
self.assertEqual("input", idf.input_col)
self.assertEqual(0, idf.min_doc_freq)
self.assertEqual("output", idf.output_col)
idf \
.set_input_col("test_input") \
.set_min_doc_freq(2) \
.set_output_col("test_output")
self.assertEqual("test_input", idf.input_col)
self.assertEqual(2, idf.min_doc_freq)
self.assertEqual("test_output", idf.output_col)
def test_output_schema(self):
idf = IDF() \
.set_input_col("test_input") \
.set_output_col("test_output")
input_data_table = self.t_env.from_data_stream(
self.env.from_collection([
(Vectors.dense(1), ''),
],
type_info=Types.ROW_NAMED(
['test_input', 'dummy_input'],
[DenseVectorTypeInfo(), Types.STRING()])))
output = idf \
.fit(input_data_table) \
.transform(input_data_table)[0]
self.assertEqual(
[idf.input_col, 'dummy_input', idf.output_col],
output.get_schema().get_field_names())
def test_fit_and_predict(self):
idf = IDF()
# Tests minDocFreq = 0.
output = idf.fit(self.input_data).transform(self.input_data)[0]
self.verify_prediction_result(self.expected_output, output)
# Tests minDocFreq = 2.
idf.set_min_doc_freq(2)
output = idf.fit(self.input_data).transform(self.input_data)[0]
self.verify_prediction_result(self.expected_output_min_doc_freq_as_two, output)
def test_get_model_data(self):
idf = IDF()
model = idf.fit(self.input_data)
model_data = model.get_model_data()[0]
expected_field_names = ['idf', 'docFreq', 'numDocs']
self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
model_rows = [result for result in
self.t_env.to_data_stream(model_data).execute_and_collect()]
self.assertEqual(1, len(model_rows))
self.assertEqual(3, model_rows[0][expected_field_names.index('numDocs')])
self.assertListEqual([0, 3, 1, 2], model_rows[0][expected_field_names.index('docFreq')])
self.assertListAlmostEqual(
[1.3862943, 0, 0.6931471, 0.2876820],
model_rows[0][expected_field_names.index('idf')].to_array(),
delta=self.tolerance)
def test_set_model_data(self):
idf = IDF()
model_a = idf.fit(self.input_data)
model_data = model_a.get_model_data()[0]
model_b = IDFModel().set_model_data(model_data)
update_existing_params(model_b, model_a)
output = model_b.transform(self.input_data)[0]
self.verify_prediction_result(self.expected_output, output)
def test_save_load_predict(self):
idf = IDF()
estimator_path = os.path.join(self.temp_dir, 'test_save_load_predict_idf')
idf.save(estimator_path)
idf = IDF.load(self.t_env, estimator_path)
model = idf.fit(self.input_data)
model_path = os.path.join(self.temp_dir, 'test_save_load_predict_idf_model')
model.save(model_path)
self.env.execute('save_model')
model = IDFModel.load(self.t_env, model_path)
output = model.transform(self.input_data)[0]
self.verify_prediction_result(self.expected_output, output)