blob: ade96da0a4f139b829449389212096db31fdc992 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
header = """#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#"""
# Code generator for shared params (shared.py). Run under this folder with:
# python _shared_params_code_gen.py > shared.py
_type_for_type_converter = {
"TypeConverters.toBoolean": "bool",
"TypeConverters.toFloat": "float",
"TypeConverters.toInt": "int",
"TypeConverters.toListFloat": "List[float]",
"TypeConverters.toListInt": "List[int]",
"TypeConverters.toListString": "List[str]",
"TypeConverters.toString": "str",
}
def _gen_param_header(
name: str, doc: str, defaultValueStr: Optional[str], typeConverter: str, paramType: str
) -> str:
"""
Generates the header part for shared variables
:param name: param name
:param doc: param doc
"""
Name = f"Has{name[0].upper()}{name[1:]}"
template = f'''class {Name}(Params):
"""
Mixin for param {name}: {doc}
"""
{name}: "Param[{paramType}]" = Param(
Params._dummy(),
"{name}",
"{doc}",
typeConverter={typeConverter},
)
def __init__(self) -> None:
super({Name}, self).__init__()'''
if defaultValueStr is not None:
template += f"""
self._setDefault({name}={defaultValueStr})"""
return template
def _gen_param_code(name: str, paramType: str) -> str:
"""
Generates Python code for a shared param class.
:param name: param name
:param doc: param doc
:param defaultValueStr: string representation of the default value
:return: code string
"""
# TODO: How to correctly inherit instance attributes?
return f'''
def get{name[0].upper()}{name[1:]}(self) -> {paramType}:
"""
Gets the value of {name} or its default value.
"""
return self.getOrDefault(self.{name})'''
if __name__ == "__main__":
print(header)
print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n")
print("from typing import List\n")
print("from pyspark.ml.param import Param, Params, TypeConverters\n\n")
shared = [
(
"maxIter",
"max number of iterations (>= 0).",
None,
"TypeConverters.toInt",
),
(
"regParam",
"regularization parameter (>= 0).",
None,
"TypeConverters.toFloat",
),
(
"featuresCol",
"features column name.",
'"features"',
"TypeConverters.toString",
),
(
"labelCol",
"label column name.",
'"label"',
"TypeConverters.toString",
),
(
"predictionCol",
"prediction column name.",
'"prediction"',
"TypeConverters.toString",
),
(
"probabilityCol",
"Column name for predicted class conditional probabilities. "
+ "Note: Not all models output well-calibrated probability estimates! "
+ "These probabilities should be treated as confidences, not precise probabilities.",
'"probability"',
"TypeConverters.toString",
),
(
"rawPredictionCol",
"raw prediction (a.k.a. confidence) column name.",
'"rawPrediction"',
"TypeConverters.toString",
),
(
"inputCol",
"input column name.",
None,
"TypeConverters.toString",
),
(
"inputCols",
"input column names.",
None,
"TypeConverters.toListString",
),
(
"outputCol",
"output column name.",
'self.uid + "__output"',
"TypeConverters.toString",
),
(
"outputCols",
"output column names.",
None,
"TypeConverters.toListString",
),
(
"numFeatures",
"Number of features. Should be greater than 0.",
"262144",
"TypeConverters.toInt",
),
(
"checkpointInterval",
"set checkpoint interval (>= 1) or disable checkpoint (-1). "
+ "E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: "
+ "this setting will be ignored if the checkpoint directory is not set in "
+ "the SparkContext.",
None,
"TypeConverters.toInt",
),
(
"seed",
"random seed.",
"hash(type(self).__name__)",
"TypeConverters.toInt",
),
(
"tol",
"the convergence tolerance for iterative algorithms (>= 0).",
None,
"TypeConverters.toFloat",
),
(
"relativeError",
"the relative target precision for the approximate quantile "
+ "algorithm. Must be in the range [0, 1]",
"0.001",
"TypeConverters.toFloat",
),
(
"stepSize",
"Step size to be used for each iteration of optimization (>= 0).",
None,
"TypeConverters.toFloat",
),
(
"handleInvalid",
"how to handle invalid entries. Options are skip (which will filter "
+ "out rows with bad values), or error (which will throw an error). "
+ "More options may be added later.",
None,
"TypeConverters.toString",
),
(
"elasticNetParam",
"the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, "
+ "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.",
"0.0",
"TypeConverters.toFloat",
),
(
"fitIntercept",
"whether to fit an intercept term.",
"True",
"TypeConverters.toBoolean",
),
(
"standardization",
"whether to standardize the training features before fitting the " + "model.",
"True",
"TypeConverters.toBoolean",
),
(
"thresholds",
"Thresholds in multi-class classification to adjust the probability of "
+ "predicting each class. Array must have length equal to the number of classes, with "
+ "values > 0, excepting that at most one value may be 0. "
+ "The class with largest value p/t is predicted, where p is the original "
+ "probability of that class and t is the class's threshold.",
None,
"TypeConverters.toListFloat",
),
(
"threshold",
"threshold in binary classification prediction, in range [0, 1]",
"0.5",
"TypeConverters.toFloat",
),
(
"weightCol",
"weight column name. If this is not set or empty, we treat "
+ "all instance weights as 1.0.",
None,
"TypeConverters.toString",
),
(
"solver",
"the solver algorithm for optimization. If this is not set or empty, "
+ "default value is 'auto'.",
'"auto"',
"TypeConverters.toString",
),
(
"varianceCol",
"column name for the biased sample variance of prediction.",
None,
"TypeConverters.toString",
),
(
"aggregationDepth",
"suggested depth for treeAggregate (>= 2).",
"2",
"TypeConverters.toInt",
),
(
"parallelism",
"the number of threads to use when running parallel algorithms (>= 1).",
"1",
"TypeConverters.toInt",
),
(
"collectSubModels",
"Param for whether to collect a list of sub-models trained during "
+ "tuning. If set to false, then only the single best sub-model will be available "
+ "after fitting. If set to true, then all sub-models will be available. Warning: "
+ "For large models, collecting all sub-models can cause OOMs on the Spark driver.",
"False",
"TypeConverters.toBoolean",
),
(
"loss",
"the loss function to be optimized.",
None,
"TypeConverters.toString",
),
(
"distanceMeasure",
"the distance measure. Supported options: 'euclidean' and 'cosine'.",
'"euclidean"',
"TypeConverters.toString",
),
(
"validationIndicatorCol",
"name of the column that indicates whether each row is for "
+ "training or for validation. False indicates training; true indicates validation.",
None,
"TypeConverters.toString",
),
(
"blockSize",
"block size for stacking input data in matrices. Data is stacked within "
"partitions. If block size is more than remaining data in a partition then it is "
"adjusted to the size of this data.",
None,
"TypeConverters.toInt",
),
(
"maxBlockSizeInMB",
"maximum memory in MB for stacking input data into blocks. Data is "
+ "stacked within partitions. If more than remaining data size in a partition then it "
+ "is adjusted to the data size. Default 0.0 represents choosing optimal value, "
+ "depends on specific algorithm. Must be >= 0.",
"0.0",
"TypeConverters.toFloat",
),
(
"numTrainWorkers",
"number of training workers",
"1",
"TypeConverters.toInt",
),
(
"batchSize",
"number of training batch size",
None,
"TypeConverters.toInt",
),
(
"learningRate",
"learning rate for training",
None,
"TypeConverters.toFloat",
),
(
"momentum",
"momentum for training optimizer",
None,
"TypeConverters.toFloat",
),
(
"featureSizes",
"input feature size list for input columns of vector assembler",
None,
"TypeConverters.toListInt",
),
]
code = []
for name, doc, defaultValueStr, typeConverter in shared:
paramType = _type_for_type_converter.get(typeConverter, "None")
param_code = _gen_param_header(name, doc, defaultValueStr, typeConverter, paramType)
code.append(param_code + "\n" + _gen_param_code(name, paramType))
print("\n\n\n".join(code))