blob: 45446a46e7120571874edb0f1669d96bf80826ec [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from pyflink.common import Types
from pyflink.table import Table
from typing import List
from pyflink.ml.recommendation.swing import Swing
from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
# Tests Swing.
class SwingTest(PyFlinkMLTestCase):
def setUp(self):
super(SwingTest, self).setUp()
self.input_table = self.t_env.from_data_stream(
self.env.from_collection([
(0, 10),
(0, 11),
(0, 12),
(1, 13),
(1, 12),
(2, 10),
(2, 11),
(2, 12),
(3, 13),
(3, 12)
],
type_info=Types.ROW_NAMED(
['user', 'item'],
[Types.LONG(), Types.LONG()])
))
self.wrong_type_data = self.t_env.from_data_stream(
self.env.from_collection([
(0, 10),
(1, 11),
(2, 12)
],
type_info=Types.ROW_NAMED(
['user', 'item'],
[Types.INT(), Types.LONG()])
))
self.none_value_data = self.t_env.from_data_stream(
self.env.from_collection([
(0, 10),
(None, 11),
(2, 12)
],
type_info=Types.ROW_NAMED(
['user', 'item'],
[Types.LONG(), Types.LONG()])
))
self.expected_data = [
[10, '11,0.058845768947156235;12,0.058845768947156235'],
[11, '10,0.058845768947156235;12,0.058845768947156235'],
[12, '13,0.09134833828228624;10,0.058845768947156235;11,0.058845768947156235'],
[13, '12,0.09134833828228624']]
def test_param(self):
swing = Swing()
self.assertEqual("item", swing.item_col)
self.assertEqual("user", swing.user_col)
self.assertEqual(100, swing.k)
self.assertEqual(1000, swing.max_user_num_per_item)
self.assertEqual(10, swing.min_user_behavior)
self.assertEqual(1000, swing.max_user_behavior)
self.assertEqual(15, swing.alpha1)
self.assertEqual(0, swing.alpha2)
self.assertAlmostEqual(0.3, swing.beta, delta=1e-9)
self.assertEqual(438758276, swing.seed)
swing.set_item_col("item_1") \
.set_user_col("user_1") \
.set_k(20) \
.set_max_user_num_per_item(500) \
.set_min_user_behavior(20) \
.set_max_user_behavior(50) \
.set_alpha1(5) \
.set_alpha2(1) \
.set_beta(0.35) \
.set_seed(1)
self.assertEqual("item_1", swing.item_col)
self.assertEqual("user_1", swing.user_col)
self.assertEqual(20, swing.k)
self.assertEqual(500, swing.max_user_num_per_item)
self.assertEqual(20, swing.min_user_behavior)
self.assertEqual(50, swing.max_user_behavior)
self.assertEqual(5, swing.alpha1)
self.assertEqual(1, swing.alpha2)
self.assertAlmostEqual(0.35, swing.beta, delta=1e-9)
self.assertEqual(1, swing.seed)
def test_output_schema(self):
swing = Swing() \
.set_item_col('test_item') \
.set_user_col('test_user') \
.set_output_col("item_score")
output = swing.transform(self.input_table.alias('test_user', 'test_item'))[0]
self.assertEqual(
['test_item', 'item_score'],
output.get_schema().get_field_names())
def test_transform(self):
swing = Swing().set_min_user_behavior(1)
output = swing.transform(self.input_table)[0]
self.verify_output_result(
output,
swing.get_item_col(),
output.get_schema().get_field_names(),
self.expected_data)
def test_save_load_and_transform(self):
swing = Swing().set_min_user_behavior(1)
reloaded_swing = self.save_and_reload(swing)
output = reloaded_swing.transform(self.input_table)[0]
self.verify_output_result(
output,
swing.get_item_col(),
output.get_schema().get_field_names(),
self.expected_data)
def verify_output_result(
self, output: Table,
item_col: str,
field_names: List[str],
expected_result: List):
collected_results = [result for result in
self.t_env.to_data_stream(output).execute_and_collect()]
results = []
for result in collected_results:
main_item = result[field_names.index(item_col)]
item_rank_score = result[1]
results.append([main_item, item_rank_score])
results.sort(key=lambda x: x[0])
self.assertEqual(expected_result, results)
def test_sampling_method(self):
swing1 = Swing().set_min_user_behavior(1).set_max_user_num_per_item(2).set_seed(3)
swing2 = Swing().set_min_user_behavior(1).set_max_user_num_per_item(2)
output1 = swing1.transform(self.input_table)[0]
output2 = swing2.transform(self.input_table)[0]
result1 = [result for result in self.t_env.to_data_stream(output1).execute_and_collect()]
result2 = [result for result in self.t_env.to_data_stream(output2).execute_and_collect()]
self.assertNotEqual(len(result1), len(result2))