blob: 76de7d62724dc092935b65583155c9b25f9a39ba [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Basic tests for a Disco nvshmem support"""
# pylint: disable=missing-docstring
import numpy as np
import pytest
import shutil
import subprocess
import sys
import tempfile
import threading
import multiprocessing
from multiprocessing import Process
from typing import Any, Callable, List
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
import tvm
import tvm.testing
from tvm.runtime import ShapeTuple
from tvm.runtime import disco as di
from tvm.exec import disco_worker as _ # pylint: disable=unused-import
_SOCKET_SESSION_TESTER = None
def get_free_port():
import socket
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
port = s.getsockname()[1]
s.close()
return port
class SocketSessionTester:
def __init__(self, num_workers):
num_nodes = 2
num_groups = 1
assert num_workers % num_nodes == 0
num_workers_per_node = num_workers // num_nodes
server_host = "localhost"
server_port = get_free_port()
self.sess = None
def start_server():
self.sess = di.SocketSession(
num_nodes, num_workers_per_node, num_groups, server_host, server_port
)
thread = threading.Thread(target=start_server)
thread.start()
cmd = "tvm.exec.disco_remote_socket_session"
self.remote_nodes = []
for _ in range(num_nodes - 1):
self.remote_nodes.append(
subprocess.Popen(
[
"python3",
"-m",
cmd,
server_host,
str(server_port),
str(num_workers_per_node),
],
stdout=sys.stdout,
stderr=sys.stderr,
)
)
thread.join()
def __del__(self):
if self.sess is not None:
self.sess.shutdown()
del self.sess
def create_socket_session(num_workers):
global _SOCKET_SESSION_TESTER
if _SOCKET_SESSION_TESTER is not None:
del _SOCKET_SESSION_TESTER
_SOCKET_SESSION_TESTER = SocketSessionTester(num_workers)
assert _SOCKET_SESSION_TESTER.sess is not None
return _SOCKET_SESSION_TESTER.sess
def test_nvshmem_init_finalize(session_kind: di.Session, num_workers: int):
if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None:
return
sess = session_kind(num_workers=num_workers)
f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
uid = f_init_nvshmem_uid()
init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
init_dfunc(uid, num_workers, 0)
sess.sync_worker_0()
finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
finalize_dfunc()
sess.sync_worker_0()
def test_nvshmem_empty(session_kind: di.Session, num_workers: int):
if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None:
return
device = tvm.cuda()
sess = session_kind(num_workers=num_workers)
f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
uid = f_init_nvshmem_uid()
init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
init_dfunc(uid, num_workers, 0)
sess.sync_worker_0()
empty_dfunc = sess.get_global_func("runtime.disco.nvshmem.empty")
a = empty_dfunc(ShapeTuple((32, 64)), "float32", device)
b = empty_dfunc(ShapeTuple((64, 32)), "float32", device)
sess.sync_worker_0()
finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
finalize_dfunc()
sess.sync_worker_0()
def test_nvshmem_compile():
if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None:
return
num_workers = 2
sess = di.ProcessSession(num_workers=num_workers)
f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
uid = f_init_nvshmem_uid()
init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
init_dfunc(uid, num_workers, 0)
sess.sync_worker_0()
@T.prim_func
def main(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), "float32")):
for i in T.thread_binding(T.int64(8), thread="threadIdx.y"):
for j in T.thread_binding(T.int64(16), thread="threadIdx.x"):
with T.sblock("T_transpose"):
v0 = T.axis.spatial(T.int64(8), i)
v1 = T.axis.spatial(T.int64(16), j)
T.reads(A[v0, v1])
T.writes(B[v1, v0])
B[v1, v0] = A[v0, v1]
tmpdir = tempfile.mkdtemp()
try:
path = tmpdir + "/test.so"
A_np = np.arange(8 * 16).astype("float32").reshape([8, 16])
B_np = np.zeros((16, 8), dtype="float32")
A_array = sess.empty(A_np.shape, "float32")
B_array = sess.empty(B_np.shape, "float32")
A_array.debug_copy_from(0, A_np)
target = tvm.target.Target("cuda")
tvm.compile(main, target=target).export_library(path)
mod = sess.load_vm_module(path)
mod["main"](A_array, B_array)
B_res = B_array.debug_get_from_remote(0).numpy()
np.testing.assert_equal(B_res, A_np.T)
# sync all workers to make sure the temporary files are cleaned up after all workers
# finish the execution
sess._sync_all()
finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
finalize_dfunc()
sess.sync_worker_0()
finally:
sess.shutdown()
shutil.rmtree(tmpdir, ignore_errors=True)
NVSHMEM_QUERY_KERNEL_SOURCE = """
#include <nvshmem.h>
extern "C" __global__ void nvshmem_query_kernel(int* my_pe_out, int* n_pes_out) {
my_pe_out[0] = nvshmem_my_pe();
n_pes_out[0] = nvshmem_n_pes();
}
"""
def _test_nvshmem_kernel_compile_impl():
"""Test compiling and running a kernel that calls NVSHMEM functions"""
if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None:
return
num_workers = 2
sess = di.ProcessSession(num_workers=num_workers)
f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
uid = f_init_nvshmem_uid()
init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
init_dfunc(uid, num_workers, 0)
sess.sync_worker_0()
try:
@I.ir_module
class NvshmemQueryModule:
@T.prim_func
def query_pe(
my_pe_out: T.Buffer((1,), "int32"),
n_pes_out: T.Buffer((1,), "int32"),
):
with T.sblock("root"):
T.reads()
T.writes(my_pe_out[0:1], n_pes_out[0:1])
T.call_kernel(
NVSHMEM_QUERY_KERNEL_SOURCE,
((1,), (1,)), # grid=(1,), block=(1,)
my_pe_out.data,
n_pes_out.data,
kernel_name="nvshmem_query_kernel",
)
@R.function
def main() -> R.Tuple(R.Tensor((1,), "int32"), R.Tensor((1,), "int32")):
cls = NvshmemQueryModule
with R.dataflow():
my_pe = R.call_tir(
cls.query_pe,
(),
out_sinfo=[
R.Tensor((1,), "int32"),
R.Tensor((1,), "int32"),
],
)
R.output(my_pe)
return my_pe
tmpdir = tempfile.mkdtemp()
try:
path = tmpdir + "/test_nvshmem_kernel.so"
target = tvm.target.Target("cuda")
tvm.compile(NvshmemQueryModule, target=target).export_library(path)
mod = sess.load_vm_module(path)
result = mod["main"]()
# Verify results from each worker
for worker_id in range(num_workers):
my_pe_result, n_pes_result = result.debug_get_from_remote(worker_id)
my_pe_val = my_pe_result.numpy()[0]
n_pes_val = n_pes_result.numpy()[0]
assert (
my_pe_val == worker_id
), f"Worker {worker_id} reported my_pe={my_pe_val}, expected {worker_id}"
assert (
n_pes_val == num_workers
), f"Worker {worker_id} reported n_pes={n_pes_val}, expected {num_workers}"
# Sync all workers before cleanup
sess._sync_all()
finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
finalize_dfunc()
sess.sync_worker_0()
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
finally:
sess.shutdown()
def test_nvshmem_kernel_compile_nvcc():
"""Test NVSHMEM kernel compilation with nvcc."""
# Since this test runs in a separate process, we can safely set the env var
import os
os.environ["TVM_CUDA_COMPILE_MODE"] = "nvcc"
_test_nvshmem_kernel_compile_impl()
def test_nvshmem_kernel_compile_nvrtc():
"""Test NVSHMEM kernel compilation with nvrtc."""
try:
from cuda.bindings import nvrtc # noqa: F401
except ImportError:
pytest.skip("cuda-python not available, skipping nvrtc test")
# Since this test runs in a separate process, we can safely set the env var
import os
os.environ["TVM_CUDA_COMPILE_MODE"] = "nvrtc"
_test_nvshmem_kernel_compile_impl()
if __name__ == "__main__":
# After the first call to `nvshmem_init`, a subsequent call to `nvshmem_init`
# or `nvshmem_init_thread` in the same program results in undefined behavior.
# So we always create a new process to run the test. Then no repeated nvshmem
# init happens in the same process, since the worker0 may share the same process.
# Use 'spawn' start method to avoid inheriting CUDA state from parent process
# 'fork' (default on Linux) can cause issues with CUDA contexts in child processes
multiprocessing.set_start_method("spawn", force=True)
for session_kind in [create_socket_session, di.ProcessSession]:
for num_workers in [2, 4]:
for test_func in [test_nvshmem_init_finalize, test_nvshmem_empty]:
p = Process(target=test_func, args=[session_kind, num_workers])
p.start()
p.join()
# Ensure the process finished successfully
assert (
p.exitcode == 0
), f"Test {test_func.__name__} failed with exit code {p.exitcode}"
p = Process(target=test_nvshmem_compile)
p.start()
p.join()
assert p.exitcode == 0, f"Test test_nvshmem_compile failed with exit code {p.exitcode}"
for test_func in [test_nvshmem_kernel_compile_nvcc, test_nvshmem_kernel_compile_nvrtc]:
p = Process(target=test_func)
p.start()
p.join()
assert p.exitcode == 0, f"Test {test_func.__name__} failed with exit code {p.exitcode}"