blob: 337eba25e5d7de371cfaf8a5679872c9b9145dfb [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
from typing import Dict, Optional, TYPE_CHECKING
from pyspark.sql.column import Column
from pyspark.sql.utils import to_scala_map
if TYPE_CHECKING:
from pyspark.sql.dataframe import DataFrame
__all__ = ["MergeIntoWriter"]
class MergeIntoWriter:
"""
`MergeIntoWriter` provides methods to define and execute merge actions based
on specified conditions.
.. versionadded: 4.0.0
"""
def __init__(self, df: "DataFrame", table: str, condition: Column):
self._spark = df.sparkSession
from pyspark.sql.classic.column import _to_java_column
self._jwriter = df._jdf.mergeInto(table, _to_java_column(condition))
def whenMatched(self, condition: Optional[Column] = None) -> "MergeIntoWriter.WhenMatched":
"""
Initialize a `WhenMatched` action with a condition.
This `WhenMatched` action will be executed when a source row matches a target table row
based on the merge condition and the specified `condition` is satisfied.
This `WhenMatched` can be followed by one of the following merge actions:
- `updateAll`: Update all the matched target table rows with source dataset rows.
- `update(Dict)`: Update all the matched target table rows while changing only
a subset of columns based on the provided assignment.
- `delete`: Delete all target rows that have a match in the source table.
"""
return self.WhenMatched(self, condition)
def whenNotMatched(
self, condition: Optional[Column] = None
) -> "MergeIntoWriter.WhenNotMatched":
"""
Initialize a `WhenNotMatched` action with a condition.
This `WhenNotMatched` action will be executed when a source row does not match any target
row based on the merge condition and the specified `condition` is satisfied.
This `WhenNotMatched` can be followed by one of the following merge actions:
- `insertAll`: Insert all rows from the source that are not already in the target table.
- `insert(Dict)`: Insert all rows from the source that are not already in the target
table, with the specified columns based on the provided assignment.
"""
return self.WhenNotMatched(self, condition)
def whenNotMatchedBySource(
self, condition: Optional[Column] = None
) -> "MergeIntoWriter.WhenNotMatchedBySource":
"""
Initialize a `WhenNotMatchedBySource` action with a condition.
This `WhenNotMatchedBySource` action will be executed when a target row does not match any
rows in the source table based on the merge condition and the specified `condition`
is satisfied.
This `WhenNotMatchedBySource` can be followed by one of the following merge actions:
- `updateAll`: Update all the not matched target table rows with source dataset rows.
- `update(Dict)`: Update all the not matched target table rows while changing only
the specified columns based on the provided assignment.
- `delete`: Delete all target rows that have no matches in the source table.
"""
return self.WhenNotMatchedBySource(self, condition)
def withSchemaEvolution(self) -> "MergeIntoWriter":
"""
Enable automatic schema evolution for this merge operation.
"""
self._jwriter = self._jwriter.withSchemaEvolution()
return self
def merge(self) -> None:
"""
Execute the merge operation.
"""
self._jwriter.merge()
class WhenMatched:
"""
A class for defining actions to be taken when matching rows in a DataFrame during
a merge operation."""
def __init__(self, writer: "MergeIntoWriter", condition: Optional[Column]):
self.writer = writer
if condition is None:
self.when_matched = writer._jwriter.whenMatched()
else:
from pyspark.sql.classic.column import _to_java_column
self.when_matched = writer._jwriter.whenMatched(_to_java_column(condition))
def updateAll(self) -> "MergeIntoWriter":
"""
Specifies an action to update all matched rows in the DataFrame.
"""
self.writer._jwriter = self.when_matched.updateAll()
return self.writer
def update(self, assignments: Dict[str, Column]) -> "MergeIntoWriter":
"""
Specifies an action to update matched rows in the DataFrame with the provided column
assignments.
"""
jvm = self.writer._spark._jvm
from pyspark.sql.classic.column import _to_java_column
jmap = to_scala_map(jvm, {k: _to_java_column(v) for k, v in assignments.items()})
self.writer._jwriter = self.when_matched.update(jmap)
return self.writer
def delete(self) -> "MergeIntoWriter":
"""
Specifies an action to delete matched rows from the DataFrame.
"""
self.writer._jwriter = self.when_matched.delete()
return self.writer
class WhenNotMatched:
"""
A class for defining actions to be taken when no matching rows are found in a DataFrame
during a merge operation."""
def __init__(self, writer: "MergeIntoWriter", condition: Optional[Column]):
self.writer = writer
if condition is None:
self.when_not_matched = writer._jwriter.whenNotMatched()
else:
from pyspark.sql.classic.column import _to_java_column
self.when_not_matched = writer._jwriter.whenNotMatched(_to_java_column(condition))
def insertAll(self) -> "MergeIntoWriter":
"""
Specifies an action to insert all non-matched rows into the DataFrame.
"""
self.writer._jwriter = self.when_not_matched.insertAll()
return self.writer
def insert(self, assignments: Dict[str, Column]) -> "MergeIntoWriter":
"""
Specifies an action to insert non-matched rows into the DataFrame with the provided
column assignments.
"""
jvm = self.writer._spark._jvm
from pyspark.sql.classic.column import _to_java_column
jmap = to_scala_map(jvm, {k: _to_java_column(v) for k, v in assignments.items()})
self.writer._jwriter = self.when_not_matched.insert(jmap)
return self.writer
class WhenNotMatchedBySource:
"""
A class for defining actions to be performed when there is no match by source
during a merge operation in a MergeIntoWriter.
"""
def __init__(self, writer: "MergeIntoWriter", condition: Optional[Column]):
self.writer = writer
if condition is None:
self.when_not_matched_by_source = writer._jwriter.whenNotMatchedBySource()
else:
from pyspark.sql.classic.column import _to_java_column
self.when_not_matched_by_source = writer._jwriter.whenNotMatchedBySource(
_to_java_column(condition)
)
def updateAll(self) -> "MergeIntoWriter":
"""
Specifies an action to update all non-matched rows in the target DataFrame when
not matched by the source.
"""
self.writer._jwriter = self.when_not_matched_by_source.updateAll()
return self.writer
def update(self, assignments: Dict[str, Column]) -> "MergeIntoWriter":
"""
Specifies an action to update non-matched rows in the target DataFrame with the provided
column assignments when not matched by the source.
"""
jvm = self.writer._spark._jvm
from pyspark.sql.classic.column import _to_java_column
jmap = to_scala_map(jvm, {k: _to_java_column(v) for k, v in assignments.items()})
self.writer._jwriter = self.when_not_matched_by_source.update(jmap)
return self.writer
def delete(self) -> "MergeIntoWriter":
"""
Specifies an action to delete matched rows from the DataFrame.
"""
self.writer._jwriter = self.when_not_matched_by_source.delete()
return self.writer
def _test() -> None:
import doctest
import os
import py4j
from pyspark.core.context import SparkContext
from pyspark.sql import SparkSession
import pyspark.sql.merge
os.chdir(os.environ["SPARK_HOME"])
globs = pyspark.sql.merge.__dict__.copy()
sc = SparkContext("local[4]", "PythonTest")
try:
spark = SparkSession._getActiveSessionOrCreate()
except py4j.protocol.Py4JError:
spark = SparkSession(sc)
globs["spark"] = spark
(failure_count, test_count) = doctest.testmod(
pyspark.sql.merge,
globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF,
)
spark.stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()