blob: 528fc10bbb28da768611da0e9063fd9857869695 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import hashlib
import shutil
import tempfile
import unittest
import os
from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect
from pyspark.testing.sqlutils import SPARK_HOME
from pyspark.sql.functions import udf, assert_true, lit
if should_test_connect:
from pyspark.sql.connect.client.artifact import ArtifactManager
from pyspark.sql.connect.client import DefaultChannelBuilder
from pyspark.errors import SparkRuntimeException
class ArtifactTestsMixin:
def check_add_pyfile(self, spark_session):
with tempfile.TemporaryDirectory(prefix="check_add_pyfile") as d:
pyfile_path = os.path.join(d, "my_pyfile.py")
with open(pyfile_path, "w") as f:
f.write("my_func = lambda: 10")
@udf("int")
def func(x):
import my_pyfile
return my_pyfile.my_func()
spark_session.addArtifacts(pyfile_path, pyfile=True)
spark_session.range(1).select(assert_true(func("id") == lit(10))).show()
def test_add_pyfile(self):
self.check_add_pyfile(self.spark)
# Test multi sessions. Should be able to add the same
# file from different session.
self.check_add_pyfile(
SparkSession.builder.remote(
f"sc://localhost:{DefaultChannelBuilder.default_port()}"
).create()
)
def test_artifacts_cannot_be_overwritten(self):
with tempfile.TemporaryDirectory(prefix="test_artifacts_cannot_be_overwritten") as d:
pyfile_path = os.path.join(d, "my_pyfile.py")
with open(pyfile_path, "w+") as f:
f.write("my_func = lambda: 10")
self.spark.addArtifacts(pyfile_path, pyfile=True)
# Writing the same file twice is fine, and should not throw.
self.spark.addArtifacts(pyfile_path, pyfile=True)
with open(pyfile_path, "w+") as f:
f.write("my_func = lambda: 11")
with self.assertRaises(SparkRuntimeException) as pe:
self.spark.addArtifacts(pyfile_path, pyfile=True)
self.check_error(
exception=pe.exception,
errorClass="ARTIFACT_ALREADY_EXISTS",
messageParameters={"normalizedRemoteRelativePath": "pyfiles/my_pyfile.py"},
)
def check_add_zipped_package(self, spark_session):
with tempfile.TemporaryDirectory(prefix="check_add_zipped_package") as d:
package_path = os.path.join(d, "my_zipfile")
os.mkdir(package_path)
pyfile_path = os.path.join(package_path, "__init__.py")
with open(pyfile_path, "w") as f:
_ = f.write("my_func = lambda: 5")
shutil.make_archive(package_path, "zip", d, "my_zipfile")
@udf("long")
def func(x):
import my_zipfile
return my_zipfile.my_func()
spark_session.addArtifacts(f"{package_path}.zip", pyfile=True)
spark_session.range(1).select(assert_true(func("id") == lit(5))).show()
def test_add_zipped_package(self):
self.check_add_zipped_package(self.spark)
# Test multi sessions. Should be able to add the same
# file from different session.
self.check_add_zipped_package(
SparkSession.builder.remote(
f"sc://localhost:{DefaultChannelBuilder.default_port()}"
).create()
)
def check_add_archive(self, spark_session):
with tempfile.TemporaryDirectory(prefix="check_add_archive") as d:
archive_path = os.path.join(d, "my_archive")
os.mkdir(archive_path)
pyfile_path = os.path.join(archive_path, "my_file.txt")
with open(pyfile_path, "w") as f:
_ = f.write("hello world!")
shutil.make_archive(archive_path, "zip", d, "my_archive")
# Should addArtifact first to make sure state is set,
# and 'root' can be found properly.
spark_session.addArtifacts(f"{archive_path}.zip#my_files", archive=True)
root = self.root()
@udf("string")
def func(x):
with open(
os.path.join(root, "my_files", "my_archive", "my_file.txt"),
"r",
) as my_file:
return my_file.read().strip()
spark_session.range(1).select(assert_true(func("id") == lit("hello world!"))).show()
def test_add_archive(self):
self.check_add_archive(self.spark)
# Test multi sessions. Should be able to add the same
# file from different session.
self.check_add_archive(
SparkSession.builder.remote(
f"sc://localhost:{DefaultChannelBuilder.default_port()}"
).create()
)
def check_add_file(self, spark_session):
with tempfile.TemporaryDirectory(prefix="check_add_file") as d:
file_path = os.path.join(d, "my_file.txt")
with open(file_path, "w") as f:
f.write("Hello world!!")
# Should addArtifact first to make sure state is set,
# and 'root' can be found properly.
spark_session.addArtifacts(file_path, file=True)
root = self.root()
@udf("string")
def func(x):
with open(os.path.join(root, "my_file.txt"), "r") as my_file:
return my_file.read().strip()
spark_session.range(1).select(assert_true(func("id") == lit("Hello world!!"))).show()
def test_add_file(self):
self.check_add_file(self.spark)
# Test multi sessions. Should be able to add the same
# file from different session.
self.check_add_file(
SparkSession.builder.remote(
f"sc://localhost:{DefaultChannelBuilder.default_port()}"
).create()
)
@unittest.skipIf(is_remote_only(), "Requires JVM access")
class ArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin):
@classmethod
def root(cls):
from pyspark.core.files import SparkFiles
# In local mode, the file location is the same as Driver
# The executors are running in a thread.
jvm = SparkSession._instantiatedSession._jvm
current_uuid = (
getattr(
getattr(
jvm.org.apache.spark, # type: ignore[union-attr]
"JobArtifactSet$",
),
"MODULE$",
)
.lastSeenState()
.get()
.uuid()
)
return os.path.join(SparkFiles.getRootDirectory(), current_uuid)
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.artifact_manager: ArtifactManager = cls.spark._client._artifact_manager
cls.base_resource_dir = os.path.join(SPARK_HOME, "data")
cls.artifact_file_path = os.path.join(
cls.base_resource_dir,
"artifact-tests",
)
cls.artifact_crc_path = os.path.join(
cls.artifact_file_path,
"crc",
)
@classmethod
def conf(cls):
conf = super().conf()
conf.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
return conf
def test_basic_requests(self):
file_name = "smallJar"
small_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar")
if not os.path.isfile(small_jar_path):
raise unittest.SkipTest(f"Skipped as {small_jar_path} does not exist.")
response = self.artifact_manager._retrieve_responses(
self.artifact_manager._create_requests(
small_jar_path, pyfile=False, archive=False, file=False
)
)
self.assertTrue(response.artifacts[0].name.endswith(f"{file_name}.jar"))
def test_single_chunk_artifact(self):
file_name = "smallJar"
small_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar")
small_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt")
if not os.path.isfile(small_jar_path):
raise unittest.SkipTest(f"Skipped as {small_jar_path} does not exist.")
requests = list(
self.artifact_manager._create_requests(
small_jar_path, pyfile=False, archive=False, file=False
)
)
self.assertEqual(len(requests), 1)
request = requests[0]
self.assertIsNotNone(request.batch)
batch = request.batch
self.assertEqual(len(batch.artifacts), 1)
single_artifact = batch.artifacts[0]
self.assertTrue(single_artifact.name.endswith(".jar"))
self.assertEqual(os.path.join("jars", f"{file_name}.jar"), single_artifact.name)
with open(small_jar_crc_path) as f1, open(small_jar_path, "rb") as f2:
self.assertEqual(single_artifact.data.crc, int(f1.readline()))
self.assertEqual(single_artifact.data.data, f2.read())
def test_chunked_artifacts(self):
file_name = "junitLargeJar"
large_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar")
large_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt")
if not os.path.isfile(large_jar_path):
raise unittest.SkipTest(f"Skipped as {large_jar_path} does not exist.")
requests = list(
self.artifact_manager._create_requests(
large_jar_path, pyfile=False, archive=False, file=False
)
)
# Expected chunks = roundUp( file_size / chunk_size) = 12
# File size of `junitLargeJar.jar` is 384581 bytes.
large_jar_size = os.path.getsize(large_jar_path)
expected_chunks = int(
(large_jar_size + (ArtifactManager.CHUNK_SIZE - 1)) / ArtifactManager.CHUNK_SIZE
)
self.assertEqual(len(requests), expected_chunks)
request = requests[0]
self.assertIsNotNone(request.begin_chunk)
begin_chunk = request.begin_chunk
self.assertEqual(begin_chunk.name, os.path.join("jars", f"{file_name}.jar"))
self.assertEqual(begin_chunk.total_bytes, large_jar_size)
self.assertEqual(begin_chunk.num_chunks, expected_chunks)
other_requests = requests[1:]
data_chunks = [begin_chunk.initial_chunk] + [req.chunk for req in other_requests]
with open(large_jar_crc_path) as f1, open(large_jar_path, "rb") as f2:
cscs = [chunk.crc for chunk in data_chunks]
expected_cscs = [int(line.rstrip()) for line in f1]
self.assertEqual(cscs, expected_cscs)
binaries = [chunk.data for chunk in data_chunks]
expected_binaries = list(iter(lambda: f2.read(ArtifactManager.CHUNK_SIZE), b""))
self.assertEqual(binaries, expected_binaries)
def test_batched_artifacts(self):
file_name = "smallJar"
small_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar")
small_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt")
if not os.path.isfile(small_jar_path):
raise unittest.SkipTest(f"Skipped as {small_jar_path} does not exist.")
requests = list(
self.artifact_manager._create_requests(
small_jar_path, small_jar_path, pyfile=False, archive=False, file=False
)
)
# Single request containing 2 artifacts.
self.assertEqual(len(requests), 1)
request = requests[0]
self.assertIsNotNone(request.batch)
batch = request.batch
self.assertEqual(len(batch.artifacts), 2)
artifact1 = batch.artifacts[0]
self.assertTrue(artifact1.name.endswith(".jar"))
artifact2 = batch.artifacts[1]
self.assertTrue(artifact2.name.endswith(".jar"))
self.assertEqual(os.path.join("jars", f"{file_name}.jar"), artifact1.name)
with open(small_jar_crc_path) as f1, open(small_jar_path, "rb") as f2:
crc = int(f1.readline())
data = f2.read()
self.assertEqual(artifact1.data.crc, crc)
self.assertEqual(artifact1.data.data, data)
self.assertEqual(artifact2.data.crc, crc)
self.assertEqual(artifact2.data.data, data)
def test_single_chunked_and_chunked_artifact(self):
file_name1 = "smallJar"
file_name2 = "junitLargeJar"
small_jar_path = os.path.join(self.artifact_file_path, f"{file_name1}.jar")
small_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name1}.txt")
large_jar_path = os.path.join(self.artifact_file_path, f"{file_name2}.jar")
large_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name2}.txt")
large_jar_size = os.path.getsize(large_jar_path)
if not os.path.isfile(small_jar_path):
raise unittest.SkipTest(f"Skipped as {small_jar_path} does not exist.")
if not os.path.isfile(large_jar_path):
raise unittest.SkipTest(f"Skipped as {large_jar_path} does not exist.")
requests = list(
self.artifact_manager._create_requests(
small_jar_path,
large_jar_path,
small_jar_path,
small_jar_path,
pyfile=False,
archive=False,
file=False,
)
)
# There are a total of 14 requests.
# The 1st request contains a single artifact - smallJar.jar (There are no
# other artifacts batched with it since the next one is large multi-chunk artifact)
# Requests 2-13 (1-indexed) belong to the transfer of junitLargeJar.jar. This includes
# the first "beginning chunk" and the subsequent data chunks.
# The last request (14) contains two smallJar.jar batched
# together.
self.assertEqual(len(requests), 1 + 12 + 1)
first_req_batch = requests[0].batch.artifacts
self.assertEqual(len(first_req_batch), 1)
self.assertEqual(first_req_batch[0].name, os.path.join("jars", f"{file_name1}.jar"))
with open(small_jar_crc_path) as f1, open(small_jar_path, "rb") as f2:
self.assertEqual(first_req_batch[0].data.crc, int(f1.readline()))
self.assertEqual(first_req_batch[0].data.data, f2.read())
second_req_batch = requests[1]
self.assertIsNotNone(second_req_batch.begin_chunk)
begin_chunk = second_req_batch.begin_chunk
self.assertEqual(begin_chunk.name, os.path.join("jars", f"{file_name2}.jar"))
self.assertEqual(begin_chunk.total_bytes, large_jar_size)
self.assertEqual(begin_chunk.num_chunks, 12)
other_requests = requests[2:-1]
data_chunks = [begin_chunk.initial_chunk] + [req.chunk for req in other_requests]
with open(large_jar_crc_path) as f1, open(large_jar_path, "rb") as f2:
cscs = [chunk.crc for chunk in data_chunks]
expected_cscs = [int(line.rstrip()) for line in f1]
self.assertEqual(cscs, expected_cscs)
binaries = [chunk.data for chunk in data_chunks]
expected_binaries = list(iter(lambda: f2.read(ArtifactManager.CHUNK_SIZE), b""))
self.assertEqual(binaries, expected_binaries)
last_request = requests[-1]
self.assertIsNotNone(last_request.batch)
batch = last_request.batch
self.assertEqual(len(batch.artifacts), 2)
artifact1 = batch.artifacts[0]
self.assertTrue(artifact1.name.endswith(".jar"))
artifact2 = batch.artifacts[1]
self.assertTrue(artifact2.name.endswith(".jar"))
self.assertEqual(os.path.join("jars", f"{file_name1}.jar"), artifact1.name)
with open(small_jar_crc_path) as f1, open(small_jar_path, "rb") as f2:
crc = int(f1.readline())
data = f2.read()
self.assertEqual(artifact1.data.crc, crc)
self.assertEqual(artifact1.data.data, data)
self.assertEqual(artifact2.data.crc, crc)
self.assertEqual(artifact2.data.data, data)
def test_copy_from_local_to_fs(self):
with tempfile.TemporaryDirectory(prefix="test_copy_from_local_to_fs1") as d:
with tempfile.TemporaryDirectory(prefix="test_copy_from_local_to_fs2") as d2:
file_path = os.path.join(d, "file1")
dest_path = os.path.join(d2, "file1_dest")
file_content = "test_copy_from_local_to_FS"
with open(file_path, "w") as f:
f.write(file_content)
self.spark.copyFromLocalToFs(file_path, dest_path)
with open(dest_path, "r") as f:
self.assertEqual(f.read(), file_content)
def test_cache_artifact(self):
s = "Hello, World!"
blob = bytearray(s, "utf-8")
expected_hash = hashlib.sha256(blob).hexdigest()
self.assertEqual(self.artifact_manager.is_cached_artifact(expected_hash), False)
actualHash = self.artifact_manager.cache_artifact(blob)
self.assertEqual(actualHash, expected_hash)
self.assertEqual(self.artifact_manager.is_cached_artifact(expected_hash), True)
def test_add_not_existing_artifact(self):
with tempfile.TemporaryDirectory(prefix="test_add_not_existing_artifact") as d:
with self.assertRaises(FileNotFoundError):
self.artifact_manager.add_artifacts(
os.path.join(d, "not_existing"), file=True, pyfile=False, archive=False
)