blob: 42bedfb330b1b3d11534c5e266bbe4143cfe1d71 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, List, TYPE_CHECKING, Mapping, Dict
import pyspark.sql.connect.proto as pb2
from pyspark.sql.types import DataType
from pyspark.ml.linalg import (
DenseVector,
SparseVector,
DenseMatrix,
SparseMatrix,
)
if TYPE_CHECKING:
from pyspark.sql.connect.client import SparkConnectClient
from pyspark.ml.param import Params
def literal_null() -> pb2.Expression.Literal:
dt = pb2.DataType()
dt.null.CopyFrom(pb2.DataType.NULL())
return pb2.Expression.Literal(null=dt)
def build_int_list(value: List[int]) -> pb2.Expression.Literal:
p = pb2.Expression.Literal()
p.specialized_array.ints.values.extend(value)
return p
def build_float_list(value: List[float]) -> pb2.Expression.Literal:
p = pb2.Expression.Literal()
p.specialized_array.doubles.values.extend(value)
return p
def build_proto_udt(jvm_class: str) -> pb2.DataType:
ret = pb2.DataType()
ret.udt.type = "udt"
ret.udt.jvm_class = jvm_class
return ret
proto_vector_udt = build_proto_udt("org.apache.spark.ml.linalg.VectorUDT")
proto_matrix_udt = build_proto_udt("org.apache.spark.ml.linalg.MatrixUDT")
def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression.Literal:
from pyspark.sql.connect.expressions import LiteralExpression
if isinstance(value, SparseVector):
p = pb2.Expression.Literal()
p.struct.struct_type.CopyFrom(proto_vector_udt)
# type = 0
p.struct.elements.append(pb2.Expression.Literal(byte=0))
# size
p.struct.elements.append(pb2.Expression.Literal(integer=value.size))
# indices
p.struct.elements.append(build_int_list(value.indices.tolist()))
# values
p.struct.elements.append(build_float_list(value.values.tolist()))
return p
elif isinstance(value, DenseVector):
p = pb2.Expression.Literal()
p.struct.struct_type.CopyFrom(proto_vector_udt)
# type = 1
p.struct.elements.append(pb2.Expression.Literal(byte=1))
# size = null
p.struct.elements.append(literal_null())
# indices = null
p.struct.elements.append(literal_null())
# values
p.struct.elements.append(build_float_list(value.values.tolist()))
return p
elif isinstance(value, SparseMatrix):
p = pb2.Expression.Literal()
p.struct.struct_type.CopyFrom(proto_matrix_udt)
# type = 0
p.struct.elements.append(pb2.Expression.Literal(byte=0))
# numRows
p.struct.elements.append(pb2.Expression.Literal(integer=value.numRows))
# numCols
p.struct.elements.append(pb2.Expression.Literal(integer=value.numCols))
# colPtrs
p.struct.elements.append(build_int_list(value.colPtrs.tolist()))
# rowIndices
p.struct.elements.append(build_int_list(value.rowIndices.tolist()))
# values
p.struct.elements.append(build_float_list(value.values.tolist()))
# isTransposed
p.struct.elements.append(pb2.Expression.Literal(boolean=value.isTransposed))
return p
elif isinstance(value, DenseMatrix):
p = pb2.Expression.Literal()
p.struct.struct_type.CopyFrom(proto_matrix_udt)
# type = 1
p.struct.elements.append(pb2.Expression.Literal(byte=1))
# numRows
p.struct.elements.append(pb2.Expression.Literal(integer=value.numRows))
# numCols
p.struct.elements.append(pb2.Expression.Literal(integer=value.numCols))
# colPtrs = null
p.struct.elements.append(literal_null())
# rowIndices = null
p.struct.elements.append(literal_null())
# values
p.struct.elements.append(build_float_list(value.values.tolist()))
# isTransposed
p.struct.elements.append(pb2.Expression.Literal(boolean=value.isTransposed))
return p
else:
return LiteralExpression._from_value(value).to_plan(client).literal
def serialize(client: "SparkConnectClient", *args: Any) -> List[Any]:
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.sql.connect.expressions import LiteralExpression
result = []
for arg in args:
if isinstance(arg, ConnectDataFrame):
result.append(pb2.Fetch.Method.Args(input=arg._plan.plan(client)))
elif isinstance(arg, tuple) and len(arg) == 2 and isinstance(arg[1], DataType):
# explicitly specify the data type, for cases like empty list[str]
result.append(
pb2.Fetch.Method.Args(
param=LiteralExpression(value=arg[0], dataType=arg[1]).to_plan(client).literal
)
)
else:
result.append(pb2.Fetch.Method.Args(param=serialize_param(arg, client)))
return result
def deserialize_param(literal: pb2.Expression.Literal) -> Any:
from pyspark.sql.connect.expressions import LiteralExpression
if literal.HasField("struct"):
s = literal.struct
jvm_class = s.struct_type.udt.jvm_class
if jvm_class == "org.apache.spark.ml.linalg.VectorUDT":
assert len(s.elements) == 4
tpe = s.elements[0].byte
if tpe == 0:
size = s.elements[1].integer
indices = s.elements[2].specialized_array.ints.values
values = s.elements[3].specialized_array.doubles.values
return SparseVector(size, indices, values)
elif tpe == 1:
values = s.elements[3].specialized_array.doubles.values
return DenseVector(values)
else:
raise ValueError(f"Unknown Vector type {tpe}")
elif jvm_class == "org.apache.spark.ml.linalg.MatrixUDT":
assert len(s.elements) == 7
tpe = s.elements[0].byte
if tpe == 0:
numRows = s.elements[1].integer
numCols = s.elements[2].integer
colPtrs = s.elements[3].specialized_array.ints.values
rowIndices = s.elements[4].specialized_array.ints.values
values = s.elements[5].specialized_array.doubles.values
isTransposed = s.elements[6].boolean
return SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
elif tpe == 1:
numRows = s.elements[1].integer
numCols = s.elements[2].integer
values = s.elements[5].specialized_array.doubles.values
isTransposed = s.elements[6].boolean
return DenseMatrix(numRows, numCols, values, isTransposed)
else:
raise ValueError(f"Unknown Matrix type {tpe}")
else:
raise ValueError(f"Unknown UDT {jvm_class}")
else:
return LiteralExpression._to_value(literal)
def deserialize(ml_command_result_properties: Dict[str, Any]) -> Any:
ml_command_result = ml_command_result_properties["ml_command_result"]
if ml_command_result.HasField("operator_info"):
return ml_command_result.operator_info
if ml_command_result.HasField("param"):
return deserialize_param(ml_command_result.param)
raise ValueError("Unsupported result type")
def serialize_ml_params(instance: "Params", client: "SparkConnectClient") -> pb2.MlParams:
params: Mapping[str, pb2.Expression.Literal] = {
k.name: serialize_param(v, client) for k, v in instance._paramMap.items()
}
return pb2.MlParams(params=params)
def serialize_ml_params_values(
values: Dict[str, Any], client: "SparkConnectClient"
) -> pb2.MlParams:
params: Mapping[str, pb2.Expression.Literal] = {
k: serialize_param(v, client) for k, v in values.items()
}
return pb2.MlParams(params=params)