blob: 23250f28aa9f7e538888c56a4d8ccae7a602bd62 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import List, Tuple
import numpy as np
import pytest
import tvm
import tvm.testing
from tvm import relax
from tvm.ir import assert_structural_equal
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import core, modules, spec
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
def test_relu():
@R.function
def forward(
x: R.Tensor((3, 3), dtype="float32"),
_io: R.Object,
) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
relu: R.Tensor((3, 3), dtype="float32") = R.nn.relu(x)
gv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)) = relu, (_io,)
R.output(gv1)
return gv1
mod = modules.ReLU()
tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3, 3), "float32")}}, debug=True)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_silu():
@R.function
def forward(
x: R.Tensor((3, 3), dtype="float32"),
_io: R.Object,
) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
silu: R.Tensor((3, 3), dtype="float32") = R.nn.silu(x)
gv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)) = silu, (_io,)
R.output(gv1)
return gv1
mod = modules.SiLU()
tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3, 3), "float32")}}, debug=True)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_gelu():
@R.function
def forward(
x: R.Tensor((3, 3), dtype="float32"),
_io: R.Object,
) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
gelu: R.Tensor((3, 3), dtype="float32") = R.nn.gelu(x)
gv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)) = gelu, (_io,)
R.output(gv1)
return gv1
mod = modules.GELU()
tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3, 3), "float32")}}, debug=True)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_identity():
@R.function
def forward(
x: R.Tensor((3, 3), dtype="float32"),
_io: R.Object,
) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
gv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)) = x, (_io,)
R.output(gv1)
return gv1
mod = modules.Identity()
tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3, 3), "float32")}}, debug=True)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_linear():
@R.function
def forward(
x: R.Tensor((1, 4), dtype="float32"),
_io: R.Object,
weight: R.Tensor((8, 4), dtype="float32"),
bias: R.Tensor((8,), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 8), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
permute_dims: R.Tensor((4, 8), dtype="float32") = R.permute_dims(weight, axes=None)
matmul: R.Tensor((1, 8), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((1, 8), dtype="float32") = R.add(matmul, bias)
gv1: R.Tuple(R.Tensor((1, 8), dtype="float32"), R.Tuple(R.Object)) = add, (_io,)
R.output(gv1)
return gv1
mod = modules.Linear(4, 8)
tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((1, 4), "float32")}}, debug=True)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_conv1d():
@R.function
def forward(
x: R.Tensor((1, 3, 32), dtype="float32"),
_io: R.Object,
weight: R.Tensor((32, 3, 3), dtype="float32"),
bias: R.Tensor((32,), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 32, 30), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
lv1: R.Tensor((1, 32, 30), dtype="float32") = R.nn.conv1d(
x,
weight,
strides=[1],
padding=[0, 0],
dilation=[1],
groups=1,
data_layout="NCW",
kernel_layout="OIW",
out_layout="NCW",
out_dtype="void",
)
lv2: R.Tensor((1, 32, 1), dtype="float32") = R.reshape(bias, R.shape([1, 32, 1]))
conv1d: R.Tensor((1, 32, 30), dtype="float32") = R.add(lv1, lv2)
gv1: R.Tuple(R.Tensor((1, 32, 30), dtype="float32"), R.Tuple(R.Object)) = conv1d, (_io,)
R.output(gv1)
return gv1
mod = modules.Conv1D(3, 32, 3, bias=True)
tvm_mod, _ = mod.export_tvm(
spec={
"forward": {
"x": spec.Tensor([1, 3, 32], "float32"),
}
},
debug=True,
)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_conv1d_transpose():
# fmt: off
@R.function
def forward(x: R.Tensor((1, 3, 30), dtype="float32"), _io: R.Object, weight: R.Tensor((3, 32, 3), dtype="float32"), bias: R.Tensor((32,), dtype="float32")) -> R.Tuple(R.Tensor((1, 32, 32), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
lv1: R.Tensor((1, 32, 32), dtype="float32") = R.nn.conv1d_transpose(x, weight, strides=[1], padding=[0, 0], output_padding=[0], dilation=[1], groups=1, data_layout="NCW", kernel_layout="IOW", out_layout="NCW", out_dtype="void")
lv2: R.Tensor((1, 32, 1), dtype="float32") = R.reshape(bias, R.shape([1, 32, 1]))
conv1d_transpose: R.Tensor((1, 32, 32), dtype="float32") = R.add(lv1, lv2)
gv1: R.Tuple(R.Tensor((1, 32, 32), dtype="float32"), R.Tuple(R.Object)) = conv1d_transpose, (_io,)
R.output(gv1)
return gv1
# fmt: on
mod = modules.ConvTranspose1D(3, 32, 3, bias=True)
tvm_mod, _ = mod.export_tvm(
spec={
"forward": {
"x": spec.Tensor([1, 3, 30], "float32"),
}
},
debug=True,
)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_layer_norm():
@R.function
def forward(
x: R.Tensor((2, 4, 8), dtype="float32"),
_io: R.Object,
weight: R.Tensor((8,), dtype="float32"),
bias: R.Tensor((8,), dtype="float32"),
) -> R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
layer_norm: R.Tensor((2, 4, 8), dtype="float32") = R.nn.layer_norm(
x, weight, bias, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True
)
gv1: R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), R.Tuple(R.Object)) = layer_norm, (
_io,
)
R.output(gv1)
return gv1
mod = modules.LayerNorm(8)
tvm_mod, _ = mod.export_tvm(
spec={"forward": {"x": spec.Tensor((2, 4, 8), "float32")}}, debug=True
)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_conv2d():
@R.function
def forward(
x: R.Tensor((1, 3, 32, 32), dtype="float32"),
_io: R.Object,
weight: R.Tensor((32, 3, 3, 3), dtype="float32"),
bias: R.Tensor((32,), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 32, 30, 30), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
lv1: R.Tensor((1, 32, 30, 30), dtype="float32") = R.nn.conv2d(x, weight)
lv2: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(bias, R.shape([1, 32, 1, 1]))
conv2d: R.Tensor((1, 32, 30, 30), dtype="float32") = R.add(lv1, lv2)
gv1: R.Tuple(R.Tensor((1, 32, 30, 30), dtype="float32"), R.Tuple(R.Object)) = conv2d, (
_io,
)
R.output(gv1)
return gv1
mod = modules.Conv2D(3, 32, 3, bias=True)
tvm_mod, _ = mod.export_tvm(
spec={
"forward": {
"x": spec.Tensor([1, 3, 32, 32], "float32"),
}
},
debug=True,
)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_conv3d():
@R.function
def forward(
x: R.Tensor((1, 3, 32, 32, 32), dtype="float32"),
_io: R.Object,
weight: R.Tensor((32, 3, 3, 3, 3), dtype="float32"),
bias: R.Tensor((32,), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 32, 30, 30, 30), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
lv1: R.Tensor((1, 32, 30, 30, 30), dtype="float32") = R.nn.conv3d(x, weight)
lv2: R.Tensor((1, 32, 1, 1, 1), dtype="float32") = R.reshape(
bias, R.shape([1, 32, 1, 1, 1])
)
conv3d: R.Tensor((1, 32, 30, 30, 30), dtype="float32") = R.add(lv1, lv2)
gv1: R.Tuple(
R.Tensor((1, 32, 30, 30, 30), dtype="float32"), R.Tuple(R.Object)
) = conv3d, (_io,)
R.output(gv1)
return gv1
mod = modules.Conv3D(3, 32, 3, bias=True)
tvm_mod, _ = mod.export_tvm(
spec={
"forward": {
"x": spec.Tensor([1, 3, 32, 32, 32], "float32"),
}
},
debug=True,
)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_conv2d_dynamic():
@R.function
def forward(
x: R.Tensor(("n", "c", "h", "w"), dtype="float32"),
_io: R.Object,
weight: R.Tensor((32, "in_channels", 3, 3), dtype="float32"),
bias: R.Tensor((32,), dtype="float32"),
) -> R.Tuple(R.Tensor(("n", 32, "h - 2", "w - 2"), dtype="float32"), R.Tuple(R.Object)):
n = T.int64()
h = T.int64()
w = T.int64()
c = T.int64()
in_channels = T.int64()
R.func_attr({"num_input": 2})
with R.dataflow():
lv1: R.Tensor((n, 32, h - 2, w - 2), dtype="float32") = R.nn.conv2d(x, weight)
lv2: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(bias, R.shape([1, 32, 1, 1]))
conv2d: R.Tensor((n, 32, h - 2, w - 2), dtype="float32") = R.add(lv1, lv2)
gv1: R.Tuple(R.Tensor((n, 32, h - 2, w - 2), dtype="float32"), R.Tuple(R.Object)) = (
conv2d,
(_io,),
)
R.output(gv1)
return gv1
mod = modules.Conv2D(tvm.tir.Var("in_channels", "int64"), 32, 3, bias=True)
tvm_mod, _ = mod.export_tvm(
spec={
"forward": {
"x": spec.Tensor(["n", "c", "h", "w"], "float32"),
}
},
debug=True,
)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_rms_norm():
@R.function
def forward(
x: R.Tensor((2, 4, 8), dtype="float32"),
_io: R.Object,
weight: R.Tensor((8,), dtype="float32"),
) -> R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
rms_norm: R.Tensor((2, 4, 8), dtype="float32") = R.nn.rms_norm(
x, weight, axes=[2], epsilon=1.0000000000000001e-05
)
gv1: R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), R.Tuple(R.Object)) = rms_norm, (_io,)
R.output(gv1)
return gv1
mod = modules.RMSNorm(8, [2], bias=False)
tvm_mod, _ = mod.export_tvm(
spec={"forward": {"x": spec.Tensor((2, 4, 8), "float32")}}, debug=True
)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_group_norm():
@R.function
def forward(
x: R.Tensor((2, 4, 8), dtype="float32"),
_io: R.Object,
weight: R.Tensor((4,), dtype="float32"),
bias: R.Tensor((4,), dtype="float32"),
) -> R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
group_norm: R.Tensor((2, 4, 8), dtype="float32") = R.nn.group_norm(
x, weight, bias, num_groups=2, channel_axis=1, axes=[2]
)
gv1: R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), R.Tuple(R.Object)) = group_norm, (
_io,
)
R.output(gv1)
return gv1
mod = modules.GroupNorm(num_groups=2, num_channels=4)
tvm_mod, _ = mod.export_tvm(
spec={"forward": {"x": spec.Tensor((2, 4, 8), "float32")}}, debug=True
)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_embedding():
@R.function
def forward(
x: R.Tensor((1, 4), dtype="int32"),
_io: R.Object,
weight: R.Tensor((4, 8), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 4, 8), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
reshape: R.Tensor((4,), dtype="int32") = R.reshape(x, R.shape([4]))
take: R.Tensor((4, 8), dtype="float32") = R.take(weight, reshape, axis=0)
reshape1: R.Tensor((1, 4, 8), dtype="float32") = R.reshape(take, R.shape([1, 4, 8]))
gv1: R.Tuple(R.Tensor((1, 4, 8), dtype="float32"), R.Tuple(R.Object)) = reshape1, (_io,)
R.output(gv1)
return gv1
mod = modules.Embedding(4, 8, "float32")
tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((1, 4), "int32")}}, debug=True)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_timestep_embedding():
@R.function
def forward(
sample: R.Tensor((32, 32), dtype="float32"),
condition: R.Tensor((32, 16), dtype="float32"),
_io: R.Object,
linear_1_weight: R.Tensor((32, 32), dtype="float32"),
linear_1_bias: R.Tensor((32,), dtype="float32"),
cond_proj_weight: R.Tensor((32, 16), dtype="float32"),
linear_2_weight: R.Tensor((32, 32), dtype="float32"),
linear_2_bias: R.Tensor((32,), dtype="float32"),
) -> R.Tuple(R.Tensor((32, 32), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 3})
with R.dataflow():
permute_dims: R.Tensor((16, 32), dtype="float32") = R.permute_dims(
cond_proj_weight, axes=None
)
matmul: R.Tensor((32, 32), dtype="float32") = R.matmul(
condition, permute_dims, out_dtype="void"
)
add: R.Tensor((32, 32), dtype="float32") = R.add(sample, matmul)
permute_dims1: R.Tensor((32, 32), dtype="float32") = R.permute_dims(
linear_1_weight, axes=None
)
matmul1: R.Tensor((32, 32), dtype="float32") = R.matmul(
add, permute_dims1, out_dtype="void"
)
add1: R.Tensor((32, 32), dtype="float32") = R.add(matmul1, linear_1_bias)
silu: R.Tensor((32, 32), dtype="float32") = R.nn.silu(add1)
permute_dims2: R.Tensor((32, 32), dtype="float32") = R.permute_dims(
linear_2_weight, axes=None
)
matmul2: R.Tensor((32, 32), dtype="float32") = R.matmul(
silu, permute_dims2, out_dtype="void"
)
add2: R.Tensor((32, 32), dtype="float32") = R.add(matmul2, linear_2_bias)
gv1: R.Tuple(R.Tensor((32, 32), dtype="float32"), R.Tuple(R.Object)) = add2, (_io,)
R.output(gv1)
return gv1
mod = modules.TimestepEmbedding(32, 32, cond_proj_dim=16)
tvm_mod, _ = mod.export_tvm(
spec={
"forward": {
"sample": spec.Tensor((32, 32), "float32"),
"condition": spec.Tensor((32, 16), "float32"),
}
},
debug=True,
)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_timesteps():
@R.function
def forward(
x: R.Tensor((3,), dtype="float32"), _io: R.Object
) -> R.Tuple(R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
lv1: R.Tensor((3,), dtype="float32") = R.astype(x, dtype="float32")
lv2: R.Tensor((3, 1), dtype="float32") = R.expand_dims(lv1, axis=[1])
lv3: R.Tensor((5,), dtype="float32") = R.arange(
R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="float32"
)
lv4: R.Tensor((5,), dtype="float32") = R.multiply(
R.const(-9.2103404998779297, "float32"), lv3
)
lv5: R.Tensor((5,), dtype="float32") = R.divide(lv4, R.const(4, "float32"))
lv6: R.Tensor((5,), dtype="float32") = R.exp(lv5)
lv7: R.Tensor((1, 5), dtype="float32") = R.expand_dims(lv6, axis=[0])
lv8: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv7)
lv9: R.Tensor((3, 5), dtype="float32") = R.sin(lv8)
lv10: R.Tensor((3, 5), dtype="float32") = R.cos(lv8)
lv11: R.Tensor((3, 10), dtype="float32") = R.concat((lv9, lv10), axis=-1)
get_timestep_embedding: R.Tensor((3, 10), dtype="float32") = R.astype(
lv11, dtype="float32"
)
gv1: R.Tuple(R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)) = (
get_timestep_embedding,
(_io,),
)
R.output(gv1)
return gv1
mod = modules.Timesteps(10)
tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3,), "float32")}}, debug=True)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_kv_cache():
@I.ir_module
class Module:
@R.function
def _initialize_effect() -> R.Tuple(R.Object, R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
lv: R.Tensor((8, 2, 4), dtype="float32") = R.zeros(
R.shape([8, 2, 4]), dtype="float32"
)
cache: R.Object = R.call_pure_packed(
"vm.builtin.attention_kv_cache_create",
lv,
R.shape([8, 2, 4]),
R.prim_value(0),
sinfo_args=[R.Object()],
)
lv1 = _io, cache
gv = lv1
R.output(gv)
return gv
@R.function
def forward(
x: R.Tensor((2, 4), dtype="float32"), _io: R.Object, cache: R.Object
) -> R.Tuple(R.Tensor((4, 2, 4), dtype="float32"), R.Tuple(R.Object, R.Object)):
R.func_attr({"num_input": 3})
with R.dataflow():
lv2: R.Object = R.call_inplace_packed(
"vm.builtin.attention_kv_cache_append",
cache,
x,
inplace_indices=[0],
sinfo_args=[R.Object()],
)
lv3: R.Tensor((4, 2, 4), dtype="float32") = R.call_pure_packed(
"vm.builtin.attention_kv_cache_view",
lv2,
R.shape([4, 2, 4]),
sinfo_args=(R.Tensor((4, 2, 4), dtype="float32"),),
)
gv1: R.Tuple(R.Tensor((4, 2, 4), dtype="float32"), R.Tuple(R.Object, R.Object)) = (
lv3,
(_io, lv2),
)
R.output(gv1)
return gv1
class KVCacheTest(modules.Module):
def __init__(self) -> None:
self.cache = modules.KVCache(8, [2, 4])
def forward(self, x: core.Tensor) -> core.Tensor:
self.cache.append(x)
return self.cache.view(4)
tvm_mod, _ = KVCacheTest().export_tvm(
spec={"forward": {"x": spec.Tensor((2, 4), "float32")}}, debug=True
)
assert_structural_equal(tvm_mod, Module, True)
def test_attention():
@R.function
def forward(
hidden_states: R.Tensor((2, 4096, 640), dtype="float32"),
encoder_hidden_states: R.Tensor((2, 77, 2048), dtype="float32"),
_io: R.Object,
to_q_weight: R.Tensor((640, 640), dtype="float32"),
to_k_weight: R.Tensor((640, 2048), dtype="float32"),
to_v_weight: R.Tensor((640, 2048), dtype="float32"),
group_norm_weight: R.Tensor((640,), dtype="float32"),
group_norm_bias: R.Tensor((640,), dtype="float32"),
to_out_0_weight: R.Tensor((640, 640), dtype="float32"),
to_out_0_bias: R.Tensor((640,), dtype="float32"),
) -> R.Tuple(R.Tensor((2, 4096, 640), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 3})
with R.dataflow():
group_norm: R.Tensor((2, 4096, 640), dtype="float32") = R.nn.group_norm(
hidden_states,
group_norm_weight,
group_norm_bias,
num_groups=8,
channel_axis=2,
axes=[1],
epsilon=1.0000000000000001e-05,
center=True,
scale=True,
)
permute_dims: R.Tensor((640, 640), dtype="float32") = R.permute_dims(
to_q_weight, axes=None
)
matmul: R.Tensor((2, 4096, 640), dtype="float32") = R.matmul(
group_norm, permute_dims, out_dtype="void"
)
permute_dims1: R.Tensor((2048, 640), dtype="float32") = R.permute_dims(
to_k_weight, axes=None
)
matmul1: R.Tensor((2, 77, 640), dtype="float32") = R.matmul(
encoder_hidden_states, permute_dims1, out_dtype="void"
)
permute_dims2: R.Tensor((2048, 640), dtype="float32") = R.permute_dims(
to_v_weight, axes=None
)
matmul2: R.Tensor((2, 77, 640), dtype="float32") = R.matmul(
encoder_hidden_states, permute_dims2, out_dtype="void"
)
reshape: R.Tensor((2, 4096, 10, 64), dtype="float32") = R.reshape(
matmul, R.shape([2, 4096, 10, 64])
)
reshape1: R.Tensor((2, 77, 10, 64), dtype="float32") = R.reshape(
matmul1, R.shape([2, 77, 10, 64])
)
reshape2: R.Tensor((2, 77, 10, 64), dtype="float32") = R.reshape(
matmul2, R.shape([2, 77, 10, 64])
)
scaled_dot_product_attention: R.Tensor(
(2, 4096, 10, 64), dtype="float32"
) = R.nn.attention(reshape, reshape1, reshape2, scale=None, causal_mask=None)
reshape3: R.Tensor((2, 4096, 640), dtype="float32") = R.reshape(
scaled_dot_product_attention, R.shape([2, 4096, 640])
)
permute_dims3: R.Tensor((640, 640), dtype="float32") = R.permute_dims(
to_out_0_weight, axes=None
)
matmul3: R.Tensor((2, 4096, 640), dtype="float32") = R.matmul(
reshape3, permute_dims3, out_dtype="void"
)
add: R.Tensor((2, 4096, 640), dtype="float32") = R.add(matmul3, to_out_0_bias)
gv1: R.Tuple(R.Tensor((2, 4096, 640), dtype="float32"), R.Tuple(R.Object)) = add, (_io,)
R.output(gv1)
return gv1
mod = modules.Attention(query_dim=640, cross_attention_dim=2048, heads=10, norm_num_groups=8)
tvm_mod, _ = mod.export_tvm(
spec={
"forward": {
"hidden_states": spec.Tensor((2, 4096, 640), "float32"),
"encoder_hidden_states": spec.Tensor((2, 77, 2048), "float32"),
}
},
debug=True,
)
assert_structural_equal(tvm_mod["forward"], forward, True)
def test_nn_module_tuple_input():
class Layer(nn.Module):
def __init__(self):
pass
def forward(self, x: Tuple[nn.Tensor, nn.Tensor]):
x0 = x[0]
x1 = x[1]
y0 = nn.add(x0, x1)
y1 = nn.subtract(x0, x1)
return (y0, y1)
# fmt: off
@R.function
def forward(x: R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
lv1: R.Tensor((10, 5), dtype="float32") = x[0]
lv2: R.Tensor((10, 5), dtype="float32") = x[1]
add: R.Tensor((10, 5), dtype="float32") = R.add(lv1, lv2)
subtract: R.Tensor((10, 5), dtype="float32") = R.subtract(lv1, lv2)
gv1: R.Tuple(R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), R.Tuple(R.Object)) = (add, subtract), (_io,)
R.output(gv1)
return gv1
# fmt: on
mod = Layer()
tvm_mod, _ = mod.export_tvm(
spec={
"forward": {
"x": (spec.Tensor([10, 5], dtype="float32"), spec.Tensor([10, 5], dtype="float32"))
}
},
debug=True,
)
assert_structural_equal(tvm_mod["forward"], forward)
def test_nn_module_list_input():
class Layer(nn.Module):
def __init__(self):
pass
def forward(self, x: List[nn.Tensor]):
x0 = x[0]
x1 = x[1]
y0 = nn.add(x0, x1)
y1 = nn.subtract(x0, x1)
return [y0, y1]
# fmt: off
@R.function
def forward(x: R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
lv1: R.Tensor((10, 5), dtype="float32") = x[0]
lv2: R.Tensor((10, 5), dtype="float32") = x[1]
add: R.Tensor((10, 5), dtype="float32") = R.add(lv1, lv2)
subtract: R.Tensor((10, 5), dtype="float32") = R.subtract(lv1, lv2)
gv1: R.Tuple(R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), R.Tuple(R.Object)) = (add, subtract), (_io,)
R.output(gv1)
return gv1
# fmt: on
mod = Layer()
tvm_mod, _ = mod.export_tvm(
spec={
"forward": {
"x": [spec.Tensor([10, 5], dtype="float32"), spec.Tensor([10, 5], dtype="float32")]
}
},
debug=True,
)
assert_structural_equal(tvm_mod["forward"], forward)
def test_module_list():
class Module(nn.Module):
def __init__(self):
self.layers = nn.ModuleList(
[nn.ModuleList([nn.Linear(4, 4, bias=False) for _ in range(2)]) for _ in range(1)]
)
def forward(self, x: nn.Tensor):
return self.layers(x)
mod = Module()
named_params = dict(mod.named_parameters())
assert ["layers.0.0.weight", "layers.0.1.weight"] == sorted(list(named_params.keys()))
if __name__ == "__main__":
tvm.testing.main()