| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| import unittest |
| from typing import Dict, Any |
| |
| from pyflink.ml.core.param import Param |
| from pyflink.ml.lib.param import HasDistanceMeasure, HasFeaturesCol, HasGlobalBatchSize, \ |
| HasHandleInvalid, HasInputCols, HasLabelCol, HasLearningRate, HasMaxIter, HasMultiClass, \ |
| HasOutputCols, HasPredictionCol, HasRawPredictionCol, HasReg, HasSeed, HasTol, HasWeightCol |
| |
| |
| class TestParams(HasDistanceMeasure, HasFeaturesCol, HasGlobalBatchSize, HasHandleInvalid, |
| HasInputCols, HasLabelCol, HasLearningRate, HasMaxIter, HasMultiClass, |
| HasOutputCols, HasPredictionCol, HasRawPredictionCol, HasReg, HasSeed, HasTol, |
| HasWeightCol): |
| def __init__(self): |
| self._param_map = {} |
| |
| def get_param_map(self) -> Dict['Param[Any]', Any]: |
| return self._param_map |
| |
| |
| class ParamTests(unittest.TestCase): |
| def test_distance_measure_param(self): |
| param = TestParams() |
| distance_measure = param.DISTANCE_MEASURE |
| self.assertEqual(distance_measure.name, "distance_measure") |
| self.assertEqual(distance_measure.description, |
| "Distance measure. Supported options: 'euclidean' and 'cosine'.") |
| self.assertEqual(distance_measure.default_value, "euclidean") |
| |
| param.set_distance_measure("cosine") |
| self.assertEqual(param.get_distance_measure(), "cosine") |
| |
| def test_feature_col_param(self): |
| param = TestParams() |
| feature_col = param.FEATURES_COL |
| self.assertEqual(feature_col.name, "features_col") |
| self.assertEqual(feature_col.description, "Features column name.") |
| self.assertEqual(feature_col.default_value, "features") |
| |
| param.set_features_col("test_features") |
| self.assertEqual(param.get_features_col(), "test_features") |
| |
| def test_global_batch_size_param(self): |
| param = TestParams() |
| global_batch_size = param.GLOBAL_BATCH_SIZE |
| self.assertEqual(global_batch_size.name, "global_batch_size") |
| self.assertEqual(global_batch_size.description, |
| "Global batch size of training algorithms.") |
| self.assertEqual(global_batch_size.default_value, 32) |
| |
| param.set_global_batch_size(100) |
| self.assertEqual(param.get_global_batch_size(), 100) |
| |
| def test_handle_invalid_param(self): |
| param = TestParams() |
| handle_invalid = param.HANDLE_INVALID |
| self.assertEqual(handle_invalid.name, "handle_invalid") |
| self.assertEqual(handle_invalid.description, "Strategy to handle invalid entries.") |
| self.assertEqual(handle_invalid.default_value, "error") |
| |
| param.set_handle_invalid("skip") |
| self.assertEqual(param.get_handle_invalid(), "skip") |
| |
| def test_input_cols_param(self): |
| param = TestParams() |
| input_cols = param.INPUT_COLS |
| self.assertEqual(input_cols.name, "input_cols") |
| self.assertEqual(input_cols.description, "Input column names.") |
| self.assertEqual(input_cols.default_value, None) |
| |
| param.set_input_cols('a', 'b', 'c') |
| self.assertEqual(param.get_input_cols(), ('a', 'b', 'c')) |
| |
| def test_label_col_param(self): |
| param = TestParams() |
| label_col = param.LABEL_COL |
| self.assertEqual(label_col.name, "label_col") |
| self.assertEqual(label_col.description, "Label column name.") |
| self.assertEqual(label_col.default_value, "label") |
| |
| param.set_label_col('test_label') |
| self.assertEqual(param.get_label_col(), 'test_label') |
| |
| def test_learning_rate_param(self): |
| param = TestParams() |
| learning_rate = param.LEARNING_RATE |
| self.assertEqual(learning_rate.name, "learning_rate") |
| self.assertEqual(learning_rate.description, "Learning rate of optimization method.") |
| self.assertEqual(learning_rate.default_value, 0.1) |
| |
| param.set_learning_rate(0.2) |
| self.assertEqual(param.get_learning_rate(), 0.2) |
| |
| def test_max_iter_param(self): |
| param = TestParams() |
| max_iter = param.MAX_ITER |
| self.assertEqual(max_iter.name, "max_iter") |
| self.assertEqual(max_iter.description, "Maximum number of iterations.") |
| self.assertEqual(max_iter.default_value, 20) |
| |
| param.set_max_iter(50) |
| self.assertEqual(param.get_max_iter(), 50) |
| |
| def test_multi_class_param(self): |
| param = TestParams() |
| multi_class = param.MULTI_CLASS |
| self.assertEqual(multi_class.name, "multi_class") |
| self.assertEqual(multi_class.description, |
| "Classification type. Supported options: " |
| "'auto', 'binomial' and 'multinomial'.") |
| self.assertEqual(multi_class.default_value, 'auto') |
| |
| param.set_multi_class('binomial') |
| self.assertEqual(param.get_multi_class(), 'binomial') |
| |
| def test_output_cols_param(self): |
| param = TestParams() |
| output_cols = param.OUTPUT_COLS |
| self.assertEqual(output_cols.name, "output_cols") |
| self.assertEqual(output_cols.description, "Output column names.") |
| self.assertEqual(output_cols.default_value, None) |
| |
| param.set_output_cols('a', 'b') |
| self.assertEqual(param.get_output_cols(), ('a', 'b')) |
| |
| def test_prediction_col_param(self): |
| param = TestParams() |
| prediction_col = param.PREDICTION_COL |
| self.assertEqual(prediction_col.name, "prediction_col") |
| self.assertEqual(prediction_col.description, "Prediction column name.") |
| self.assertEqual(prediction_col.default_value, "prediction") |
| |
| param.set_prediction_col('test_prediction') |
| self.assertEqual(param.get_prediction_col(), 'test_prediction') |
| |
| def test_raw_prediction_col_param(self): |
| param = TestParams() |
| raw_prediction_col = param.RAW_PREDICTION_COL |
| self.assertEqual(raw_prediction_col.name, "raw_prediction_col") |
| self.assertEqual(raw_prediction_col.description, "Raw prediction column name.") |
| self.assertEqual(raw_prediction_col.default_value, "raw_prediction") |
| |
| param.set_raw_prediction_col('test_raw_prediction') |
| self.assertEqual(param.get_raw_prediction_col(), 'test_raw_prediction') |
| |
| def test_reg_param(self): |
| param = TestParams() |
| reg = param.REG |
| self.assertEqual(reg.name, "reg") |
| self.assertEqual(reg.description, "Regularization parameter.") |
| self.assertEqual(reg.default_value, 0.) |
| |
| param.set_reg(0.4) |
| self.assertEqual(param.get_reg(), 0.4) |
| |
| def test_seed_param(self): |
| param = TestParams() |
| seed = param.SEED |
| self.assertEqual(seed.name, "seed") |
| self.assertEqual(seed.description, "The random seed.") |
| self.assertEqual(seed.default_value, None) |
| |
| param.set_seed(1) |
| self.assertEqual(param.get_seed(), 1) |
| |
| def test_tol(self): |
| param = TestParams() |
| tol = param.TOL |
| self.assertEqual(tol.name, "tol") |
| self.assertEqual(tol.description, "Convergence tolerance for iterative algorithms.") |
| self.assertEqual(tol.default_value, 1e-6) |
| |
| param.set_tol(1e-5) |
| self.assertEqual(param.get_tol(), 1e-5) |
| |
| def test_weight_col(self): |
| param = TestParams() |
| weight_col = param.WEIGHT_COL |
| self.assertEqual(weight_col.name, "weight_col") |
| self.assertEqual(weight_col.description, "Weight column name.") |
| self.assertEqual(weight_col.default_value, None) |
| |
| param.set_weight_col('test_weight_col') |
| self.assertEqual(param.get_weight_col(), 'test_weight_col') |