blob: 74e8ac0da3bf7e1aaaf40779ff47cef37e7c8466 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pathlib import Path
from pyspark.errors import PySparkTypeError
from pyspark.sql import SparkSession
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.pipelines.block_connect_access import block_spark_connect_execution_and_analysis
from pyspark.pipelines.output import (
Output,
MaterializedView,
Table,
StreamingTable,
TemporaryView,
)
from pyspark.pipelines.flow import Flow
from pyspark.pipelines.graph_element_registry import GraphElementRegistry
from pyspark.pipelines.source_code_location import SourceCodeLocation
from pyspark.sql.connect.types import pyspark_types_to_proto_types
from pyspark.sql.types import StructType
from typing import Any, cast
import pyspark.sql.connect.proto as pb2
class SparkConnectGraphElementRegistry(GraphElementRegistry):
"""Registers outputs and flows in a dataflow graph held in a Spark Connect server."""
def __init__(self, spark: SparkSession, dataflow_graph_id: str) -> None:
# Cast because mypy seems to think `spark`` is a function, not an object. Likely related to
# SPARK-47544.
self._client = cast(Any, spark).client
self._dataflow_graph_id = dataflow_graph_id
def register_output(self, output: Output) -> None:
if isinstance(output, Table):
if isinstance(output.schema, str):
schema_string = output.schema
schema_data_type = None
elif isinstance(output.schema, StructType):
schema_string = None
schema_data_type = pyspark_types_to_proto_types(output.schema)
else:
schema_string = None
schema_data_type = None
table_details = pb2.PipelineCommand.DefineOutput.TableDetails(
table_properties=output.table_properties,
partition_cols=output.partition_cols,
format=output.format,
# Even though schema_string is not required, the generated Python code seems to
# erroneously think it is required.
schema_string=schema_string, # type: ignore[arg-type]
schema_data_type=schema_data_type,
)
if isinstance(output, MaterializedView):
output_type = pb2.OutputType.MATERIALIZED_VIEW
elif isinstance(output, StreamingTable):
output_type = pb2.OutputType.TABLE
else:
raise PySparkTypeError(
errorClass="UNSUPPORTED_PIPELINES_DATASET_TYPE",
messageParameters={"output_type": type(output).__name__},
)
elif isinstance(output, TemporaryView):
output_type = pb2.OutputType.TEMPORARY_VIEW
table_details = None
else:
raise PySparkTypeError(
errorClass="UNSUPPORTED_PIPELINES_DATASET_TYPE",
messageParameters={"output_type": type(output).__name__},
)
inner_command = pb2.PipelineCommand.DefineOutput(
dataflow_graph_id=self._dataflow_graph_id,
output_name=output.name,
output_type=output_type,
comment=output.comment,
table_details=table_details,
source_code_location=source_code_location_to_proto(output.source_code_location),
)
command = pb2.Command()
command.pipeline_command.define_output.CopyFrom(inner_command)
self._client.execute_command(command)
def register_flow(self, flow: Flow) -> None:
with block_spark_connect_execution_and_analysis():
df = flow.func()
relation = cast(ConnectDataFrame, df)._plan.plan(self._client)
relation_flow_details = pb2.PipelineCommand.DefineFlow.WriteRelationFlowDetails(
relation=relation,
)
inner_command = pb2.PipelineCommand.DefineFlow(
dataflow_graph_id=self._dataflow_graph_id,
flow_name=flow.name,
target_dataset_name=flow.target,
relation_flow_details=relation_flow_details,
sql_conf=flow.spark_conf,
source_code_location=source_code_location_to_proto(flow.source_code_location),
)
command = pb2.Command()
command.pipeline_command.define_flow.CopyFrom(inner_command)
self._client.execute_command(command)
def register_sql(self, sql_text: str, file_path: Path) -> None:
inner_command = pb2.PipelineCommand.DefineSqlGraphElements(
dataflow_graph_id=self._dataflow_graph_id,
sql_text=sql_text,
sql_file_path=str(file_path),
)
command = pb2.Command()
command.pipeline_command.define_sql_graph_elements.CopyFrom(inner_command)
self._client.execute_command(command)
def source_code_location_to_proto(
source_code_location: SourceCodeLocation,
) -> pb2.SourceCodeLocation:
return pb2.SourceCodeLocation(
file_name=source_code_location.filename, line_number=source_code_location.line_number
)