blob: fb430d48a0d455657211320b2ccc6d282373380f [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
"""Spark connector with different catalog types."""
import os
from typing import Any, Dict, List, Optional, Union
from pyspark.sql import SparkSession
class IcebergSparkSession:
"""Create a Spark session that connects to Polaris.
The session is expected to be used within a with statement, as in:
with IcebergSparkSession(
credentials=f"{client_id}:{client_secret}",
aws_region='us-west-2',
polaris_url="http://polaris:8181/api/catalog",
catalog_name="catalog_name"
) as spark:
spark.sql(f"USE catalog.{hybrid_executor.database}.{hybrid_executor.schema}")
table_list = spark.sql("SHOW TABLES").collect()
"""
def __init__(
self,
bearer_token: str = None,
credentials: str = None,
aws_region: str = None,
catalog_name: str = None,
polaris_url: str = None,
realm: str = 'POLARIS',
use_vended_credentials: bool = True
):
"""Constructor for Iceberg Spark session. Sets the member variables."""
self.bearer_token = bearer_token
self.credentials = credentials
if aws_region is None:
self.aws_region = os.environ.get('AWS_REGION', 'us-west-2')
self.catalog_name = catalog_name
self.polaris_url = polaris_url
self.realm = realm
self.use_vended_credentials = use_vended_credentials
def get_catalog_name(self):
"""Get the catalog name of this spark session based on catalog_type."""
return self.catalog_name
def get_session(self):
"""Get the real spark session."""
return self.spark_session
def sql(self, query: str, args: Optional[Union[Dict[str, Any], List]] = None, **kwargs):
"""Wrapper for the sql function of SparkSession."""
return self.spark_session.sql(query, args, **kwargs)
def __enter__(self):
"""Initial method for Iceberg Spark session. Creates a Spark session with specified configs.
"""
packages = [
"org.apache.iceberg:iceberg-spark-runtime-3.5_2.12:1.10.0",
"org.apache.iceberg:iceberg-aws-bundle:1.10.0",
]
excludes = ["org.checkerframework:checker-qual", "com.google.errorprone:error_prone_annotations"]
packages_string = ",".join(packages)
excludes_string = ",".join(excludes)
catalog_name = self.get_catalog_name()
creds = self.credentials
credConfig = f"spark.sql.catalog.{catalog_name}.credential"
if self.bearer_token is not None:
creds = self.bearer_token
credConfig = f"spark.sql.catalog.{catalog_name}.token"
spark_session_builder = (
SparkSession.builder.config("spark.jars.packages", packages_string)
.config("spark.jars.excludes", excludes_string)
.config("spark.sql.iceberg.vectorization.enabled", "false")
.config("spark.hadoop.fs.s3.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")
.config("spark.history.fs.logDirectory", "/home/iceberg/spark-events")
.config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions")
.config(
"spark.hadoop.fs.s3a.aws.credentials.provider",
"org.apache.hadoop.fs.s3a.TemporaryAWSCredentialsProvider",
)
.config(
f"spark.sql.catalog.{catalog_name}", "org.apache.iceberg.spark.SparkCatalog"
)
.config(f"spark.sql.catalog.{catalog_name}.type", "rest")
.config(f"spark.sql.catalog.{catalog_name}.uri", self.polaris_url)
.config(f"spark.sql.catalog.{catalog_name}.warehouse", self.catalog_name)
.config(f"spark.sql.catalog.{catalog_name}.scope", 'PRINCIPAL_ROLE:ALL')
.config(f"spark.sql.catalog.{catalog_name}.header.realm", self.realm)
.config(f"spark.sql.catalog.{catalog_name}.client.region", self.aws_region)
.config(credConfig, creds)
.config("spark.ui.showConsoleProgress", False)
)
# Conditionally add vended credentials header
if self.use_vended_credentials:
spark_session_builder = spark_session_builder.config(
f"spark.sql.catalog.{catalog_name}.header.X-Iceberg-Access-Delegation", "vended-credentials"
)
else:
# Explicitly remove the header if it was set globally
spark_session_builder = spark_session_builder.config(
f"spark.sql.catalog.{catalog_name}.header.X-Iceberg-Access-Delegation", ""
)
self.spark_session = spark_session_builder.getOrCreate()
self.quiet_logs(self.spark_session.sparkContext)
return self
def quiet_logs(self, sc):
logger = sc._jvm.org.apache.log4j
logger.LogManager.getLogger("org").setLevel(logger.Level.ERROR)
logger.LogManager.getLogger("akka").setLevel(logger.Level.ERROR)
def __exit__(self, exc_type, exc_val, exc_tb):
"""Destructor for Iceberg Spark session. Stops the Spark session."""
self.spark_session.stop()