blob: d9976e05e50b5877de253262d32610ea7fe2ec13 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Basic tests for a Disco nvshmem support"""
# pylint: disable=missing-docstring
import tempfile
import numpy as np
import pytest
import subprocess
import threading
import sys
from multiprocessing import Process
from typing import Any, Callable, List
from tvm.script import tir as T
import tvm
import tvm.testing
from tvm.runtime import ShapeTuple
from tvm.runtime import disco as di
from tvm.exec import disco_worker as _ # pylint: disable=unused-import
_SOCKET_SESSION_TESTER = None
def get_free_port():
import socket
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
port = s.getsockname()[1]
s.close()
return port
class SocketSessionTester:
def __init__(self, num_workers):
num_nodes = 2
num_groups = 1
assert num_workers % num_nodes == 0
num_workers_per_node = num_workers // num_nodes
server_host = "localhost"
server_port = get_free_port()
self.sess = None
def start_server():
self.sess = di.SocketSession(
num_nodes, num_workers_per_node, num_groups, server_host, server_port
)
thread = threading.Thread(target=start_server)
thread.start()
cmd = "tvm.exec.disco_remote_socket_session"
self.remote_nodes = []
for _ in range(num_nodes - 1):
self.remote_nodes.append(
subprocess.Popen(
[
"python3",
"-m",
cmd,
server_host,
str(server_port),
str(num_workers_per_node),
],
stdout=sys.stdout,
stderr=sys.stderr,
)
)
thread.join()
def __del__(self):
if self.sess is not None:
self.sess.shutdown()
del self.sess
def create_socket_session(num_workers):
global _SOCKET_SESSION_TESTER
if _SOCKET_SESSION_TESTER is not None:
del _SOCKET_SESSION_TESTER
_SOCKET_SESSION_TESTER = SocketSessionTester(num_workers)
assert _SOCKET_SESSION_TESTER.sess is not None
return _SOCKET_SESSION_TESTER.sess
def test_nvshmem_init_finalize(session_kind: di.Session, num_workers: int):
if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None:
return
sess = session_kind(num_workers=num_workers)
f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
uid = f_init_nvshmem_uid()
init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
init_dfunc(uid, num_workers, 0)
sess.sync_worker_0()
finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
finalize_dfunc()
sess.sync_worker_0()
def test_nvshmem_empty(session_kind: di.Session, num_workers: int):
if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None:
return
device = tvm.cuda()
sess = session_kind(num_workers=num_workers)
f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
uid = f_init_nvshmem_uid()
init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
init_dfunc(uid, num_workers, 0)
sess.sync_worker_0()
empty_dfunc = sess.get_global_func("runtime.disco.nvshmem.empty")
a = empty_dfunc(ShapeTuple((32, 64)), "float32", device)
b = empty_dfunc(ShapeTuple((64, 32)), "float32", device)
sess.sync_worker_0()
finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
finalize_dfunc()
sess.sync_worker_0()
def test_nvshmem_compile():
if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None:
return
num_workers = 4
sess = di.ProcessSession(num_workers=num_workers)
f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
uid = f_init_nvshmem_uid()
init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
init_dfunc(uid, num_workers, 0)
sess.sync_worker_0()
@T.prim_func
def main(A: T.Buffer((8, 16), "float32"), B: T.Buffer((16, 8), "float32")):
for i in T.thread_binding(T.int64(8), thread="threadIdx.y"):
for j in T.thread_binding(T.int64(16), thread="threadIdx.x"):
with T.block("T_transpose"):
v0 = T.axis.spatial(T.int64(8), i)
v1 = T.axis.spatial(T.int64(16), j)
T.reads(A[v0, v1])
T.writes(B[v1, v0])
B[v1, v0] = A[v0, v1]
with tempfile.TemporaryDirectory() as tmpdir:
path = tmpdir + "/test.so"
A_np = np.arange(8 * 16).astype("float32").reshape([8, 16])
B_np = np.zeros((16, 8), dtype="float32")
A_array = sess.empty(A_np.shape, "float32")
B_array = sess.empty(B_np.shape, "float32")
A_array.debug_copy_from(0, A_np)
target = tvm.target.Target("cuda")
tvm.compile(main, target=target).export_library(path)
mod = sess.load_vm_module(path)
mod["main"](A_array, B_array)
B_res = B_array.debug_get_from_remote(0).numpy()
np.testing.assert_equal(B_res, A_np.T)
# sync all workers to make sure the temporary files are cleaned up after all workers
# finish the execution
sess._sync_all()
finalize_dfunc = sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
finalize_dfunc()
sess.sync_worker_0()
if __name__ == "__main__":
# After the first call to `nvshmem_init`, a subsequent call to `nvshmem_init`
# or `nvshmem_init_thread` in the same program results in undefined behavior.
# So we always create a new process to run the test. Then no repeated nvshmem
# init happens in the same process, since the worker0 may share the same process.
for session_kind in [create_socket_session, di.ProcessSession]:
for num_workers in [2, 4]:
for test_func in [test_nvshmem_init_finalize, test_nvshmem_empty]:
p = Process(target=test_func, args=[session_kind, num_workers])
p.start()
p.join()
# testing compilation flow
p = Process(target=test_nvshmem_compile)
p.start()
p.join()