| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| from abc import ABCMeta |
| |
| from pyspark import SparkContext |
| from pyspark.sql import DataFrame |
| from pyspark.ml.param import Params |
| from pyspark.ml.pipeline import Estimator, Transformer, Model |
| from pyspark.mllib.common import inherit_doc, _java2py, _py2java |
| |
| |
| def _jvm(): |
| """ |
| Returns the JVM view associated with SparkContext. Must be called |
| after SparkContext is initialized. |
| """ |
| jvm = SparkContext._jvm |
| if jvm: |
| return jvm |
| else: |
| raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") |
| |
| |
| @inherit_doc |
| class JavaWrapper(Params): |
| """ |
| Utility class to help create wrapper classes from Java/Scala |
| implementations of pipeline components. |
| """ |
| |
| __metaclass__ = ABCMeta |
| |
| #: The wrapped Java companion object. Subclasses should initialize |
| #: it properly. The param values in the Java object should be |
| #: synced with the Python wrapper in fit/transform/evaluate/copy. |
| _java_obj = None |
| |
| @staticmethod |
| def _new_java_obj(java_class, *args): |
| """ |
| Construct a new Java object. |
| """ |
| sc = SparkContext._active_spark_context |
| java_obj = _jvm() |
| for name in java_class.split("."): |
| java_obj = getattr(java_obj, name) |
| java_args = [_py2java(sc, arg) for arg in args] |
| return java_obj(*java_args) |
| |
| def _make_java_param_pair(self, param, value): |
| """ |
| Makes a Java parm pair. |
| """ |
| sc = SparkContext._active_spark_context |
| param = self._resolveParam(param) |
| java_param = self._java_obj.getParam(param.name) |
| java_value = _py2java(sc, value) |
| return java_param.w(java_value) |
| |
| def _transfer_params_to_java(self): |
| """ |
| Transforms the embedded params to the companion Java object. |
| """ |
| paramMap = self.extractParamMap() |
| for param in self.params: |
| if param in paramMap: |
| pair = self._make_java_param_pair(param, paramMap[param]) |
| self._java_obj.set(pair) |
| |
| def _transfer_params_from_java(self): |
| """ |
| Transforms the embedded params from the companion Java object. |
| """ |
| sc = SparkContext._active_spark_context |
| for param in self.params: |
| if self._java_obj.hasParam(param.name): |
| java_param = self._java_obj.getParam(param.name) |
| if self._java_obj.isDefined(java_param): |
| value = _java2py(sc, self._java_obj.getOrDefault(java_param)) |
| self._paramMap[param] = value |
| |
| @staticmethod |
| def _empty_java_param_map(): |
| """ |
| Returns an empty Java ParamMap reference. |
| """ |
| return _jvm().org.apache.spark.ml.param.ParamMap() |
| |
| |
| @inherit_doc |
| class JavaEstimator(Estimator, JavaWrapper): |
| """ |
| Base class for :py:class:`Estimator`s that wrap Java/Scala |
| implementations. |
| """ |
| |
| __metaclass__ = ABCMeta |
| |
| def _create_model(self, java_model): |
| """ |
| Creates a model from the input Java model reference. |
| """ |
| raise NotImplementedError() |
| |
| def _fit_java(self, dataset): |
| """ |
| Fits a Java model to the input dataset. |
| |
| :param dataset: input dataset, which is an instance of |
| :py:class:`pyspark.sql.DataFrame` |
| :param params: additional params (overwriting embedded values) |
| :return: fitted Java model |
| """ |
| self._transfer_params_to_java() |
| return self._java_obj.fit(dataset._jdf) |
| |
| def _fit(self, dataset): |
| java_model = self._fit_java(dataset) |
| return self._create_model(java_model) |
| |
| |
| @inherit_doc |
| class JavaTransformer(Transformer, JavaWrapper): |
| """ |
| Base class for :py:class:`Transformer`s that wrap Java/Scala |
| implementations. Subclasses should ensure they have the transformer Java object |
| available as _java_obj. |
| """ |
| |
| __metaclass__ = ABCMeta |
| |
| def _transform(self, dataset): |
| self._transfer_params_to_java() |
| return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx) |
| |
| |
| @inherit_doc |
| class JavaModel(Model, JavaTransformer): |
| """ |
| Base class for :py:class:`Model`s that wrap Java/Scala |
| implementations. Subclasses should inherit this class before |
| param mix-ins, because this sets the UID from the Java model. |
| """ |
| |
| __metaclass__ = ABCMeta |
| |
| def __init__(self, java_model): |
| """ |
| Initialize this instance with a Java model object. |
| Subclasses should call this constructor, initialize params, |
| and then call _transformer_params_from_java. |
| """ |
| super(JavaModel, self).__init__() |
| self._java_obj = java_model |
| self.uid = java_model.uid() |
| |
| def copy(self, extra=None): |
| """ |
| Creates a copy of this instance with the same uid and some |
| extra params. This implementation first calls Params.copy and |
| then make a copy of the companion Java model with extra params. |
| So both the Python wrapper and the Java model get copied. |
| |
| :param extra: Extra parameters to copy to the new instance |
| :return: Copy of this instance |
| """ |
| if extra is None: |
| extra = dict() |
| that = super(JavaModel, self).copy(extra) |
| that._java_obj = self._java_obj.copy(self._empty_java_param_map()) |
| that._transfer_params_to_java() |
| return that |
| |
| def _call_java(self, name, *args): |
| m = getattr(self._java_obj, name) |
| sc = SparkContext._active_spark_context |
| java_args = [_py2java(sc, arg) for arg in args] |
| return _java2py(sc, m(*java_args)) |