blob: 64cd5ac4024f2009698c4c085e977636a9fce3b0 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from typing import Tuple
from import WithParams, Param, ParamValidators, StringParam, IntParam, \
StringArrayParam, FloatParam
class HasDistanceMeasure(WithParams, ABC):
Base class for the shared distance_measure param.
DISTANCE_MEASURE: Param[str] = StringParam(
"Distance measure. Supported options: 'euclidean' and 'cosine'.",
ParamValidators.in_array(['euclidean', 'cosine']))
def set_distance_measure(self, distance_measure: str):
return self.set(self.DISTANCE_MEASURE, distance_measure)
def get_distance_measure(self) -> str:
return self.get(self.DISTANCE_MEASURE)
def distance_measure(self) -> str:
return self.get_distance_measure()
class HasFeaturesCol(WithParams, ABC):
Base class for the shared feature_col param.
FEATURES_COL: Param[str] = StringParam(
"Features column name.",
def set_features_col(self, col):
return self.set(self.FEATURES_COL, col)
def get_features_col(self) -> str:
return self.get(self.FEATURES_COL)
def features_col(self) -> str:
return self.get_features_col()
class HasGlobalBatchSize(WithParams, ABC):
Base class for the shared global_batch_size param.
GLOBAL_BATCH_SIZE: Param[int] = IntParam(
"Global batch size of training algorithms.",
def set_global_batch_size(self, global_batch_size: int):
return self.set(self.GLOBAL_BATCH_SIZE, global_batch_size)
def get_global_batch_size(self) -> int:
return self.get(self.GLOBAL_BATCH_SIZE)
def global_batch_size(self) -> int:
return self.get_global_batch_size()
class HasHandleInvalid(WithParams, ABC):
Base class for the shared handle_invalid param.
Supported options and the corresponding behavior to handle invalid entries is listed as follows.
<li>error: raise an exception.
<li>skip: filter out rows with bad values.
HANDLE_INVALID: Param[str] = StringParam(
"Strategy to handle invalid entries.",
ParamValidators.in_array(['error', 'skip']))
def set_handle_invalid(self, value: str):
return self.set(self.HANDLE_INVALID, value)
def get_handle_invalid(self) -> str:
return self.get(self.HANDLE_INVALID)
def handle_invalid(self) -> str:
return self.get_handle_invalid()
class HasInputCol(WithParams, ABC):
Base class for the shared input col param.
INPUT_COL: Param[str] = StringParam(
"Input column name.",
def set_input_col(self, col: str):
return self.set(self.INPUT_COL, col)
def get_input_col(self) -> str:
return self.get(self.INPUT_COL)
def input_col(self) -> str:
return self.get_input_col()
class HasInputCols(WithParams, ABC):
Base class for the shared input cols param.
INPUT_COLS: Param[Tuple[str, ...]] = StringArrayParam(
"Input column names.",
def set_input_cols(self, *cols: str):
return self.set(self.INPUT_COLS, cols)
def get_input_cols(self) -> Tuple[str, ...]:
return self.get(self.INPUT_COLS)
def input_cols(self) -> Tuple[str, ...]:
return self.get_input_cols()
class HasLabelCol(WithParams, ABC):
Base class for the shared label column param.
LABEL_COL: Param[str] = StringParam(
"Label column name.",
def set_label_col(self, col: str):
return self.set(self.LABEL_COL, col)
def get_label_col(self) -> str:
return self.get(self.LABEL_COL)
def label_col(self) -> str:
return self.get_label_col()
class HasLearningRate(WithParams, ABC):
Base class for the shared learning rate param.
LEARNING_RATE: Param[float] = FloatParam(
"Learning rate of optimization method.",
def set_learning_rate(self, learning_rate: float):
return self.set(self.LEARNING_RATE, learning_rate)
def get_learning_rate(self) -> float:
return self.get(self.LEARNING_RATE)
def learning_rate(self) -> float:
return self.get_learning_rate()
class HasMaxIter(WithParams, ABC):
Base class for the shared maxIter param.
MAX_ITER: Param[int] = IntParam(
"Maximum number of iterations.",
def set_max_iter(self, max_iter: int):
return self.set(self.MAX_ITER, max_iter)
def get_max_iter(self) -> int:
return self.get(self.MAX_ITER)
def max_iter(self) -> int:
return self.get_max_iter()
class HasMultiClass(WithParams, ABC):
Base class for the shared multi class param.
Supported options:
<li>auto: selects the classification type based on the number of classes:
If the number of unique label values from the input data is one or two,
set to "binomial". Otherwise, set to "multinomial".
<li>binomial: binary logistic regression.
<li>multinomial: multinomial logistic regression.
MULTI_CLASS: Param[str] = StringParam(
"Classification type. Supported options: 'auto', 'binomial' and 'multinomial'.",
ParamValidators.in_array(['auto', 'binomial', 'multinomial']))
def set_multi_class(self, class_type: str):
return self.set(self.MULTI_CLASS, class_type)
def get_multi_class(self) -> str:
return self.get(self.MULTI_CLASS)
def multi_class(self) -> str:
return self.get_multi_class()
class HasOutputCol(WithParams, ABC):
Base class for the shared output_col param.
OUTPUT_COL: Param[str] = StringParam(
"Output column name.",
def set_output_col(self, col: str):
return self.set(self.OUTPUT_COL, col)
def get_output_col(self) -> str:
return self.get(self.OUTPUT_COL)
def output_col(self) -> str:
return self.get_output_col()
class HasOutputCols(WithParams, ABC):
Base class for the shared output_cols param.
OUTPUT_COLS: Param[Tuple[str, ...]] = StringArrayParam(
"Output column names.",
def set_output_cols(self, *cols: str):
return self.set(self.OUTPUT_COLS, cols)
def get_output_cols(self) -> Tuple[str, ...]:
return self.get(self.OUTPUT_COLS)
def output_cols(self) -> Tuple[str, ...]:
return self.get_output_cols()
class HasPredictionCol(WithParams, ABC):
Base class for the shared prediction column param.
PREDICTION_COL: Param[str] = StringParam(
"Prediction column name.",
def set_prediction_col(self, col: str):
return self.set(self.PREDICTION_COL, col)
def get_prediction_col(self) -> str:
return self.get(self.PREDICTION_COL)
def prediction_col(self) -> str:
return self.get_prediction_col()
class HasRawPredictionCol(WithParams, ABC):
Base class for the shared raw prediction column param.
RAW_PREDICTION_COL: Param[str] = StringParam(
"Raw prediction column name.",
def set_raw_prediction_col(self, col: str):
return self.set(self.RAW_PREDICTION_COL, col)
def get_raw_prediction_col(self):
return self.get(self.RAW_PREDICTION_COL)
def raw_prediction_col(self) -> str:
return self.get_raw_prediction_col()
class HasReg(WithParams, ABC):
Base class for the shared regularization param.
REG: Param[float] = FloatParam(
"Regularization parameter.",
def set_reg(self, value: float):
return self.set(self.REG, value)
def get_reg(self) -> float:
return self.get(self.REG)
def reg(self) -> float:
return self.get_reg()
class HasSeed(WithParams, ABC):
Base class for the shared seed param.
SEED: Param[int] = IntParam(
"The random seed.",
def set_seed(self, seed: int):
return self.set(self.SEED, seed) if seed is not None else hash(self.__class__.__name__)
def get_seed(self) -> int:
return self.get(self.SEED)
def seed(self) -> int:
return self.get_seed()
class HasTol(WithParams, ABC):
Base class for the shared tolerance param.
TOL: Param[float] = FloatParam(
"Convergence tolerance for iterative algorithms.",
def set_tol(self, value: float):
return self.set(self.TOL, value)
def get_tol(self) -> float:
return self.get(self.TOL)
def tol(self) -> float:
return self.get_tol()
class HasWeightCol(WithParams, ABC):
Base class for the shared weight column param. If this is not set, we treat all instance weights
as 1.0.
WEIGHT_COL: Param[str] = StringParam(
"Weight column name.",
def set_weight_col(self, col: str):
return self.set(self.WEIGHT_COL, col)
def get_weight_col(self) -> str:
return self.get(self.WEIGHT_COL)
def weight_col(self):
return self.get_weight_col()
class HasBatchStrategy(WithParams, ABC):
Base class for the shared batch strategy param.
BATCH_STRATEGY: Param[str] = StringParam(
"Strategy to create mini batch from online train data.",
def get_batch_strategy(self) -> str:
return self.get(self.BATCH_STRATEGY)
def batch_strategy(self):
return self.get_batch_strategy()
class HasDecayFactor(WithParams, ABC):
Base class for the shared decay factor param.
DECAY_FACTOR: Param[float] = FloatParam(
"The forgetfulness of the previous centroids.",
ParamValidators.in_range(0, 1))
def set_decay_factor(self, value: float):
return self.set(self.DECAY_FACTOR, value)
def get_decay_factor(self) -> float:
return self.get(self.DECAY_FACTOR)
def decay_factor(self):
return self.get(self.DECAY_FACTOR)
class HasElasticNet(WithParams, ABC):
Base class for the shared decay factor param.
ELASTIC_NET: Param[float] = FloatParam(
"ElasticNet parameter.",
ParamValidators.in_range(0.0, 1.0))
def set_elastic_net(self, value: float):
return self.set(self.ELASTIC_NET, value)
def get_elastic_net(self) -> float:
return self.get(self.ELASTIC_NET)
def elastic_net(self):
return self.get(self.ELASTIC_NET)