blob: 8e78058331a5409960618dcc0e989a70179255b4 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test sharded loader"""
# pylint: disable=missing-docstring
import pathlib
import tempfile
import numpy as np
import tvm
import tvm.testing
from tvm.script import relax as R, tir as T
@tvm.testing.requires_nccl
def test_callback():
"""Simulate lazy loading of parameters in a callback
The output of a lazy parameter loading, which would accept a
callback to load the parameters.
"""
@R.function
def transform_params(
rank_arg: R.Prim(value="rank"),
fget_item: R.Callable([R.Object, R.Prim("int64")], R.Object),
):
rank = T.int64()
A = fget_item(R.str("A"), R.prim_value(0))
A = R.match_cast(A, R.Tensor([4, 4], "int32"))
A = R.strided_slice(A, axes=[0], begin=[rank * 2], end=[(rank + 1) * 2])
B = fget_item(R.str("B"), R.prim_value(1))
B = R.match_cast(B, R.Tensor([2, 2], "float32"))
B = R.strided_slice(B, axes=[1], begin=[rank * 1], end=[(rank + 1) * 1])
return (A, B)
pipeline = tvm.ir.transform.Sequential(
[
tvm.relax.transform.LegalizeOps(),
tvm.dlight.ApplyDefaultSchedule(tvm.dlight.gpu.Fallback()),
],
name="pipeline",
)
with tvm.target.Target("cuda"):
mod = tvm.IRModule.from_expr(transform_params)
mod = pipeline(mod)
built = tvm.compile(mod, "cuda")
num_shards = 2
session = tvm.runtime.disco.ProcessSession(num_workers=num_shards)
session.import_python_module("tvm.exec.disco_worker")
session.init_ccl("nccl", *range(num_shards))
worker_device = session.get_global_func("runtime.disco.device")()
worker_id = session.get_global_func("runtime.disco.worker_rank")()
callback_maker = session.get_global_func("tests.disco.test_callback")
fget_item = callback_maker(worker_device)
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = pathlib.Path(temp_dir)
# TODO(Lunderberg): Update `disco.Session.load_vm_module` to
# allow a `tvm.runtime.Module` argument. This would avoid the
# need for a temporary file.
shlib_path = temp_dir.joinpath("libtemp.so")
built.export_library(shlib_path)
vm = session.load_vm_module(shlib_path.as_posix())
transform_params = vm["transform_params"]
params = transform_params(worker_id, fget_item)
# Worker 0 is the same PID as the controlling scope, so
# `debug_get_from_remote(0)` returns the Tensor containing
# the output.
params_gpu0 = params.debug_get_from_remote(0)
assert params_gpu0[0].device == tvm.cuda(0)
assert params_gpu0[1].device == tvm.cuda(0)
np.testing.assert_array_equal(
params_gpu0[0].numpy(),
[
[0, 1, 2, 3],
[4, 5, 6, 7],
],
)
np.testing.assert_array_equal(
params_gpu0[1].numpy(),
[[0], [2]],
)
# Worker 1 is a different PID altogether, so
# `debug_get_from_remote(1)` returns a new Tensor within the
# calling scope's PID.
params_gpu1 = params.debug_get_from_remote(1)
assert params_gpu1[0].device == tvm.cpu()
assert params_gpu1[1].device == tvm.cpu()
np.testing.assert_array_equal(
params_gpu1[0].numpy(),
[
[8, 9, 10, 11],
[12, 13, 14, 15],
],
)
np.testing.assert_array_equal(
params_gpu1[1].numpy(),
[[1], [3]],
)
if __name__ == "__main__":
tvm.testing.main()