blob: 020c7989138d4d2d91db353828c3d5160e930a68 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pathlib import Path
from pyspark.errors import PySparkTypeError
from pyspark.sql import SparkSession
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.pipelines.block_connect_access import block_spark_connect_execution_and_analysis
from pyspark.pipelines.dataset import (
Dataset,
MaterializedView,
Table,
StreamingTable,
TemporaryView,
)
from pyspark.pipelines.flow import Flow
from pyspark.pipelines.graph_element_registry import GraphElementRegistry
from typing import Any, cast
import pyspark.sql.connect.proto as pb2
class SparkConnectGraphElementRegistry(GraphElementRegistry):
"""Registers datasets and flows in a dataflow graph held in a Spark Connect server."""
def __init__(self, spark: SparkSession, dataflow_graph_id: str) -> None:
# Cast because mypy seems to think `spark`` is a function, not an object. Likely related to
# SPARK-47544.
self._client = cast(Any, spark).client
self._dataflow_graph_id = dataflow_graph_id
def register_dataset(self, dataset: Dataset) -> None:
if isinstance(dataset, Table):
table_properties = dataset.table_properties
partition_cols = dataset.partition_cols
schema = None # TODO
format = dataset.format
if isinstance(dataset, MaterializedView):
dataset_type = pb2.DatasetType.MATERIALIZED_VIEW
elif isinstance(dataset, StreamingTable):
dataset_type = pb2.DatasetType.TABLE
else:
raise PySparkTypeError(
errorClass="UNSUPPORTED_PIPELINES_DATASET_TYPE",
messageParameters={"dataset_type": type(dataset).__name__},
)
elif isinstance(dataset, TemporaryView):
table_properties = None
partition_cols = None
schema = None
format = None
dataset_type = pb2.DatasetType.TEMPORARY_VIEW
else:
raise PySparkTypeError(
errorClass="UNSUPPORTED_PIPELINES_DATASET_TYPE",
messageParameters={"dataset_type": type(dataset).__name__},
)
inner_command = pb2.PipelineCommand.DefineDataset(
dataflow_graph_id=self._dataflow_graph_id,
dataset_name=dataset.name,
dataset_type=dataset_type,
comment=dataset.comment,
table_properties=table_properties,
partition_cols=partition_cols,
schema=schema,
format=format,
)
command = pb2.Command()
command.pipeline_command.define_dataset.CopyFrom(inner_command)
self._client.execute_command(command)
def register_flow(self, flow: Flow) -> None:
with block_spark_connect_execution_and_analysis():
df = flow.func()
relation = cast(ConnectDataFrame, df)._plan.plan(self._client)
inner_command = pb2.PipelineCommand.DefineFlow(
dataflow_graph_id=self._dataflow_graph_id,
flow_name=flow.name,
target_dataset_name=flow.target,
relation=relation,
sql_conf=flow.spark_conf,
)
command = pb2.Command()
command.pipeline_command.define_flow.CopyFrom(inner_command)
self._client.execute_command(command)
def register_sql(self, sql_text: str, file_path: Path) -> None:
inner_command = pb2.PipelineCommand.DefineSqlGraphElements(
dataflow_graph_id=self._dataflow_graph_id,
sql_text=sql_text,
sql_file_path=str(file_path),
)
command = pb2.Command()
command.pipeline_command.define_sql_graph_elements.CopyFrom(inner_command)
self._client.execute_command(command)