blob: ad33e901cd3a90cacb8d62587e6471cf9a6a4549 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from typing import Tuple
from pyflink.ml.param import Param, StringArrayParam, ParamValidators
from pyflink.ml.wrapper import JavaWithParams
from pyflink.ml.evaluation.common import JavaEvaluationAlgoOperator
from pyflink.ml.common.param import HasLabelCol, HasRawPredictionCol, HasWeightCol
class _BinaryClassificationEvaluatorParams(
JavaWithParams,
HasLabelCol,
HasRawPredictionCol,
HasWeightCol
):
"""
Params for :class:`BinaryClassificationEvaluator`.
"""
METRICS_NAMES: Param[Tuple[str, ...]] = StringArrayParam(
"metrics_names",
"Names of output metrics.",
["areaUnderROC", "areaUnderPR"],
ParamValidators.is_sub_set(["areaUnderROC", "areaUnderPR", "areaUnderLorenz", "ks"]))
def __init__(self, java_params):
super(_BinaryClassificationEvaluatorParams, self).__init__(java_params)
def set_metrics_names(self, *value: str):
return self.set(self.METRICS_NAMES, value)
def get_metrics_names(self) -> Tuple[str, ...]:
return self.get(self.METRICS_NAMES)
@property
def metrics_names(self) -> Tuple[str, ...]:
return self.get_metrics_names()
class BinaryClassificationEvaluator(
JavaEvaluationAlgoOperator,
_BinaryClassificationEvaluatorParams
):
"""
An AlgoOperator which calculates the evaluation metrics for binary classification. The input
data has columns rawPrediction, label and an optional weight column. The rawPrediction can be
of type double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2
vector of raw predictions, scores, or label probabilities). The output may contain different
metrics which will be defined by parameter metrics_names.
"""
def __init__(self, java_algo_operator=None):
super(BinaryClassificationEvaluator, self).__init__(java_algo_operator)
@classmethod
def _java_algo_operator_package_name(cls) -> str:
return "binaryclassification"
@classmethod
def _java_algo_operator_class_name(cls) -> str:
return "BinaryClassificationEvaluator"