blob: 61b72956e5ccc060874ab9d41005a43872184a07 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import timezone
from typing import Any, Dict, Mapping, Iterator, Optional, cast, Sequence
import pyspark.sql.connect.proto as pb2
from pyspark.sql import SparkSession
from pyspark.errors.exceptions.base import PySparkValueError
from pyspark.pipelines.logging_utils import log_with_provided_timestamp
def create_dataflow_graph(
spark: SparkSession,
default_catalog: Optional[str],
default_database: Optional[str],
sql_conf: Optional[Mapping[str, str]],
) -> str:
"""Create a dataflow graph in in the Spark Connect server.
:returns: The ID of the created dataflow graph.
"""
inner_command = pb2.PipelineCommand.CreateDataflowGraph(
default_catalog=default_catalog,
default_database=default_database,
sql_conf=sql_conf,
)
command = pb2.Command()
command.pipeline_command.create_dataflow_graph.CopyFrom(inner_command)
# Cast because mypy seems to think `spark`` is a function, not an object. Likely related to
# SPARK-47544.
(_, properties, _) = cast(Any, spark).client.execute_command(command)
return properties["pipeline_command_result"].create_dataflow_graph_result.dataflow_graph_id
def handle_pipeline_events(iter: Iterator[Dict[str, Any]]) -> None:
"""
Prints out the pipeline events received from the Spark Connect server.
"""
for result in iter:
if "pipeline_command_result" in result.keys():
# We expect to get a pipeline_command_result back in response to the initial StartRun
# command.
continue
elif "pipeline_event_result" not in result.keys():
raise PySparkValueError(
"Pipeline logs stream handler received an unexpected result: " f"{result}"
)
else:
event = result["pipeline_event_result"].event
dt = event.timestamp.ToDatetime().replace(tzinfo=timezone.utc)
log_with_provided_timestamp(event.message, dt)
def start_run(
spark: SparkSession,
dataflow_graph_id: str,
full_refresh: Optional[Sequence[str]],
full_refresh_all: bool,
refresh: Optional[Sequence[str]],
dry: bool,
) -> Iterator[Dict[str, Any]]:
"""Start a run of the dataflow graph in the Spark Connect server.
:param dataflow_graph_id: The ID of the dataflow graph to start.
:param full_refresh: List of datasets to reset and recompute.
:param full_refresh_all: Perform a full graph reset and recompute.
:param refresh: List of datasets to update.
"""
inner_command = pb2.PipelineCommand.StartRun(
dataflow_graph_id=dataflow_graph_id,
full_refresh_selection=full_refresh or [],
full_refresh_all=full_refresh_all,
refresh_selection=refresh or [],
dry=dry,
)
command = pb2.Command()
command.pipeline_command.start_run.CopyFrom(inner_command)
# Cast because mypy seems to think `spark`` is a function, not an object. Likely related to
# SPARK-47544.
return cast(Any, spark).client.execute_command_as_iterator(command)