blob: 4ca3aa05960e580822d379d25c8a4a9fbc3620d9 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from abc import ABC
from typing import Tuple
from pyflink.ml.core.param import WithParams, Param, ParamValidators, StringParam, IntParam, \
StringArrayParam, FloatParam, WindowsParam
from pyflink.ml.core.windows import Windows, GlobalWindows
class HasDistanceMeasure(WithParams, ABC):
"""
Base class for the shared distance_measure param.
"""
DISTANCE_MEASURE: Param[str] = StringParam(
"distance_measure",
"Distance measure. Supported options: 'euclidean', 'manhattan' and 'cosine'.",
"euclidean",
ParamValidators.in_array(['euclidean', 'manhattan', 'cosine']))
def set_distance_measure(self, distance_measure: str):
return self.set(self.DISTANCE_MEASURE, distance_measure)
def get_distance_measure(self) -> str:
return self.get(self.DISTANCE_MEASURE)
@property
def distance_measure(self) -> str:
return self.get_distance_measure()
class HasFeaturesCol(WithParams, ABC):
"""
Base class for the shared feature_col param.
`HasFeaturesCol` is typically used for `Stage`s that implement `HasLabelCol`. It is preferred
to use `HasInputCol` for other cases.
"""
FEATURES_COL: Param[str] = StringParam(
"features_col",
"Features column name.",
"features",
ParamValidators.not_null())
def set_features_col(self, col):
return self.set(self.FEATURES_COL, col)
def get_features_col(self) -> str:
return self.get(self.FEATURES_COL)
@property
def features_col(self) -> str:
return self.get_features_col()
class HasGlobalBatchSize(WithParams, ABC):
"""
Base class for the shared global_batch_size param.
"""
GLOBAL_BATCH_SIZE: Param[int] = IntParam(
"global_batch_size",
"Global batch size of training algorithms.",
32,
ParamValidators.gt(0))
def set_global_batch_size(self, global_batch_size: int):
return self.set(self.GLOBAL_BATCH_SIZE, global_batch_size)
def get_global_batch_size(self) -> int:
return self.get(self.GLOBAL_BATCH_SIZE)
@property
def global_batch_size(self) -> int:
return self.get_global_batch_size()
class HasHandleInvalid(WithParams, ABC):
"""
Base class for the shared handle_invalid param.
Supported options and the corresponding behavior to handle invalid entries is listed as follows.
<ul>
<li>error: raise an exception.
<li>skip: filter out rows with bad values.
</ul>
"""
HANDLE_INVALID: Param[str] = StringParam(
"handle_invalid",
"Strategy to handle invalid entries.",
"error",
ParamValidators.in_array(['error', 'skip']))
def set_handle_invalid(self, value: str):
return self.set(self.HANDLE_INVALID, value)
def get_handle_invalid(self) -> str:
return self.get(self.HANDLE_INVALID)
@property
def handle_invalid(self) -> str:
return self.get_handle_invalid()
class HasInputCol(WithParams, ABC):
"""
Base class for the shared input col param.
"""
INPUT_COL: Param[str] = StringParam(
"input_col",
"Input column name.",
"input",
ParamValidators.not_null())
def set_input_col(self, col: str):
return self.set(self.INPUT_COL, col)
def get_input_col(self) -> str:
return self.get(self.INPUT_COL)
@property
def input_col(self) -> str:
return self.get_input_col()
class HasInputCols(WithParams, ABC):
"""
Base class for the shared input cols param.
"""
INPUT_COLS: Param[Tuple[str, ...]] = StringArrayParam(
"input_cols",
"Input column names.",
None,
ParamValidators.non_empty_array())
def set_input_cols(self, *cols: str):
return self.set(self.INPUT_COLS, cols)
def get_input_cols(self) -> Tuple[str, ...]:
return self.get(self.INPUT_COLS)
@property
def input_cols(self) -> Tuple[str, ...]:
return self.get_input_cols()
class HasCategoricalCols(WithParams, ABC):
"""
Base class for the shared categorical cols param.
"""
CATEGORICAL_COLS: Param[Tuple[str, ...]] = StringArrayParam(
"categorical_cols",
"Categorical column names.",
[],
ParamValidators.not_null())
def set_categorical_cols(self, *cols: str):
return self.set(self.CATEGORICAL_COLS, cols)
def get_categorical_cols(self) -> Tuple[str, ...]:
return self.get(self.CATEGORICAL_COLS)
@property
def categorical_cols(self) -> Tuple[str, ...]:
return self.get_categorical_cols()
class HasNumFeatures(WithParams, ABC):
"""
Base class for the shared numFeatures param.
"""
NUM_FEATURES: Param[int] = IntParam(
"num_features",
"Number of features.",
262144,
ParamValidators.gt(0))
def set_num_features(self, num_features: int):
return self.set(self.NUM_FEATURES, num_features)
def get_num_features(self) -> int:
return self.get(self.NUM_FEATURES)
@property
def num_features(self) -> int:
return self.get_num_features()
class HasLabelCol(WithParams, ABC):
"""
Base class for the shared label column param.
"""
LABEL_COL: Param[str] = StringParam(
"label_col",
"Label column name.",
"label",
ParamValidators.not_null())
def set_label_col(self, col: str):
return self.set(self.LABEL_COL, col)
def get_label_col(self) -> str:
return self.get(self.LABEL_COL)
@property
def label_col(self) -> str:
return self.get_label_col()
class HasLearningRate(WithParams, ABC):
"""
Base class for the shared learning rate param.
"""
LEARNING_RATE: Param[float] = FloatParam(
"learning_rate",
"Learning rate of optimization method.",
0.1,
ParamValidators.gt(0))
def set_learning_rate(self, learning_rate: float):
return self.set(self.LEARNING_RATE, learning_rate)
def get_learning_rate(self) -> float:
return self.get(self.LEARNING_RATE)
@property
def learning_rate(self) -> float:
return self.get_learning_rate()
class HasMaxIter(WithParams, ABC):
"""
Base class for the shared maxIter param.
"""
MAX_ITER: Param[int] = IntParam(
"max_iter",
"Maximum number of iterations.",
20,
ParamValidators.gt(0))
def set_max_iter(self, max_iter: int):
return self.set(self.MAX_ITER, max_iter)
def get_max_iter(self) -> int:
return self.get(self.MAX_ITER)
@property
def max_iter(self) -> int:
return self.get_max_iter()
class HasMultiClass(WithParams, ABC):
"""
Base class for the shared multi class param.
Supported options:
<li>auto: selects the classification type based on the number of classes:
If the number of unique label values from the input data is one or two,
set to "binomial". Otherwise, set to "multinomial".
<li>binomial: binary logistic regression.
<li>multinomial: multinomial logistic regression.
"""
MULTI_CLASS: Param[str] = StringParam(
"multi_class",
"Classification type. Supported options: 'auto', 'binomial' and 'multinomial'.",
'auto',
ParamValidators.in_array(['auto', 'binomial', 'multinomial']))
def set_multi_class(self, class_type: str):
return self.set(self.MULTI_CLASS, class_type)
def get_multi_class(self) -> str:
return self.get(self.MULTI_CLASS)
@property
def multi_class(self) -> str:
return self.get_multi_class()
class HasOutputCol(WithParams, ABC):
"""
Base class for the shared output_col param.
"""
OUTPUT_COL: Param[str] = StringParam(
"output_col",
"Output column name.",
"output",
ParamValidators.not_null())
def set_output_col(self, col: str):
return self.set(self.OUTPUT_COL, col)
def get_output_col(self) -> str:
return self.get(self.OUTPUT_COL)
@property
def output_col(self) -> str:
return self.get_output_col()
class HasOutputCols(WithParams, ABC):
"""
Base class for the shared output_cols param.
"""
OUTPUT_COLS: Param[Tuple[str, ...]] = StringArrayParam(
"output_cols",
"Output column names.",
None,
ParamValidators.non_empty_array())
def set_output_cols(self, *cols: str):
return self.set(self.OUTPUT_COLS, cols)
def get_output_cols(self) -> Tuple[str, ...]:
return self.get(self.OUTPUT_COLS)
@property
def output_cols(self) -> Tuple[str, ...]:
return self.get_output_cols()
class HasPredictionCol(WithParams, ABC):
"""
Base class for the shared prediction column param.
"""
PREDICTION_COL: Param[str] = StringParam(
"prediction_col",
"Prediction column name.",
"prediction",
ParamValidators.not_null())
def set_prediction_col(self, col: str):
return self.set(self.PREDICTION_COL, col)
def get_prediction_col(self) -> str:
return self.get(self.PREDICTION_COL)
@property
def prediction_col(self) -> str:
return self.get_prediction_col()
class HasRawPredictionCol(WithParams, ABC):
"""
Base class for the shared raw prediction column param.
"""
RAW_PREDICTION_COL: Param[str] = StringParam(
"raw_prediction_col",
"Raw prediction column name.",
"raw_prediction")
def set_raw_prediction_col(self, col: str):
return self.set(self.RAW_PREDICTION_COL, col)
def get_raw_prediction_col(self):
return self.get(self.RAW_PREDICTION_COL)
@property
def raw_prediction_col(self) -> str:
return self.get_raw_prediction_col()
class HasReg(WithParams, ABC):
"""
Base class for the shared regularization param.
"""
REG: Param[float] = FloatParam(
"reg",
"Regularization parameter.",
0.,
ParamValidators.gt_eq(0.))
def set_reg(self, value: float):
return self.set(self.REG, value)
def get_reg(self) -> float:
return self.get(self.REG)
@property
def reg(self) -> float:
return self.get_reg()
class HasSeed(WithParams, ABC):
"""
Base class for the shared seed param.
"""
SEED: Param[int] = IntParam(
"seed",
"The random seed.",
None)
def set_seed(self, seed: int):
return self.set(self.SEED, seed) if seed is not None else hash(self.__class__.__name__)
def get_seed(self) -> int:
return self.get(self.SEED)
@property
def seed(self) -> int:
return self.get_seed()
class HasTol(WithParams, ABC):
"""
Base class for the shared tolerance param.
"""
TOL: Param[float] = FloatParam(
"tol",
"Convergence tolerance for iterative algorithms.",
1e-6,
ParamValidators.gt_eq(0))
def set_tol(self, value: float):
return self.set(self.TOL, value)
def get_tol(self) -> float:
return self.get(self.TOL)
@property
def tol(self) -> float:
return self.get_tol()
class HasWeightCol(WithParams, ABC):
"""
Base class for the shared weight column param. If this is not set, we treat all instance weights
as 1.0.
"""
WEIGHT_COL: Param[str] = StringParam(
"weight_col",
"Weight column name.",
None)
def set_weight_col(self, col: str):
return self.set(self.WEIGHT_COL, col)
def get_weight_col(self) -> str:
return self.get(self.WEIGHT_COL)
@property
def weight_col(self):
return self.get_weight_col()
class HasBatchStrategy(WithParams, ABC):
"""
Base class for the shared batch strategy param.
"""
BATCH_STRATEGY: Param[str] = StringParam(
"batch_strategy",
"Strategy to create mini batch from online train data.",
"count",
ParamValidators.in_array(["count"]))
def get_batch_strategy(self) -> str:
return self.get(self.BATCH_STRATEGY)
@property
def batch_strategy(self):
return self.get_batch_strategy()
class HasDecayFactor(WithParams, ABC):
"""
Base class for the shared decay factor param.
"""
DECAY_FACTOR: Param[float] = FloatParam(
"decay_factor",
"The forgetfulness of the previous centroids.",
0.,
ParamValidators.in_range(0, 1))
def set_decay_factor(self, value: float):
return self.set(self.DECAY_FACTOR, value)
def get_decay_factor(self) -> float:
return self.get(self.DECAY_FACTOR)
@property
def decay_factor(self):
return self.get(self.DECAY_FACTOR)
class HasElasticNet(WithParams, ABC):
"""
Base class for the shared decay factor param.
"""
ELASTIC_NET: Param[float] = FloatParam(
"elastic_net",
"ElasticNet parameter.",
0.,
ParamValidators.in_range(0.0, 1.0))
def set_elastic_net(self, value: float):
return self.set(self.ELASTIC_NET, value)
def get_elastic_net(self) -> float:
return self.get(self.ELASTIC_NET)
@property
def elastic_net(self):
return self.get(self.ELASTIC_NET)
class HasWindows(WithParams, ABC):
"""
Base class for the shared windows param.
"""
WINDOWS: Param[Windows] = WindowsParam(
"windows",
"Windowing strategy that determines how to create mini-batches from input data.",
GlobalWindows(),
ParamValidators.not_null())
def set_windows(self, value: Windows):
self.set(self.WINDOWS, value)
return self
def get_windows(self) -> Windows:
return self.get(self.WINDOWS)
@property
def windows(self):
return self.get(self.WINDOWS)
class HasRelativeError(WithParams, ABC):
"""
Interface for shared param relativeError.
"""
RELATIVE_ERROR: Param[float] = FloatParam(
"relative_error",
"The relative target precision for the approximate quantile algorithm.",
0.001,
ParamValidators.in_range(0.0, 1.0))
def set_relative_error(self, value: float):
return self.set(self.RELATIVE_ERROR, value)
def get_relative_error(self) -> float:
return self.get(self.RELATIVE_ERROR)
@property
def relative_error(self):
return self.get(self.RELATIVE_ERROR)