blob: 08802a0fceaee5b785b0bc7f6ff313550c032eb4 [file]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import cast, Type, TYPE_CHECKING, Union, Dict, Any
import pyspark.sql.connect.proto as pb2
from pyspark.ml.connect.serialize import serialize_ml_params, deserialize, deserialize_param
from pyspark.ml.util import MLWriter, MLReader, RL
from pyspark.ml.wrapper import JavaWrapper
if TYPE_CHECKING:
from pyspark.core.context import SparkContext
from pyspark.sql.connect.session import SparkSession
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
class RemoteMLWriter(MLWriter):
def __init__(self, instance: "JavaMLWritable") -> None:
super().__init__()
self._instance = instance
@property
def sc(self) -> "SparkContext":
raise RuntimeError("Accessing SparkContext is not supported on Connect")
def save(self, path: str) -> None:
from pyspark.sql.connect.session import SparkSession
session = SparkSession.getActiveSession()
assert session is not None
RemoteMLWriter.saveInstance(
self._instance,
path,
session,
self.shouldOverwrite,
self.optionMap,
)
@staticmethod
def saveInstance(
instance: "JavaMLWritable",
path: str,
session: "SparkSession",
shouldOverwrite: bool = False,
optionMap: Dict[str, Any] = {},
) -> None:
from pyspark.ml.wrapper import JavaModel, JavaEstimator, JavaTransformer
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml.classification import OneVsRest, OneVsRestModel
from pyspark.ml.clustering import PowerIterationClustering
from pyspark.ml.tuning import (
CrossValidator,
CrossValidatorModel,
TrainValidationSplit,
TrainValidationSplitModel,
)
# Spark Connect ML is built on scala Spark.ML, that means we're only
# supporting JavaModel or JavaEstimator or JavaEvaluator
if isinstance(instance, JavaModel):
from pyspark.ml.util import RemoteModelRef
model = cast("JavaModel", instance)
params = serialize_ml_params(model, session.client)
assert isinstance(model._java_obj, RemoteModelRef)
writer = pb2.MlCommand.Write(
obj_ref=pb2.ObjectRef(id=model._java_obj.ref_id),
params=params,
path=path,
should_overwrite=shouldOverwrite,
options=optionMap,
)
command = pb2.Command()
command.ml_command.write.CopyFrom(writer)
session.client.execute_command(command)
elif isinstance(instance, (JavaEstimator, JavaTransformer, JavaEvaluator)):
operator: Union[JavaEstimator, JavaTransformer, JavaEvaluator]
if isinstance(instance, JavaEstimator):
ml_type = pb2.MlOperator.OPERATOR_TYPE_ESTIMATOR
operator = cast("JavaEstimator", instance)
elif isinstance(instance, JavaEvaluator):
ml_type = pb2.MlOperator.OPERATOR_TYPE_EVALUATOR
operator = cast("JavaEvaluator", instance)
else:
ml_type = pb2.MlOperator.OPERATOR_TYPE_TRANSFORMER
operator = cast("JavaTransformer", instance)
params = serialize_ml_params(operator, session.client)
assert isinstance(operator._java_obj, str)
writer = pb2.MlCommand.Write(
operator=pb2.MlOperator(name=operator._java_obj, uid=operator.uid, type=ml_type),
params=params,
path=path,
should_overwrite=shouldOverwrite,
options=optionMap,
)
command = pb2.Command()
command.ml_command.write.CopyFrom(writer)
session.client.execute_command(command)
elif isinstance(instance, Pipeline):
from pyspark.ml.pipeline import PipelineWriter
RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
pl_writer = PipelineWriter(instance)
pl_writer.session(session) # type: ignore[arg-type]
pl_writer.save(path)
elif isinstance(instance, PipelineModel):
from pyspark.ml.pipeline import PipelineModelWriter
RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
plm_writer = PipelineModelWriter(instance)
plm_writer.session(session) # type: ignore[arg-type]
plm_writer.save(path)
elif isinstance(instance, CrossValidator):
from pyspark.ml.tuning import CrossValidatorWriter
RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
cv_writer = CrossValidatorWriter(instance)
cv_writer.session(session) # type: ignore[arg-type]
cv_writer.save(path)
elif isinstance(instance, CrossValidatorModel):
from pyspark.ml.tuning import CrossValidatorModelWriter
RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
cvm_writer = CrossValidatorModelWriter(instance)
cvm_writer.optionMap = optionMap
cvm_writer.session(session) # type: ignore[arg-type]
cvm_writer.save(path)
elif isinstance(instance, TrainValidationSplit):
from pyspark.ml.tuning import TrainValidationSplitWriter
RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
tvs_writer = TrainValidationSplitWriter(instance)
tvs_writer.save(path)
elif isinstance(instance, TrainValidationSplitModel):
from pyspark.ml.tuning import TrainValidationSplitModelWriter
RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
tvsm_writer = TrainValidationSplitModelWriter(instance)
tvsm_writer.optionMap = optionMap
tvsm_writer.session(session) # type: ignore[arg-type]
tvsm_writer.save(path)
elif isinstance(instance, OneVsRest):
from pyspark.ml.classification import OneVsRestWriter
RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
ovr_writer = OneVsRestWriter(instance)
ovr_writer.session(session) # type: ignore[arg-type]
ovr_writer.save(path)
elif isinstance(instance, OneVsRestModel):
from pyspark.ml.classification import OneVsRestModelWriter
RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
ovrm_writer = OneVsRestModelWriter(instance)
ovrm_writer.session(session) # type: ignore[arg-type]
ovrm_writer.save(path)
elif isinstance(instance, PowerIterationClustering):
transformer = JavaTransformer(
"org.apache.spark.ml.clustering.PowerIterationClusteringWrapper"
)
transformer._resetUid(instance.uid)
transformer._paramMap = instance._paramMap
RemoteMLWriter.saveInstance(
transformer, # type: ignore[arg-type]
path,
session,
shouldOverwrite,
optionMap,
)
else:
raise NotImplementedError(f"Unsupported write for {instance.__class__}")
@staticmethod
def handleOverwrite(path: str, shouldOverwrite: bool) -> None:
from pyspark.ml.util import ML_CONNECT_HELPER_ID
if shouldOverwrite:
helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
helper._call_java("handleOverwrite", path, shouldOverwrite)
class RemoteMLReader(MLReader[RL]):
def __init__(self, clazz: Type["JavaMLReadable[RL]"]) -> None:
super().__init__()
self._clazz = clazz
def load(self, path: str) -> RL:
from pyspark.sql.connect.session import SparkSession
session = SparkSession.getActiveSession()
assert session is not None
return RemoteMLReader.loadInstance(self._clazz, path, session)
@staticmethod
def loadInstance(
clazz: Type["JavaMLReadable[RL]"],
path: str,
session: "SparkSession",
) -> RL:
from pyspark.ml.wrapper import JavaModel, JavaEstimator, JavaTransformer
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml.classification import OneVsRest, OneVsRestModel
from pyspark.ml.clustering import PowerIterationClustering
from pyspark.ml.tuning import (
CrossValidator,
CrossValidatorModel,
TrainValidationSplit,
TrainValidationSplitModel,
)
if (
issubclass(clazz, JavaModel)
or issubclass(clazz, JavaEstimator)
or issubclass(clazz, JavaEvaluator)
or issubclass(clazz, JavaTransformer)
):
if issubclass(clazz, JavaModel):
ml_type = pb2.MlOperator.OPERATOR_TYPE_MODEL
elif issubclass(clazz, JavaEstimator):
ml_type = pb2.MlOperator.OPERATOR_TYPE_ESTIMATOR
elif issubclass(clazz, JavaEvaluator):
ml_type = pb2.MlOperator.OPERATOR_TYPE_EVALUATOR
else:
ml_type = pb2.MlOperator.OPERATOR_TYPE_TRANSFORMER
# to get the java corresponding qualified class name
java_qualified_class_name = (
clazz.__module__.replace("pyspark", "org.apache.spark") + "." + clazz.__name__
)
command = pb2.Command()
command.ml_command.read.CopyFrom(
pb2.MlCommand.Read(
operator=pb2.MlOperator(name=java_qualified_class_name, type=ml_type), path=path
)
)
_, properties, _ = session.client.execute_command(command)
result = deserialize(properties)
# Get the python type
def _get_class() -> Type[RL]:
parts = (clazz.__module__ + "." + clazz.__name__).split(".")
module = ".".join(parts[:-1])
m = __import__(module, fromlist=[parts[-1]])
return getattr(m, parts[-1])
py_type = _get_class()
# It must be JavaWrapper, since we're passing the string to the _java_obj
if issubclass(py_type, JavaWrapper):
from pyspark.ml.util import RemoteModelRef
if ml_type == pb2.MlOperator.OPERATOR_TYPE_MODEL:
remote_model_ref = RemoteModelRef(result.obj_ref.id)
instance = py_type(remote_model_ref)
else:
instance = py_type()
instance._resetUid(result.uid)
params = {k: deserialize_param(v) for k, v in result.params.params.items()}
instance._set(**params)
return instance
else:
raise RuntimeError(f"Unsupported python type {py_type}")
elif issubclass(clazz, Pipeline):
from pyspark.ml.pipeline import PipelineReader
pl_reader = PipelineReader(Pipeline)
pl_reader.session(session)
return pl_reader.load(path)
elif issubclass(clazz, PipelineModel):
from pyspark.ml.pipeline import PipelineModelReader
plm_reader = PipelineModelReader(PipelineModel)
plm_reader.session(session)
return plm_reader.load(path)
elif issubclass(clazz, CrossValidator):
from pyspark.ml.tuning import CrossValidatorReader
cv_reader = CrossValidatorReader(CrossValidator)
cv_reader.session(session)
return cv_reader.load(path)
elif issubclass(clazz, CrossValidatorModel):
from pyspark.ml.tuning import CrossValidatorModelReader
cvm_reader = CrossValidatorModelReader(CrossValidator)
cvm_reader.session(session)
return cvm_reader.load(path)
elif issubclass(clazz, TrainValidationSplit):
from pyspark.ml.tuning import TrainValidationSplitReader
tvs_reader = TrainValidationSplitReader(TrainValidationSplit)
tvs_reader.session(session)
return tvs_reader.load(path)
elif issubclass(clazz, TrainValidationSplitModel):
from pyspark.ml.tuning import TrainValidationSplitModelReader
tvs_reader = TrainValidationSplitModelReader(TrainValidationSplitModel)
tvs_reader.session(session)
return tvs_reader.load(path)
elif issubclass(clazz, OneVsRest):
from pyspark.ml.classification import OneVsRestReader
ovr_reader = OneVsRestReader(OneVsRest)
ovr_reader.session(session)
return ovr_reader.load(path)
elif issubclass(clazz, OneVsRestModel):
from pyspark.ml.classification import OneVsRestModelReader
ovrm_reader = OneVsRestModelReader(OneVsRestModel)
ovrm_reader.session(session)
return ovrm_reader.load(path)
elif issubclass(clazz, PowerIterationClustering):
java_qualified_class_name = (
"org.apache.spark.ml.clustering.PowerIterationClusteringWrapper"
)
command = pb2.Command()
command.ml_command.read.CopyFrom(
pb2.MlCommand.Read(
operator=pb2.MlOperator(
name=java_qualified_class_name,
type=pb2.MlOperator.OPERATOR_TYPE_TRANSFORMER,
),
path=path,
)
)
_, properties, _ = session.client.execute_command(command)
result = deserialize(properties)
instance = PowerIterationClustering()
instance._resetUid(result.uid)
params = {k: deserialize_param(v) for k, v in result.params.params.items()}
instance._set(**params)
return instance # type: ignore[return-value]
else:
raise RuntimeError(f"Unsupported read for {clazz}")