blob: 62399e2db68d58d6089d1a9f14009d3edbd95ea9 [file]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import contextlib
import os
import shutil
from io import StringIO
import stat
import subprocess
import sys
import time
import tempfile
import threading
from typing import Callable, Dict, Any
import unittest
from unittest.mock import patch
from pyspark import SparkConf, SparkContext
from pyspark.ml.torch.distributor import TorchDistributor, _get_gpus_owned
from pyspark.ml.torch.torch_run_process_wrapper import clean_and_terminate, check_parent_alive
from pyspark.sql import SparkSession
from pyspark.testing.sqlutils import SPARK_HOME
from pyspark.testing.utils import have_torch, torch_requirement_message
@contextlib.contextmanager
def patch_stdout() -> StringIO:
"""patch stdout and give an output"""
sys_stdout = sys.stdout
io_out = StringIO()
sys.stdout = io_out
try:
yield io_out
finally:
sys.stdout = sys_stdout
def create_training_function(mnist_dir_path: str) -> Callable:
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
batch_size = 100
num_epochs = 1
momentum = 0.5
train_dataset = datasets.MNIST(
mnist_dir_path,
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x: Any) -> Any:
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
def train_fn(learning_rate: float) -> Any:
import torch
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
dist.init_process_group("gloo")
train_sampler = DistributedSampler(dataset=train_dataset)
data_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, sampler=train_sampler
)
model = Net()
ddp_model = DDP(model)
optimizer = optim.SGD(ddp_model.parameters(), lr=learning_rate, momentum=momentum)
for epoch in range(1, num_epochs + 1):
ddp_model.train()
for _, (data, target) in enumerate(data_loader):
optimizer.zero_grad()
output = ddp_model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
print(f"epoch {epoch} finished.")
return "success" * 4096
return train_fn
def set_up_test_dirs():
gpu_discovery_script_file = tempfile.NamedTemporaryFile(delete=False)
gpu_discovery_script_file_name = gpu_discovery_script_file.name
try:
gpu_discovery_script_file.write(
b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
)
finally:
gpu_discovery_script_file.close()
# create temporary directory for Worker resources coordination
tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(tempdir.name)
os.chmod(
gpu_discovery_script_file_name,
stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
)
mnist_dir_path = tempfile.mkdtemp()
return (gpu_discovery_script_file_name, mnist_dir_path)
def get_local_mode_conf():
return {
"spark.test.home": SPARK_HOME,
"spark.driver.resource.gpu.amount": "3",
"spark.driver.memory": "512M",
"spark.executor.memory": "512M",
}
def get_distributed_mode_conf():
return {
"spark.test.home": SPARK_HOME,
"spark.worker.resource.gpu.amount": "3",
"spark.task.cpus": "2",
"spark.task.resource.gpu.amount": "1",
"spark.executor.resource.gpu.amount": "1",
"spark.driver.memory": "512M",
"spark.executor.memory": "512M",
}
class TorchDistributorBaselineUnitTestsMixin:
def setup_env_vars(self, input_map: Dict[str, str]) -> None:
for key, value in input_map.items():
os.environ[key] = value
def delete_env_vars(self, input_map: Dict[str, str]) -> None:
for key in input_map.keys():
del os.environ[key]
def test_validate_correct_inputs(self) -> None:
inputs = [
(1, True, False),
(100, True, False),
(1, False, False),
(100, False, False),
]
for num_processes, local_mode, use_gpu in inputs:
with self.subTest():
expected_params = {
"num_processes": num_processes,
"local_mode": local_mode,
"use_gpu": use_gpu,
"num_tasks": num_processes,
}
dist = TorchDistributor(num_processes, local_mode, use_gpu)
self.assertEqual(expected_params, dist.input_params)
def test_validate_incorrect_inputs(self) -> None:
inputs = [
(0, False, False, ValueError, "positive"),
]
for num_processes, local_mode, use_gpu, error, message in inputs:
with self.subTest():
with self.assertRaisesRegex(error, message):
TorchDistributor(num_processes, local_mode, use_gpu)
def test_encryption_passes(self) -> None:
inputs = [
("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "true"),
("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "false"),
("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "true"),
]
for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in inputs:
with self.subTest():
self.spark.conf.set(ssl_conf_key, ssl_conf_value)
self.spark.conf.set(pytorch_conf_key, pytorch_conf_value)
distributor = TorchDistributor(1, True, False)
distributor._check_encryption()
def test_encryption_fails(self) -> None:
# this is the only combination that should fail
inputs = [("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "false")]
for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in inputs:
with self.subTest():
with self.assertRaisesRegex(Exception, "encryption"):
self.spark.conf.set(ssl_conf_key, ssl_conf_value)
self.spark.conf.set(pytorch_conf_key, pytorch_conf_value)
distributor = TorchDistributor(1, True, False)
distributor._check_encryption()
def test_get_num_tasks_fails(self) -> None:
inputs = [1, 5, 4]
# This is when the conf isn't set and we request GPUs
for num_processes in inputs:
with self.subTest():
with self.assertRaisesRegex(RuntimeError, "driver"):
TorchDistributor(num_processes, True, True)
with self.assertRaisesRegex(RuntimeError, "unset"):
TorchDistributor(num_processes, False, True)
def test_execute_command(self) -> None:
"""Test that run command runs the process and logs are written correctly"""
with patch_stdout() as output:
stdout_command = ["echo", "hello_stdout"]
TorchDistributor._execute_command(stdout_command)
self.assertIn(
"hello_stdout", output.getvalue().strip(), "hello_stdout should print to stdout"
)
with patch_stdout() as output:
stderr_command = ["bash", "-c", "echo hello_stderr >&2"]
TorchDistributor._execute_command(stderr_command)
self.assertIn(
"hello_stderr", output.getvalue().strip(), "hello_stderr should print to stdout"
)
# include command in the exception message
with self.assertRaisesRegex(RuntimeError, "exit 1"):
error_command = ["bash", "-c", "exit 1"]
TorchDistributor._execute_command(error_command)
with self.assertRaisesRegex(RuntimeError, "abcdef"):
error_command = ["bash", "-c", "'abc''def'"]
TorchDistributor._execute_command(error_command)
def test_create_torchrun_command(self) -> None:
train_path = "train.py"
args_string = ["1", "3"]
local_mode_input_params = {"num_processes": 4, "local_mode": True}
expected_local_mode_output = [
sys.executable,
"-m",
"pyspark.ml.torch.torch_run_process_wrapper",
"--standalone",
"--nnodes=1",
"--nproc_per_node=4",
"train.py",
"1",
"3",
]
self.assertEqual(
TorchDistributor._create_torchrun_command(
local_mode_input_params, train_path, *args_string
),
expected_local_mode_output,
)
distributed_mode_input_params = {"num_processes": 4, "local_mode": False}
input_env_vars = {"MASTER_ADDR": "localhost", "MASTER_PORT": "9350", "RANK": "3"}
args_number = [1, 3] # testing conversion to strings
self.setup_env_vars(input_env_vars)
expected_distributed_mode_output = [
sys.executable,
"-m",
"pyspark.ml.torch.torch_run_process_wrapper",
"--nnodes=4",
"--node_rank=3",
"--rdzv_endpoint=localhost:9350",
"--rdzv_id=0",
"--nproc_per_node=1",
"train.py",
"1",
"3",
]
self.assertEqual(
TorchDistributor._create_torchrun_command(
distributed_mode_input_params, train_path, *args_number
),
expected_distributed_mode_output,
)
self.delete_env_vars(input_env_vars)
@patch.dict(
os.environ,
{
"CUDA_VISIBLE_DEVICES": "0,1,2,3",
"MASTER_ADDR": "11.22.33.44",
"MASTER_PORT": "6677",
"RANK": "1",
},
)
def test_multi_gpu_node_get_torchrun_args(self):
torchrun_args, processes_per_node = TorchDistributor._get_torchrun_args(False, 8)
self.assertEqual(
torchrun_args,
["--nnodes=2", "--node_rank=1", "--rdzv_endpoint=11.22.33.44:6677", "--rdzv_id=0"],
)
self.assertEqual(processes_per_node, 4)
@unittest.skipIf(not have_torch, torch_requirement_message)
class TorchDistributorBaselineUnitTests(TorchDistributorBaselineUnitTestsMixin, unittest.TestCase):
@classmethod
def setUpClass(cls):
conf = SparkConf()
sc = SparkContext("local[4]", conf=conf)
cls.spark = SparkSession(sc)
@classmethod
def tearDownClass(cls):
cls.spark.stop()
class TorchDistributorLocalUnitTestsMixin:
def setup_env_vars(self, input_map: Dict[str, str]) -> None:
for key, value in input_map.items():
os.environ[key] = value
def delete_env_vars(self, input_map: Dict[str, str]) -> None:
for key in input_map.keys():
del os.environ[key]
def test_get_num_tasks_locally(self) -> None:
succeeds = [1, 2]
fails = [4, 8]
for num_processes in succeeds:
with self.subTest():
expected_output = num_processes
distributor = TorchDistributor(num_processes, True, True)
self.assertEqual(distributor._get_num_tasks(), expected_output)
for num_processes in fails:
with self.subTest():
with self.assertLogs("TorchDistributor", level="WARNING") as log:
distributor = TorchDistributor(num_processes, True, True)
self.assertEqual(len(log.records), 1)
self.assertEqual(distributor.num_processes, 3)
def test_get_gpus_owned_local(self) -> None:
addresses = ["0", "1", "2"]
self.assertEqual(_get_gpus_owned(self.spark), addresses)
env_vars = {"CUDA_VISIBLE_DEVICES": "3,4,5"}
self.setup_env_vars(env_vars)
self.assertEqual(_get_gpus_owned(self.spark), ["3", "4", "5"])
self.delete_env_vars(env_vars)
def _get_inputs_for_test_local_training_succeeds(self):
return [
("0,1,2", 3, True, "0,1,2"),
("0,1,2", 2, False, "0,1,2"),
(None, 3, False, "NONE"),
]
def test_local_training_succeeds(self) -> None:
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
inputs = self._get_inputs_for_test_local_training_succeeds()
for i, (cuda_env_var, num_processes, use_gpu, expected) in enumerate(inputs):
with self.subTest(f"subtest: {i + 1}"):
# setup
if cuda_env_var:
self.setup_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var})
dist = TorchDistributor(num_processes, True, use_gpu)
dist._run_training_on_pytorch_file = lambda *args: os.environ.get(
CUDA_VISIBLE_DEVICES, "NONE"
)
output = dist._run_local_training(
dist._run_training_on_pytorch_file, "train.py", None
)
self.assertEqual(
sorted(expected.split(",")),
sorted(output.split(",")),
)
# cleanup
if cuda_env_var:
self.delete_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var})
def test_local_file_with_pytorch(self) -> None:
test_file_path = "python/test_support/pytorch_training_test_file.py"
learning_rate_str = "0.01"
TorchDistributor(num_processes=2, local_mode=True, use_gpu=False).run(
test_file_path, learning_rate_str
)
@unittest.skipIf(
sys.version_info > (3, 12), "SPARK-46078: Fails with dev torch with Python 3.12"
)
def test_end_to_end_run_locally(self) -> None:
train_fn = create_training_function(self.mnist_dir_path)
output = TorchDistributor(num_processes=2, local_mode=True, use_gpu=False).run(
train_fn, 0.001
)
self.assertEqual(output, "success" * 4096)
# @unittest.skipIf(not have_torch, torch_requirement_message)
# TODO(SPARK-50864): Re-enable this test after fixing the slowness
@unittest.skip("Disabled due to slowness")
class TorchDistributorLocalUnitTests(TorchDistributorLocalUnitTestsMixin, unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.gpu_discovery_script_file_name, cls.mnist_dir_path = set_up_test_dirs()
conf = SparkConf()
for k, v in get_local_mode_conf().items():
conf = conf.set(k, v)
conf = conf.set(
"spark.driver.resource.gpu.discoveryScript", cls.gpu_discovery_script_file_name
)
sc = SparkContext("local-cluster[2,2,512]", cls.__name__, conf=conf)
cls.spark = SparkSession(sc)
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.mnist_dir_path)
os.unlink(cls.gpu_discovery_script_file_name)
cls.spark.stop()
# @unittest.skipIf(not have_torch, torch_requirement_message)
# TODO(SPARK-50864): Re-enable this test after fixing the slowness
@unittest.skip("Disabled due to slowness")
class TorchDistributorLocalUnitTestsII(TorchDistributorLocalUnitTestsMixin, unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.gpu_discovery_script_file_name, cls.mnist_dir_path = set_up_test_dirs()
conf = SparkConf()
for k, v in get_local_mode_conf().items():
conf = conf.set(k, v)
conf = conf.set(
"spark.driver.resource.gpu.discoveryScript", cls.gpu_discovery_script_file_name
)
sc = SparkContext("local[4]", cls.__name__, conf=conf)
cls.spark = SparkSession(sc)
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.mnist_dir_path)
os.unlink(cls.gpu_discovery_script_file_name)
cls.spark.stop()
class TorchDistributorDistributedUnitTestsMixin:
def test_dist_training_succeeds(self) -> None:
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
inputs = [
("0,1,2", 2, True, "0"),
]
for i, (_, num_processes, use_gpu, expected) in enumerate(inputs):
with self.subTest(f"subtest: {i + 1}"):
dist = TorchDistributor(num_processes, False, use_gpu)
dist._run_training_on_pytorch_file = lambda *args: os.environ.get(
CUDA_VISIBLE_DEVICES, "NONE"
)
self.assertEqual(
expected,
dist._run_distributed_training(
dist._run_training_on_pytorch_file,
"...",
TorchDistributor._run_training_on_pytorch_file,
None,
),
)
def test_get_num_tasks_distributed(self) -> None:
inputs = [(1, 8, 8), (2, 8, 4), (3, 8, 3)]
for spark_conf_value, num_processes, expected_output in inputs:
with self.subTest():
self.spark.conf.set("spark.task.resource.gpu.amount", str(spark_conf_value))
distributor = TorchDistributor(num_processes, False, True)
self.assertEqual(distributor._get_num_tasks(), expected_output)
self.spark.conf.set("spark.task.resource.gpu.amount", "1")
def test_distributed_file_with_pytorch(self) -> None:
test_file_path = "python/test_support/pytorch_training_test_file.py"
learning_rate_str = "0.01"
TorchDistributor(num_processes=2, local_mode=False, use_gpu=False).run(
test_file_path, learning_rate_str
)
def test_end_to_end_run_distributedly(self) -> None:
train_fn = create_training_function(self.mnist_dir_path)
output = TorchDistributor(num_processes=2, local_mode=False, use_gpu=False).run(
train_fn, 0.001
)
self.assertEqual(output, "success" * 4096)
# @unittest.skipIf(not have_torch, torch_requirement_message)
# TODO(SPARK-50864): Re-enable this test after fixing the slowness
@unittest.skip("Disabled due to slowness")
class TorchDistributorDistributedUnitTests(
TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
):
@classmethod
def setUpClass(cls):
cls.gpu_discovery_script_file_name, cls.mnist_dir_path = set_up_test_dirs()
conf = SparkConf()
for k, v in get_distributed_mode_conf().items():
conf = conf.set(k, v)
conf = conf.set(
"spark.worker.resource.gpu.discoveryScript", cls.gpu_discovery_script_file_name
)
sc = SparkContext("local-cluster[2,2,512]", cls.__name__, conf=conf)
cls.spark = SparkSession(sc)
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.mnist_dir_path)
os.unlink(cls.gpu_discovery_script_file_name)
cls.spark.stop()
class TorchWrapperUnitTestsMixin:
def test_clean_and_terminate(self) -> None:
def kill_task(task: "subprocess.Popen") -> None:
time.sleep(1)
clean_and_terminate(task)
command = [sys.executable, "-c", '"import time; time.sleep(20)"']
task = subprocess.Popen(command)
t = threading.Thread(target=kill_task, args=(task,))
t.start()
time.sleep(2)
self.assertEqual(task.poll(), 0) # implies task ended
@patch("pyspark.ml.torch.torch_run_process_wrapper.clean_and_terminate")
def test_check_parent_alive(self, mock_clean_and_terminate: Callable) -> None:
command = [sys.executable, "-c", '"import time; time.sleep(2)"']
task = subprocess.Popen(command)
t = threading.Thread(target=check_parent_alive, args=(task,), daemon=True)
t.start()
time.sleep(2)
self.assertEqual(mock_clean_and_terminate.call_count, 0)
@unittest.skipIf(not have_torch, torch_requirement_message)
class TorchWrapperUnitTests(TorchWrapperUnitTestsMixin, unittest.TestCase):
pass
if __name__ == "__main__":
from pyspark.testing import main
main()