blob: 260ac12d8d0ce32a9bd593ce6251e535ea738a3d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
"""Tests for NCCL/RCCL"""
import tempfile
import numpy as np
import pytest
import tvm
import tvm.testing
from tvm import dlight as dl
from tvm import get_global_func
from tvm import relax as rx
from tvm.runtime import disco as di
from tvm.runtime.vm import VirtualMachine
from tvm.script import relax as R
_all_session_kinds = [di.ThreadedSession, di.ProcessSession]
_ccl = [get_global_func("runtime.disco.compiled_ccl")()]
def create_device_target(ccl):
if ccl == "nccl":
dev = tvm.cuda(0)
else:
dev = tvm.rocm(0)
target = tvm.target.Target.from_device(dev)
return (dev, target)
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_init(session_kind, ccl):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_allreduce(session_kind, ccl):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)
array_1 = np.arange(12, dtype="float32").reshape(3, 4)
array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
d_array = sess.empty((3, 4), "float32")
d_array.debug_copy_from(0, array_1)
d_array.debug_copy_from(1, array_2)
for op, np_op in [ # pylint: disable=invalid-name
("sum", np.add),
("prod", np.multiply),
("min", np.minimum),
("max", np.maximum),
("avg", lambda a, b: (a + b) * 0.5),
]:
dst_array = sess.empty((3, 4), "float32")
sess.allreduce(d_array, dst_array, op=op)
result = dst_array.debug_get_from_remote(0).numpy()
expected = np_op(array_1, array_2)
np.testing.assert_equal(result, expected)
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_group_allreduce(session_kind, ccl):
devices = [0, 1, 2, 3]
sess = session_kind(num_workers=len(devices), num_groups=2)
sess.init_ccl(ccl, *devices)
array_1 = np.arange(12, dtype="float32").reshape(3, 4)
array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
array_3 = np.arange(30, dtype="float32").reshape(5, 6)
array_4 = np.arange(start=1, stop=-29, step=-1, dtype="float32").reshape(5, 6)
d_array_1 = sess.empty((3, 4), "float32")
d_array_2 = sess.empty((5, 6), "float32")
d_array_1.debug_copy_from(0, array_1)
d_array_1.debug_copy_from(1, array_2)
d_array_2.debug_copy_from(2, array_3)
d_array_2.debug_copy_from(3, array_4)
for op, np_op in [ # pylint: disable=invalid-name
("sum", np.add),
("prod", np.multiply),
("min", np.minimum),
("max", np.maximum),
("avg", lambda a, b: (a + b) * 0.5),
]:
dst_array_1 = sess.empty((3, 4), "float32")
dst_array_2 = sess.empty((5, 6), "float32")
sess.allreduce(d_array_1, dst_array_1, op=op, in_group=True)
sess.allreduce(d_array_2, dst_array_2, op=op, in_group=True)
result_1 = dst_array_1.debug_get_from_remote(0).numpy()
result_2 = dst_array_2.debug_get_from_remote(2).numpy()
expected_1 = np_op(array_1, array_2)
expected_2 = np_op(array_3, array_4)
np.testing.assert_equal(result_1, expected_1)
np.testing.assert_equal(result_2, expected_2)
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_allgather(session_kind, ccl):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)
array = np.arange(36, dtype="float32")
d_src = sess.empty((3, 3, 2), "float32")
d_dst = sess.empty((3, 4, 3), "float32")
d_src.debug_copy_from(0, array[:18])
d_src.debug_copy_from(1, array[18:])
sess.allgather(d_src, d_dst)
np.testing.assert_equal(
d_dst.debug_get_from_remote(0).numpy(),
array.reshape(3, 4, 3),
)
np.testing.assert_equal(
d_dst.debug_get_from_remote(1).numpy(),
array.reshape(3, 4, 3),
)
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_group_allgather(session_kind, ccl):
devices = [0, 1, 2, 3]
sess = session_kind(num_workers=len(devices), num_groups=2)
sess.init_ccl(ccl, *devices)
array_1 = np.arange(36, dtype="float32")
array_2 = np.arange(48, dtype="float32")
d_src_1 = sess.empty((3, 3, 2), "float32")
d_dst_1 = sess.empty((3, 4, 3), "float32")
d_src_2 = sess.empty((2, 4, 3), "float32")
d_dst_2 = sess.empty((2, 6, 4), "float32")
d_src_1.debug_copy_from(0, array_1[:18])
d_src_1.debug_copy_from(1, array_1[18:])
d_src_2.debug_copy_from(2, array_2[:24])
d_src_2.debug_copy_from(3, array_2[24:])
sess.allgather(d_src_1, d_dst_1, in_group=True)
sess.allgather(d_src_2, d_dst_2, in_group=True)
np.testing.assert_equal(
d_dst_1.debug_get_from_remote(0).numpy(),
array_1.reshape(3, 4, 3),
)
np.testing.assert_equal(
d_dst_1.debug_get_from_remote(1).numpy(),
array_1.reshape(3, 4, 3),
)
np.testing.assert_equal(
d_dst_2.debug_get_from_remote(2).numpy(),
array_2.reshape(2, 6, 4),
)
np.testing.assert_equal(
d_dst_2.debug_get_from_remote(3).numpy(),
array_2.reshape(2, 6, 4),
)
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
@pytest.mark.parametrize("use_explicit_output", [True, False])
def test_broadcast(session_kind, ccl, use_explicit_output):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)
array = np.arange(12, dtype="float32").reshape(3, 4)
if use_explicit_output:
src_array = sess.empty((3, 4), "float32", worker0_only=True)
src_array.debug_copy_from(0, array)
dst_array = sess.empty((3, 4), "float32")
sess.broadcast_from_worker0(src_array, dst_array)
else:
dst_array = sess.broadcast(array)
result = dst_array.debug_get_from_remote(1).numpy()
np.testing.assert_equal(result, array)
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_group_broadcast(session_kind, ccl):
devices = [0, 1, 2, 3]
sess = session_kind(num_workers=len(devices), num_groups=2)
sess.init_ccl(ccl, *devices)
array_1 = np.arange(12, dtype="float32").reshape(3, 4)
array_2 = np.multiply(array_1, -1)
src_array = sess.empty((3, 4), "float32", worker0_only=True, in_group=True)
src_array.debug_copy_from(0, array_1)
src_array.debug_copy_from(2, array_2)
dst_array = sess.empty((3, 4), "float32")
sess.broadcast_from_worker0(src_array, dst_array)
result_1 = dst_array.debug_get_from_remote(1).numpy()
np.testing.assert_equal(result_1, array_1)
result_3 = dst_array.debug_get_from_remote(3).numpy()
np.testing.assert_equal(result_3, array_2)
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
@pytest.mark.parametrize("use_explicit_output", [True, False])
def test_scatter(session_kind, ccl, use_explicit_output, capfd):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)
array = np.arange(36, dtype="float32").reshape(2, 6, 3)
if use_explicit_output:
d_src = sess.empty((2, 6, 3), "float32", worker0_only=True)
d_dst = sess.empty((6, 3), "float32")
d_src.debug_copy_from(0, array)
sess.scatter_from_worker0(d_src, d_dst)
else:
d_dst = sess.scatter(array)
np.testing.assert_equal(
d_dst.debug_get_from_remote(0).numpy(),
array[0, :, :],
)
np.testing.assert_equal(
d_dst.debug_get_from_remote(1).numpy(),
array[1, :, :],
)
captured = capfd.readouterr()
assert (
not captured.err
), "No warning messages should be generated from disco.Session.scatter_from_worker0"
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_group_scatter(session_kind, ccl, capfd):
devices = [0, 1, 2, 3]
sess = session_kind(num_workers=len(devices), num_groups=2)
sess.init_ccl(ccl, *devices)
array_1 = np.arange(36, dtype="float32").reshape(2, 6, 3)
array_2 = np.multiply(array_1, -1)
d_src = sess.empty((2, 6, 3), "float32", worker0_only=True, in_group=True)
d_src.debug_copy_from(0, array_1)
d_src.debug_copy_from(2, array_2)
d_dst = sess.empty((6, 3), "float32")
sess.scatter_from_worker0(d_src, d_dst)
np.testing.assert_equal(
d_dst.debug_get_from_remote(0).numpy(),
array_1[0, :, :],
)
np.testing.assert_equal(
d_dst.debug_get_from_remote(1).numpy(),
array_1[1, :, :],
)
np.testing.assert_equal(
d_dst.debug_get_from_remote(2).numpy(),
array_2[0, :, :],
)
np.testing.assert_equal(
d_dst.debug_get_from_remote(3).numpy(),
array_2[1, :, :],
)
captured = capfd.readouterr()
assert (
not captured.err
), "No warning messages should be generated from disco.Session.scatter_from_worker0"
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_scatter_with_implicit_reshape(session_kind, ccl, capfd):
"""Scatter may perform an implicit reshape
Scattering elements to the workers requires the total number of
elements to be divisible by the number of workers. It does not
necessarily correspond to scattering across the outermost
dimension. Here, the number of workers (2) and the outermost
dimension (3) are not divisible, but the scatter may still be
performed.
This is only allowed when the caller explicitly uses the
`sess.scatter_from_worker0` method, and is not allowed in
`sess.scatter` method. Because the `sess.scatter` method may
perform an allocation on the disco workers, it requires that the
scatter occur across the outermost dimension.
"""
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)
array = np.arange(36, dtype="float32").reshape(3, 4, 3)
d_src = sess.empty((3, 4, 3), "float32", worker0_only=True)
d_dst = sess.empty((3, 3, 2), "float32")
d_src.debug_copy_from(0, array)
sess.scatter_from_worker0(d_src, d_dst)
np.testing.assert_equal(
d_dst.debug_get_from_remote(0).numpy(),
array.flat[:18].reshape(3, 3, 2),
)
np.testing.assert_equal(
d_dst.debug_get_from_remote(1).numpy(),
array.flat[18:].reshape(3, 3, 2),
)
captured = capfd.readouterr()
assert (
not captured.err
), "No warning messages should be generated from disco.Session.scatter_from_worker0"
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_gather(session_kind, ccl, capfd):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)
array = np.arange(36, dtype="float32")
d_src = sess.empty((3, 3, 2), "float32")
d_dst = sess.empty((3, 4, 3), "float32", worker0_only=True)
d_src.debug_copy_from(0, array[:18])
d_src.debug_copy_from(1, array[18:])
sess.gather_to_worker0(d_src, d_dst)
np.testing.assert_equal(
d_dst.debug_get_from_remote(0).numpy(),
array.reshape(3, 4, 3),
)
captured = capfd.readouterr()
assert (
not captured.err
), "No warning messages should be generated from disco.Session.gather_to_worker0"
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_group_gather(session_kind, ccl, capfd):
devices = [0, 1, 2, 3]
sess = session_kind(num_workers=len(devices), num_groups=2)
sess.init_ccl(ccl, *devices)
array_1 = np.arange(36, dtype="float32")
array_2 = np.multiply(array_1, -1)
d_src = sess.empty((3, 3, 2), "float32")
d_dst = sess.empty((3, 4, 3), "float32", worker0_only=True, in_group=True)
d_src.debug_copy_from(0, array_1[:18])
d_src.debug_copy_from(1, array_1[18:])
d_src.debug_copy_from(2, array_2[:18])
d_src.debug_copy_from(3, array_2[18:])
sess.gather_to_worker0(d_src, d_dst)
np.testing.assert_equal(
d_dst.debug_get_from_remote(0).numpy(),
array_1.reshape(3, 4, 3),
)
np.testing.assert_equal(
d_dst.debug_get_from_remote(2).numpy(),
array_2.reshape(3, 4, 3),
)
captured = capfd.readouterr()
assert (
not captured.err
), "No warning messages should be generated from disco.Session.gather_to_worker0"
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_send_to_next_group_receive_from_prev_group(session_kind, ccl):
devices = [0, 1, 2, 3]
sess = session_kind(num_workers=len(devices), num_groups=2)
sess.init_ccl(ccl, *devices)
array_1 = np.arange(12, dtype="float32").reshape(3, 4)
array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
d_array = sess.empty((3, 4), "float32")
d_array.debug_copy_from(0, array_1)
d_array.debug_copy_from(1, array_2)
sess.get_global_func("runtime.disco." + ccl + ".test_send_to_next_group_recv_from_prev_group")(
d_array
)
result_1 = d_array.debug_get_from_remote(2).numpy()
result_2 = d_array.debug_get_from_remote(3).numpy()
np.testing.assert_equal(result_1, array_1)
np.testing.assert_equal(result_2, array_2)
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_worker2_send_to_worker0(session_kind, ccl):
devices = [0, 1, 2, 3]
sess = session_kind(num_workers=len(devices), num_groups=2)
sess.init_ccl(ccl, *devices)
array = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
d_array = sess.empty((3, 4), "float32")
d_array.debug_copy_from(2, array)
sess.get_global_func("runtime.disco." + ccl + ".test_worker2_sends_to_worker0")(d_array)
result = d_array.debug_get_from_remote(0).numpy()
np.testing.assert_equal(result, array)
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_mlp(session_kind, ccl): # pylint: disable=too-many-locals
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)
# pylint: disable=invalid-name
@tvm.script.ir_module
class MLP: # pylint: disable=too-few-public-methods
@R.function
def main(
x: R.Tensor((128, 128), "float32"),
W1: R.Tensor((128, 128), "float32"),
W2: R.Tensor((128, 128), "float32"),
) -> R.Tensor((128, 128), "float32"):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
lv0: R.Tensor((128, 128), "float32") = R.matmul(x, W1)
lv1: R.Tensor((128, 128), "float32") = R.nn.gelu(lv0)
lv2: R.Tensor((128, 128), "float32") = R.matmul(lv1, W2)
R.output(lv2)
return lv2
@tvm.script.ir_module
class ShardedMLP: # pylint: disable=too-few-public-methods
@R.function
def main(
x: R.Tensor((128, 128), "float32"),
W1: R.Tensor((128, 64), "float32"), # shard along axis 1
W2: R.Tensor((64, 128), "float32"), # shard along axis 0
) -> R.Tensor((128, 128), "float32"):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
broadcast_x: R.Tensor((128, 128), "float32") = R.ccl.broadcast_from_worker0(x)
lv0: R.Tensor((128, 64), "float32") = R.matmul(broadcast_x, W1)
lv1: R.Tensor((128, 64), "float32") = R.nn.gelu(lv0)
lv2: R.Tensor((128, 128), "float32") = R.matmul(lv1, W2)
lv3: R.Tensor((128, 128), "float32") = R.ccl.allreduce(lv2, "sum")
R.output(lv3)
return lv3
# pylint: enable=invalid-name
dev, target = create_device_target(ccl)
def relax_build(mod, target):
with target:
mod = rx.get_pipeline("zero")(mod) # pylint: disable=no-value-for-parameter
mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable
dl.gpu.Matmul(),
dl.gpu.GEMV(),
dl.gpu.Reduction(),
dl.gpu.GeneralReduction(),
dl.gpu.Fallback(),
)(mod)
return tvm.compile(mod, target=target)
# pylint: disable=invalid-name
X = np.random.randn(128, 128).astype("float32")
W1 = np.random.randn(128, 128).astype("float32")
W2 = np.random.randn(128, 128).astype("float32")
Y_expected = VirtualMachine(relax_build(MLP, target), device=dev)["main"](
tvm.runtime.tensor(X, device=dev),
tvm.runtime.tensor(W1, device=dev),
tvm.runtime.tensor(W2, device=dev),
).numpy()
with tempfile.TemporaryDirectory() as tmpdir:
path = tmpdir + "/test.so"
relax_build(ShardedMLP, target).export_library(path)
mod = sess.load_vm_module(path)
d_X = sess.empty((128, 128), "float32")
d_W1 = sess.empty((128, 64), "float32")
d_W2 = sess.empty((64, 128), "float32")
d_X.debug_copy_from(0, X)
d_W1.debug_copy_from(0, W1[:, :64])
d_W1.debug_copy_from(1, W1[:, 64:])
d_W2.debug_copy_from(0, W2[:64, :])
d_W2.debug_copy_from(1, W2[64:, :])
d_Y = mod["main"](d_X, d_W1, d_W2)
Y_result = tvm.runtime.empty((128, 128), "float32", device=dev)
sess.copy_from_worker_0(Y_result, d_Y)
sess.sync_worker_0()
Y_result = Y_result.numpy()
# pylint: enable=invalid-name
np.testing.assert_allclose(Y_result, Y_expected, rtol=1e-4, atol=1e-4)
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_attention(session_kind, ccl): # pylint: disable=too-many-locals,too-many-statements
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)
# pylint: disable=invalid-name
@tvm.script.ir_module
class Attention: # pylint: disable=too-few-public-methods
@R.function
def main( # pylint: disable=too-many-locals
x: R.Tensor((1, 10, 128), "float32"),
Wq: R.Tensor((128, 512), "float32"),
Wk: R.Tensor((128, 512), "float32"),
Wv: R.Tensor((128, 512), "float32"),
Wo: R.Tensor((512, 128), "float32"),
) -> R.Tensor((128, 128), "float32"):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
# q
lv0: R.Tensor((1, 10, 512), "float32") = R.matmul(x, Wq)
lv1: R.Tensor((1, 10, 8, 64), "float32") = R.reshape(lv0, [1, 10, 8, 64])
lv2: R.Tensor((1, 8, 10, 64), "float32") = R.permute_dims(lv1, [0, 2, 1, 3])
# k
lv3: R.Tensor((1, 10, 512), "float32") = R.matmul(x, Wk)
lv4: R.Tensor((1, 10, 8, 64), "float32") = R.reshape(lv3, [1, 10, 8, 64])
lv5: R.Tensor((1, 8, 10, 64), "float32") = R.permute_dims(lv4, [0, 2, 1, 3])
# v
lv6: R.Tensor((1, 10, 512), "float32") = R.matmul(x, Wv)
lv7: R.Tensor((1, 10, 8, 64), "float32") = R.reshape(lv6, [1, 10, 8, 64])
lv8: R.Tensor((1, 8, 10, 64), "float32") = R.permute_dims(lv7, [0, 2, 1, 3])
# softmax(q @ k / sqrt(dk))
lv9: R.Tensor((1, 8, 64, 10), "float32") = R.permute_dims(lv5, [0, 1, 3, 2])
lv10: R.Tensor((1, 8, 10, 10), "float32") = R.matmul(lv2, lv9)
lv11: R.Tensor((1, 8, 10, 10), "float32") = R.multiply(
lv10, R.const(1 / 8, "float32")
)
lv12: R.Tensor((1, 8, 10, 10), "float32") = R.nn.softmax(lv11, axis=-1)
# attn_weight @ v
lv13: R.Tensor((1, 8, 10, 64), "float32") = R.matmul(lv12, lv8)
lv14: R.Tensor((1, 10, 8, 64), "float32") = R.permute_dims(lv13, [0, 2, 1, 3])
lv15: R.Tensor((1, 10, 512), "float32") = R.reshape(lv14, [1, 10, 512])
# attn_output @ o
lv16: R.Tensor((1, 10, 128), "float32") = R.matmul(lv15, Wo)
R.output(lv16)
return lv16
@tvm.script.ir_module
class ShardedAttention: # pylint: disable=too-few-public-methods
@R.function
def main( # pylint: disable=too-many-locals
x: R.Tensor((1, 10, 128), "float32"),
Wq: R.Tensor((128, 256), "float32"), # shard along axis 1
Wk: R.Tensor((128, 256), "float32"), # shard along axis 1
Wv: R.Tensor((128, 256), "float32"), # shard along axis 1
Wo: R.Tensor((256, 128), "float32"), # shard along axis 0
) -> R.Tensor((128, 128), "float32"):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
broadcast_x: R.Tensor((1, 10, 128), "float32") = R.ccl.broadcast_from_worker0(x)
# q
lv0: R.Tensor((1, 10, 256), "float32") = R.matmul(broadcast_x, Wq)
lv1: R.Tensor((1, 10, 4, 64), "float32") = R.reshape(lv0, [1, 10, 4, 64])
lv2: R.Tensor((1, 4, 10, 64), "float32") = R.permute_dims(lv1, [0, 2, 1, 3])
# k
lv3: R.Tensor((1, 10, 256), "float32") = R.matmul(broadcast_x, Wk)
lv4: R.Tensor((1, 10, 4, 64), "float32") = R.reshape(lv3, [1, 10, 4, 64])
lv5: R.Tensor((1, 4, 10, 64), "float32") = R.permute_dims(lv4, [0, 2, 1, 3])
# v
lv6: R.Tensor((1, 10, 256), "float32") = R.matmul(broadcast_x, Wv)
lv7: R.Tensor((1, 10, 4, 64), "float32") = R.reshape(lv6, [1, 10, 4, 64])
lv8: R.Tensor((1, 4, 10, 64), "float32") = R.permute_dims(lv7, [0, 2, 1, 3])
# softmax(q @ k / sqrt(dk))
lv9: R.Tensor((1, 4, 64, 10), "float32") = R.permute_dims(lv5, [0, 1, 3, 2])
lv10: R.Tensor((1, 4, 10, 10), "float32") = R.matmul(lv2, lv9)
lv11: R.Tensor((1, 4, 10, 10), "float32") = R.multiply(
lv10, R.const(1 / 8, "float32")
)
lv12: R.Tensor((1, 4, 10, 10), "float32") = R.nn.softmax(lv11, axis=-1)
# attn_weight @ v
lv13: R.Tensor((1, 4, 10, 64), "float32") = R.matmul(lv12, lv8)
lv14: R.Tensor((1, 10, 4, 64), "float32") = R.permute_dims(lv13, [0, 2, 1, 3])
lv15: R.Tensor((1, 10, 256), "float32") = R.reshape(lv14, [1, 10, 256])
# attn_output @ o
lv16: R.Tensor((1, 10, 128), "float32") = R.matmul(lv15, Wo)
lv17: R.Tensor((1, 10, 128), "float32") = R.ccl.allreduce(lv16, "sum")
R.output(lv17)
return lv17
# pylint: enable=invalid-name
dev, target = create_device_target(ccl)
def relax_build(mod, target):
with target:
mod = rx.get_pipeline("zero")(mod) # pylint: disable=no-value-for-parameter
mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable
dl.gpu.Matmul(),
dl.gpu.GEMV(),
dl.gpu.Reduction(),
dl.gpu.GeneralReduction(),
dl.gpu.Fallback(),
)(mod)
return tvm.compile(mod, target=target)
# pylint: disable=invalid-name
X = np.random.randn(1, 10, 128).astype("float32")
Wq = np.random.randn(128, 512).astype("float32")
Wk = np.random.randn(128, 512).astype("float32")
Wv = np.random.randn(128, 512).astype("float32")
Wo = np.random.randn(512, 128).astype("float32")
Y_expected = VirtualMachine(relax_build(Attention, target), device=dev)["main"](
tvm.runtime.tensor(X, device=dev),
tvm.runtime.tensor(Wq, device=dev),
tvm.runtime.tensor(Wk, device=dev),
tvm.runtime.tensor(Wv, device=dev),
tvm.runtime.tensor(Wo, device=dev),
).numpy()
with tempfile.TemporaryDirectory() as tmpdir:
path = tmpdir + "/test.so"
relax_build(ShardedAttention, target).export_library(path)
mod = sess.load_vm_module(path)
d_X = sess.empty((1, 10, 128), "float32")
d_Wq = sess.empty((128, 256), "float32")
d_Wk = sess.empty((128, 256), "float32")
d_Wv = sess.empty((128, 256), "float32")
d_Wo = sess.empty((256, 128), "float32")
d_X.debug_copy_from(0, X)
d_Wq.debug_copy_from(0, Wq[:, :256])
d_Wq.debug_copy_from(1, Wq[:, 256:])
d_Wk.debug_copy_from(0, Wk[:, :256])
d_Wk.debug_copy_from(1, Wk[:, 256:])
d_Wv.debug_copy_from(0, Wv[:, :256])
d_Wv.debug_copy_from(1, Wv[:, 256:])
d_Wo.debug_copy_from(0, Wo[:256, :])
d_Wo.debug_copy_from(1, Wo[256:, :])
d_Y = mod["main"](d_X, d_Wq, d_Wk, d_Wv, d_Wo)
Y_result = tvm.runtime.empty((1, 10, 128), "float32", device=dev)
sess.copy_from_worker_0(Y_result, d_Y)
sess.sync_worker_0()
Y_result = Y_result.numpy()
# pylint: enable=invalid-name
np.testing.assert_allclose(Y_result, Y_expected, rtol=1e-3, atol=1e-3)
if __name__ == "__main__":
tvm.testing.main()