blob: 0de51e9ccd29ba8d2e41e7edab834b0cb840aaa1 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import subprocess
import pytest
from typing import Optional
from pyspark.sql import SparkSession
from datasketches_spark import get_dependency_classpath
# Attempt to determin the Java version -- this may still fail
# based on command line arguments or other pyspark config, but
# it should often work for local testing.
# Favor looking for a java executable in $SPARK_HOME, else
# $JAVA_HOME, else $PATH from just running `java`
# May return an incorrect number for Java 8 (aka 1.8) but
# we only care about detecting 17 and higher
def get_java_version() -> Optional[int]:
java_cmd = 'java' # fallback option
# check $JAVA_HOME first
if 'JAVA_HOME' in os.environ:
java_cmd = os.path.join(os.environ['JAVA_HOME'], 'bin', 'java')
# check $SPARK_HOME next -- top choice, if available
if 'SPARK_HOME' in os.environ:
spark_java = os.path.join(os.environ['SPARK_HOME'], 'bin', 'java')
if os.path.exists(spark_java):
java_cmd = spark_java
# now attempt to run java --version
try:
# java -version prints to stderr
# Version is the 3nd argument of the 1st line and may be in double quotes
# We care only about the major version
output = subprocess.run([java_cmd, '-version'],
capture_output=True,
text=True)
version = output.stderr.splitlines()[0]
return int(version.split()[2].strip('"').split('.')[0])
except Exception as e:
print(f"Could not determine java version: {e}")
return None
@pytest.fixture(scope='session')
def spark():
java_opts = ''
java_version = get_java_version()
if java_version is not None and java_version >= 17 and java_version < 21:
java_opts = '--add-modules=jdk.incubator.foreign --add-exports=java.base/sun.nio.ch=ALL-UNNAMED'
os.environ['PYSPARK_SUBMIT_ARGS'] = f"--driver-java-options '{java_opts}' pyspark-shell"
spark = (
SparkSession.builder
.appName('test')
.master('local[1]')
.config('spark.driver.userClassPathFirst', 'true')
.config('spark.executor.userClassPathFirst', 'true')
.config('spark.driver.extraClassPath', get_dependency_classpath())
.config('spark.executor.extraClassPath', get_dependency_classpath())
.config('spark.executor.extraJavaOptions', java_opts)
.config('spark.driver.bindAddress', 'localhost')
.config('spark.driver.host', 'localhost')
.getOrCreate()
)
yield spark
spark.stop()