blob: 328fbf456e4b1c0c6a6dba0d5cc02bdc4f40818d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Test graph builder && graph."""
import pytest
import torch
from torch.nn import Module
import tvm.testing
from tvm.contrib.msc.framework.torch.frontend import translate
from tvm.contrib.msc.core.utils.namespace import MSCFramework
from tvm.contrib.msc.core import utils as msc_utils
def verify_model(torch_model, input_info, expected):
graph, _ = translate.from_torch(torch_model, input_info)
inspect = graph.inspect()
assert msc_utils.dict_equal(inspect, expected), "Inspect {} mismatch with expected {}".format(
inspect, expected
)
@pytest.mark.parametrize("dynamic", [True, False])
def test_conv1d(dynamic: bool):
"""test graph builder for conv1d"""
class Conv1D1(Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv1d(3, 6, 7, bias=True)
def forward(self, data):
return self.conv(data)
class Conv1D2(Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv1d(3, 6, 7, bias=False)
def forward(self, data):
return self.conv(data)
bz = "bz" if dynamic else 1
expected1 = {
"inputs": [{"name": "inp_0", "shape": [bz, 3, 10], "dtype": "float32", "layout": "NCW"}],
"outputs": [{"name": "conv1d", "shape": [bz, 6, 4], "dtype": "float32", "layout": "NCW"}],
"nodes": {"total": 2, "input": 1, "msc.conv1d_bias": 1},
}
expected2 = {
"inputs": [{"name": "inp_0", "shape": [bz, 3, 10], "dtype": "float32", "layout": "NCW"}],
"outputs": [{"name": "conv1d", "shape": [bz, 6, 4], "dtype": "float32", "layout": "NCW"}],
"nodes": {"total": 2, "input": 1, "nn.conv1d": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10], "float32")]
verify_model(Conv1D1(), input_info, expected1)
verify_model(Conv1D2(), input_info, expected2)
@pytest.mark.parametrize("dynamic", [True, False])
def test_conv2d(dynamic: bool):
"""test graph builder for conv2d"""
class Conv2D1(Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
def forward(self, data):
return self.conv(data)
class Conv2D2(Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 6, 7, bias=False)
def forward(self, data):
return self.conv(data)
bz = "bz" if dynamic else 1
expected1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{
"name": "conv2d",
"shape": [bz, 6, 4, 4],
"dtype": "float32",
"layout": "NCHW",
}
],
"nodes": {"total": 2, "input": 1, "msc.conv2d_bias": 1},
}
expected2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{"name": "conv2d", "shape": [bz, 6, 4, 4], "dtype": "float32", "layout": "NCHW"}
],
"nodes": {"total": 2, "input": 1, "nn.conv2d": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(Conv2D1(), input_info, expected1)
verify_model(Conv2D2(), input_info, expected2)
@pytest.mark.parametrize("dynamic", [True, False])
def test_linear(dynamic: bool):
"""test graph builder for linear"""
class Dense1(Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 7, bias=True)
def forward(self, data):
return self.linear(data)
class Dense2(Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 7, bias=False)
def forward(self, data):
return self.linear(data)
class MatMul1(Module):
def forward(self, x, y):
return torch.matmul(x, y)
bz = "bz" if dynamic else 1
mdim = "mdim" if dynamic else 10
ndim = "ndim" if dynamic else 20
kdim = "kdim" if dynamic else 30
expected1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{
"name": "matmul",
"shape": [bz, 3, 10, 7],
"dtype": "float32",
"layout": "NCHW",
}
],
"nodes": {"total": 2, "input": 1, "msc.linear_bias": 1},
}
expected2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{"name": "matmul", "shape": [bz, 3, 10, 7], "dtype": "float32", "layout": "NCHW"}
],
"nodes": {"total": 2, "input": 1, "msc.linear": 1},
}
expected3 = {
"inputs": [
{"name": "inp_0", "shape": [mdim, kdim], "dtype": "float32", "layout": "NC"},
{"name": "inp_1", "shape": [kdim, ndim], "dtype": "float32", "layout": "IO"},
],
"outputs": [{"name": "matmul", "shape": [mdim, ndim], "dtype": "float32", "layout": "NC"}],
"nodes": {"total": 3, "input": 2, "matmul": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 1, "shape": 1}
expected3["prims"] = {"total": 3, "shape": 3}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(Dense1(), input_info, expected1)
verify_model(Dense2(), input_info, expected2)
verify_model(MatMul1(), [([mdim, kdim], "float32"), ([kdim, ndim], "float32")], expected3)
@pytest.mark.parametrize("dynamic", [True, False])
def test_bmm(dynamic: bool):
"""test graph builder for bmm"""
class BMM(Module):
def forward(self, x, y):
return torch.bmm(x, y)
bz = "bz" if dynamic else 1
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 128, 256], "dtype": "float32", "layout": "NCD"},
{"name": "inp_1", "shape": [bz, 256, 512], "dtype": "float32", "layout": "NIO"},
],
"outputs": [
{"name": "matmul", "shape": [bz, 128, 512], "dtype": "float32", "layout": "NCD"}
],
"nodes": {"total": 3, "input": 2, "matmul": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [((bz, 128, 256), "float32"), ((bz, 256, 512), "float32")]
verify_model(BMM(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_baddbmm(dynamic: bool):
"""test graph builder for baddbmm"""
class BAddBMM1(Module):
def forward(self, c, x, y):
return torch.baddbmm(c, x, y)
class BAddBMM2(Module):
def forward(self, c, x, y):
return torch.baddbmm(c, x, y, alpha=2, beta=0)
bz = "bz" if dynamic else 1
expected1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 128, 512], "dtype": "float32", "layout": "NCD"},
{"name": "inp_1", "shape": [bz, 128, 256], "dtype": "float32", "layout": "NCD"},
{"name": "inp_2", "shape": [bz, 256, 512], "dtype": "float32", "layout": "NIO"},
],
"outputs": [{"name": "add", "shape": [bz, 128, 512], "dtype": "float32", "layout": "NCD"}],
"nodes": {"total": 5, "input": 3, "matmul": 1, "add": 1},
}
expected2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 128, 512], "dtype": "float32", "layout": ""},
{"name": "inp_1", "shape": [bz, 128, 256], "dtype": "float32", "layout": "NCD"},
{"name": "inp_2", "shape": [bz, 256, 512], "dtype": "float32", "layout": "NIO"},
],
"outputs": [
{"name": "multiply", "shape": [bz, 128, 512], "dtype": "float32", "layout": "NCD"}
],
"nodes": {"total": 6, "input": 3, "matmul": 1, "constant": 1, "multiply": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 1, "shape": 1}
input_info = [
((bz, 128, 512), "float32"),
((bz, 128, 256), "float32"),
((bz, 256, 512), "float32"),
]
verify_model(BAddBMM1(), input_info, expected1)
verify_model(BAddBMM2(), input_info, expected2)
@pytest.mark.parametrize("dynamic", [True, False])
def test_relu(dynamic: bool):
"""test graph builder for relu"""
class ReLU(Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, data):
return self.relu(data)
class ReLU1(Module):
def forward(self, data):
return torch.nn.functional.relu(data)
bz = "bz" if dynamic else 1
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": "AB"}],
"outputs": [{"name": "relu", "shape": [bz, 10], "dtype": "float32", "layout": "AB"}],
"nodes": {"total": 2, "input": 1, "nn.relu": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 10], "float32")]
verify_model(ReLU(), input_info, expected)
verify_model(ReLU1(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_relu6(dynamic: bool):
"""test graph builder for relu6"""
class ReLU6(Module):
def __init__(self):
super().__init__()
self.relu6 = torch.nn.ReLU6()
def forward(self, data):
return self.relu6(data)
bz = "bz" if dynamic else 1
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "clip", "shape": [bz, 10], "dtype": "float32", "layout": ""}],
"nodes": {"total": 2, "input": 1, "clip": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 10], "float32")]
verify_model(ReLU6(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_maxpool2d(dynamic: bool):
"""test graph builder for maxpool2d"""
class MaxPool2d(Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1])
def forward(self, data):
return self.pool(data)
class MaxPool2d2(Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3])
def forward(self, data):
return self.pool(data)
class MaxPool2d3(Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2)
def forward(self, data):
return self.pool(data)
bz = "bz" if dynamic else 1
expected1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{"name": "max_pool2d", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1},
}
expected2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{"name": "max_pool2d", "shape": [bz, 3, 4, 4], "dtype": "float32", "layout": "NCHW"}
],
"nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1},
}
expected3 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{"name": "max_pool2d", "shape": [bz, 3, 6, 6], "dtype": "float32", "layout": "NCHW"}
],
"nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 1, "shape": 1}
expected3["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(MaxPool2d(), input_info, expected1)
verify_model(MaxPool2d2(), input_info, expected2)
verify_model(MaxPool2d3(), input_info, expected3)
@pytest.mark.parametrize("dynamic", [True, False])
def test_avgpool2d(dynamic: bool):
"""test graph builder for avgpool2d"""
class AvgPool2d(Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1])
def forward(self, data):
return self.pool(data)
class AvgPool2d2(Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True)
def forward(self, data):
return self.pool(data)
bz = "bz" if dynamic else 1
expected1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{"name": "avg_pool2d", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"nodes": {"total": 2, "input": 1, "nn.avg_pool2d": 1},
}
expected2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{"name": "avg_pool2d", "shape": [bz, 3, 6, 6], "dtype": "float32", "layout": "NCHW"}
],
"nodes": {"total": 2, "input": 1, "nn.avg_pool2d": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(AvgPool2d(), input_info, expected1)
verify_model(AvgPool2d2(), input_info, expected2)
@pytest.mark.parametrize("dynamic", [True, False])
def test_adaptive_avgpool2d(dynamic: bool):
"""test graph builder for adaptive_avgpool2d"""
class AdaptiveAvgPool2d0(Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AdaptiveAvgPool2d([10, 10])
def forward(self, data):
return self.pool(data)
bz = "bz" if dynamic else 1
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{
"name": "adaptive_avg_pool2d",
"shape": [bz, 3, 10, 10],
"dtype": "float32",
"layout": "NCHW",
}
],
"nodes": {"total": 2, "input": 1, "nn.adaptive_avg_pool2d": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(AdaptiveAvgPool2d0(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_flatten(dynamic: bool):
"""test graph builder for flatten"""
class Flatten(Module):
def __init__(self):
super().__init__()
self.f = torch.nn.Flatten(2, -1)
def forward(self, data):
return self.f(data)
bz = "bz" if dynamic else 1
dim = "dim" if dynamic else 10
out_dim = "MUL_3" if dynamic else 100
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 3, 10, dim], "dtype": "float32", "layout": ""}],
"outputs": [
{"name": "reshape", "shape": [bz, 3, out_dim], "dtype": "float32", "layout": ""}
],
"nodes": {"total": 2, "input": 1, "reshape": 1},
}
if dynamic:
expected["prims"] = {"total": 4, "shape": 2, "Int": 1, "Mul": 1}
input_info = [([bz, 3, 10, dim], "float32")]
verify_model(Flatten(), input_info, expected)
verify_model(torch.nn.Flatten(2, -1), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_batchnorm2d(dynamic: bool):
"""test graph builder for batchnorm2d"""
class BatchNorm2d(Module):
def __init__(self):
super().__init__()
self.batchnorm = torch.nn.BatchNorm2d(3)
def forward(self, data):
return self.batchnorm(data)
bz = "bz" if dynamic else 1
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{
"name": "batch_norm.0",
"shape": [bz, 3, 10, 10],
"dtype": "float32",
"layout": "NCHW",
}
],
"nodes": {"total": 3, "input": 1, "nn.batch_norm": 1, "get_item": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(BatchNorm2d(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_embedding(dynamic: bool):
"""test graph builder for embedding"""
class Embedding(Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Embedding(10, 3)
def forward(self, data):
return self.embedding(data)
vocab = "vocab" if dynamic else 4
expected1 = {
"inputs": [{"name": "inp_0", "shape": [vocab], "dtype": "int64", "layout": "A"}],
"outputs": [{"name": "take", "shape": [vocab, 3], "dtype": "float32", "layout": "AB"}],
"nodes": {"total": 2, "input": 1, "msc.embedding": 1},
}
expected2 = {
"inputs": [{"name": "inp_0", "shape": [vocab, 5], "dtype": "int64", "layout": "AB"}],
"outputs": [
{
"name": "take",
"shape": [vocab, 5, 3],
"dtype": "float32",
"layout": "" if dynamic else "CBA",
}
],
"nodes": {"total": 2, "input": 1, "msc.embedding": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1}
verify_model(Embedding(), [([vocab], "int64")], expected1)
verify_model(Embedding(), [([vocab, 5], "int64")], expected2)
@pytest.mark.parametrize("dynamic", [True, False])
def test_dropout(dynamic: bool):
"""test graph builder for dropout"""
class Dropout1(Module):
def __init__(self):
super().__init__()
self.dropout = torch.nn.Dropout(0.5)
def forward(self, data):
return self.dropout(data)
class Dropout2(Module):
def forward(self, data):
return torch.dropout(data, 0.5, train=True)
bz = "bz" if dynamic else 1
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}],
"nodes": {"total": 1, "input": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(Dropout1(), input_info, expected)
verify_model(Dropout2(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_layernorm(dynamic: bool):
"""test graph builder for layernorm"""
class LayerNorm(Module):
def __init__(self):
super().__init__()
self.layernorm = torch.nn.LayerNorm((10, 10))
def forward(self, data):
return self.layernorm(data)
bz = "bz" if dynamic else 1
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{"name": "layer_norm", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"nodes": {"total": 2, "input": 1, "nn.layer_norm": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(LayerNorm(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_functional_layernorm(dynamic: bool):
"""test graph builder for functional_layernorm"""
class LayerNorm(Module):
def __init__(self, shape):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(shape))
self.bias = torch.nn.Parameter(torch.zeros(shape))
def forward(self, data):
return torch.nn.functional.layer_norm(
data, self.weight.shape, self.weight, self.bias, 1e-5
)
bz = "bz" if dynamic else 1
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{"name": "layer_norm", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"nodes": {"total": 2, "input": 1, "nn.layer_norm": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(LayerNorm((10, 10)), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_cross_entropy(dynamic: bool):
"""test graph builder for cross_entropy"""
class CrossEntropy1(Module):
def __init__(self):
super().__init__()
self.loss = torch.nn.CrossEntropyLoss()
def forward(self, logits, targets):
return self.loss(logits, targets)
class CrossEntropy2(Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones((2,)))
self.loss = torch.nn.CrossEntropyLoss(weight=self.weight)
def forward(self, logits, targets):
return self.loss(logits, targets)
class CrossEntropy3(Module):
def __init__(self):
super().__init__()
self.loss = torch.nn.CrossEntropyLoss(ignore_index=1, reduction="sum")
def forward(self, logits, targets):
return self.loss(logits, targets)
bz = "bz" if dynamic else 1
expected1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 2], "dtype": "float32", "layout": ""},
{"name": "inp_1", "shape": [bz], "dtype": "int32", "layout": ""},
],
"outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}],
"nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 1},
}
expected2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 2], "dtype": "float32", "layout": ""},
{"name": "inp_1", "shape": [bz], "dtype": "int32", "layout": ""},
],
"outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}],
"nodes": {"total": 5, "input": 2, "nn.log_softmax": 1, "constant": 1, "nn.nll_loss": 1},
}
expected3 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 2], "dtype": "float32", "layout": ""},
{"name": "inp_1", "shape": [bz], "dtype": "int32", "layout": ""},
],
"outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}],
"nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 1, "shape": 1}
expected3["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 2], "float32"), ([bz], "int32")]
verify_model(CrossEntropy1(), input_info, expected1)
verify_model(CrossEntropy2(), input_info, expected2)
verify_model(CrossEntropy3(), input_info, expected3)
@pytest.mark.parametrize("dynamic", [True, False])
def test_functional_cross_entropy(dynamic: bool):
"""test graph builder for functional_cross_entropy"""
class CrossEntropy(Module):
def forward(self, logits, targets):
return torch.nn.functional.cross_entropy(logits, targets)
bz = "bz" if dynamic else 1
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": ""},
{"name": "inp_1", "shape": [bz], "dtype": "int32", "layout": ""},
],
"outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}],
"nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 10], "float32"), ([bz], "int32")]
verify_model(CrossEntropy(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_silu(dynamic: bool):
"""test graph builder for silu"""
class SiLU(Module):
def __init__(self):
super().__init__()
self.silu = torch.nn.SiLU()
def forward(self, data):
return self.silu(data)
class SiLU2(Module):
def forward(self, data):
return torch.nn.functional.silu(data)
bz = "bz" if dynamic else 1
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "silu", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 2, "input": 1, "nn.silu": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(SiLU(), input_info, expected)
verify_model(SiLU2(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_groupnorm(dynamic: bool):
"""test graph builder for groupnorm"""
class GroupNorm(Module):
def __init__(self):
super().__init__()
self.groupnorm = torch.nn.GroupNorm(3, 3)
def forward(self, data):
return self.groupnorm(data)
bz = "bz" if dynamic else 1
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{"name": "group_norm", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"nodes": {"total": 2, "input": 1, "nn.group_norm": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(GroupNorm(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_softmax(dynamic: bool):
"""test graph builder for softmax"""
class Softmax(Module):
def __init__(self):
super().__init__()
self.softmax = torch.nn.Softmax(dim=1)
def forward(self, data):
return self.softmax(data)
bz = "bz" if dynamic else 1
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "softmax", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 2, "input": 1, "nn.softmax": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(Softmax(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_binary(dynamic: bool):
"""test graph builder for binary"""
bz = "bz" if dynamic else 1
input_info1 = [([bz, 3, 10, 10], "float32"), ([bz, 3, 10, 10], "float32")]
input_info2 = [([bz, 3, 10, 10], "float32")]
# Add
class Add1(Module):
def forward(self, lhs, rhs):
return lhs + rhs
class Add2(Module):
def forward(self, lhs):
return lhs + 1.0
expected_add1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"},
{"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"},
],
"outputs": [
{"name": "add", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 3, "input": 2, "add": 1},
}
expected_add2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "add", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 3, "input": 1, "constant": 1, "add": 1},
}
if dynamic:
expected_add1["prims"] = {"total": 1, "shape": 1}
expected_add2["prims"] = {"total": 1, "shape": 1}
verify_model(Add1(), input_info1, expected_add1)
verify_model(Add2(), input_info2, expected_add2)
# Sub
class Sub1(Module):
def forward(self, lhs, rhs):
return lhs - rhs
class Sub2(Module):
def forward(self, lhs):
return lhs - 1.0
expected_sub1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"},
{"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"},
],
"outputs": [
{"name": "subtract", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 3, "input": 2, "subtract": 1},
}
expected_sub2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "subtract", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 3, "input": 1, "constant": 1, "subtract": 1},
}
if dynamic:
expected_sub1["prims"] = {"total": 1, "shape": 1}
expected_sub2["prims"] = {"total": 1, "shape": 1}
verify_model(Sub1(), input_info1, expected_sub1)
verify_model(Sub2(), input_info2, expected_sub2)
# Mul
class Mul1(Module):
def forward(self, lhs, rhs):
return lhs * rhs
class Mul2(Module):
def forward(self, lhs):
return lhs * 1.0
expected_mul1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"},
{"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"},
],
"outputs": [
{"name": "multiply", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 3, "input": 2, "multiply": 1},
}
expected_mul2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "multiply", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 3, "input": 1, "constant": 1, "multiply": 1},
}
if dynamic:
expected_mul1["prims"] = {"total": 1, "shape": 1}
expected_mul2["prims"] = {"total": 1, "shape": 1}
verify_model(Mul1(), input_info1, expected_mul1)
verify_model(Mul2(), input_info2, expected_mul2)
# True div
class TrueDiv1(Module):
def forward(self, lhs, rhs):
return lhs / rhs
class TrueDiv2(Module):
def forward(self, lhs):
return lhs / 1.0
expected_div1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"},
{"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"},
],
"outputs": [
{"name": "divide", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 3, "input": 2, "divide": 1},
}
expected_div2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "divide", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 3, "input": 1, "constant": 1, "divide": 1},
}
if dynamic:
expected_div1["prims"] = {"total": 1, "shape": 1}
expected_div2["prims"] = {"total": 1, "shape": 1}
verify_model(TrueDiv1(), input_info1, expected_div1)
verify_model(TrueDiv2(), input_info2, expected_div2)
# Floor div
class FloorDiv1(Module):
def forward(self, lhs, rhs):
return lhs // rhs
class FloorDiv2(Module):
def forward(self, lhs):
return lhs // 1.0
expected_floordiv1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"},
{"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"},
],
"outputs": [
{
"name": "floor_divide",
"shape": [bz, 3, 10, 10],
"dtype": "float32",
"layout": "ABCD",
}
],
"nodes": {"total": 3, "input": 2, "floor_divide": 1},
}
expected_floordiv2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{
"name": "floor_divide",
"shape": [bz, 3, 10, 10],
"dtype": "float32",
"layout": "ABCD",
}
],
"nodes": {"total": 3, "input": 1, "constant": 1, "floor_divide": 1},
}
if dynamic:
expected_floordiv1["prims"] = {"total": 1, "shape": 1}
expected_floordiv2["prims"] = {"total": 1, "shape": 1}
verify_model(FloorDiv1(), input_info1, expected_floordiv1)
verify_model(FloorDiv2(), input_info2, expected_floordiv2)
# Power
class Power1(Module):
def forward(self, lhs, rhs):
return lhs**rhs
class Power2(Module):
def forward(self, lhs):
return lhs**1.0
expected_power1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"},
{"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"},
],
"outputs": [
{"name": "power", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 3, "input": 2, "power": 1},
}
expected_power2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "power", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 3, "input": 1, "constant": 1, "power": 1},
}
if dynamic:
expected_power1["prims"] = {"total": 1, "shape": 1}
expected_power2["prims"] = {"total": 1, "shape": 1}
verify_model(Power1(), input_info1, expected_power1)
verify_model(Power2(), input_info2, expected_power2)
# LT
class LT1(Module):
def forward(self, lhs, rhs):
return lhs < rhs
class LT2(Module):
def forward(self, lhs):
return lhs < 1.0
expected_lt1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"},
{"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"},
],
"outputs": [{"name": "less", "shape": [bz, 3, 10, 10], "dtype": "bool", "layout": "ABCD"}],
"nodes": {"total": 3, "input": 2, "less": 1},
}
expected_lt2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [{"name": "less", "shape": [bz, 3, 10, 10], "dtype": "bool", "layout": "ABCD"}],
"nodes": {"total": 3, "input": 1, "constant": 1, "less": 1},
}
if dynamic:
expected_lt1["prims"] = {"total": 1, "shape": 1}
expected_lt2["prims"] = {"total": 1, "shape": 1}
verify_model(LT1(), input_info1, expected_lt1)
verify_model(LT2(), input_info2, expected_lt2)
@pytest.mark.parametrize("dynamic", [True, False])
def test_size(dynamic: bool):
"""test graph builder for size"""
class Size(Module):
def forward(self, data):
return data.size()
bz = "bz" if dynamic else 1
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "shape", "shape": [4], "dtype": "int32", "layout": "O"}],
"nodes": {"total": 2, "input": 1, "shape": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(Size(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_squeeze(dynamic: bool):
"""test graph builder for squeeze"""
class Squeeze1(Module):
def forward(self, data):
return data.squeeze(1)
class Squeeze2(Module):
def forward(self, data):
return data.squeeze()
bz = "bz" if dynamic else 10
expected1 = {
"inputs": [{"name": "inp_0", "shape": [bz, 1, 4, 1], "dtype": "float32", "layout": "ADBC"}],
"outputs": [{"name": "squeeze", "shape": [bz, 4, 1], "dtype": "float32", "layout": "ABC"}],
"nodes": {"total": 2, "input": 1, "squeeze": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 1, 4, 1], "dtype": "float32", "layout": "ACBD"}
],
"outputs": [{"name": "squeeze", "shape": [], "dtype": "float32", "layout": "AB"}],
"nodes": {"total": 2, "input": 1, "squeeze": 1},
"prims": {"total": 1, "shape": 1},
}
else:
expected2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 1, 4, 1], "dtype": "float32", "layout": "ACBD"}
],
"outputs": [{"name": "squeeze", "shape": [bz, 4], "dtype": "float32", "layout": "AB"}],
"nodes": {"total": 2, "input": 1, "squeeze": 1},
}
input_info = [([bz, 1, 4, 1], "float32")]
verify_model(Squeeze1(), input_info, expected1)
verify_model(Squeeze2(), input_info, expected2)
@pytest.mark.parametrize("dynamic", [True, False])
def test_unsqueeze(dynamic: bool):
"""test graph builder for unsqueeze"""
class Unsqueeze1(Module):
def forward(self, data):
return data.unsqueeze(1)
class Unsqueeze2(Module):
def forward(self, data):
return data.unsqueeze(-1)
bz = "bz" if dynamic else 1
expected1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ACDE"}
],
"outputs": [
{
"name": "expand_dims",
"shape": [bz, 1, 3, 10, 10],
"dtype": "float32",
"layout": "ABCDE",
}
],
"nodes": {"total": 2, "input": 1, "expand_dims": 1},
}
expected2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCE"}
],
"outputs": [
{
"name": "expand_dims",
"shape": [bz, 3, 10, 10, 1],
"dtype": "float32",
"layout": "ABCDE",
}
],
"nodes": {"total": 2, "input": 1, "expand_dims": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(Unsqueeze1(), input_info, expected1)
verify_model(Unsqueeze2(), input_info, expected2)
@pytest.mark.parametrize("dynamic", [True, False])
def test_getattr(dynamic: bool):
"""test graph builder for getattr"""
class GetAttr1(Module):
def forward(self, data):
return data.shape
bz = "bz" if dynamic else 1
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "shape", "shape": [4], "dtype": "int32", "layout": "O"}],
"nodes": {"total": 2, "input": 1, "shape": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(GetAttr1(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_getitem(dynamic: bool):
"""test graph builder for getitem"""
class Slice1(Module):
def forward(self, x):
return x[0, 1::2, :, :3]
class Slice2(Module):
def forward(self, x):
return x[:, None, None, :, None]
bz = "bz" if dynamic else 1
expected1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{
"name": "reshape",
"shape": ["MIN_2" if dynamic else 1, 1, 10, 3],
"dtype": "float32",
"layout": "ABCD",
}
],
"nodes": {"total": 3, "input": 1, "strided_slice": 1, "reshape": 1},
}
expected2 = {
"inputs": [{"name": "inp_0", "shape": [bz, 16], "dtype": "float32", "layout": "AB"}],
"outputs": [
{"name": "reshape", "shape": [bz, 1, 1, 16, 1], "dtype": "float32", "layout": "CDAEB"}
],
"nodes": {"total": 3, "input": 1, "strided_slice": 1, "reshape": 1},
}
if dynamic:
expected1["prims"] = {"total": 3, "shape": 1, "Int": 1, "Min": 1}
expected2["prims"] = {"total": 1, "shape": 1}
verify_model(Slice1(), [([bz, 3, 10, 10], "float32")], expected1)
verify_model(Slice2(), [([bz, 16], "float32")], expected2)
@pytest.mark.parametrize("dynamic", [True, False])
def test_unary(dynamic: bool):
"""test graph builder for unary"""
bz = "bz" if dynamic else 1
input_info = [([bz, 3, 10, 10], "float32")]
# sin
class Sin(Module):
def forward(self, data):
return torch.sin(data)
expected_sin = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "sin", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 2, "input": 1, "sin": 1},
}
if dynamic:
expected_sin["prims"] = {"total": 1, "shape": 1}
verify_model(Sin(), input_info, expected_sin)
# cos
class Cos(Module):
def forward(self, data):
return torch.cos(data)
expected_cos = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "cos", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 2, "input": 1, "cos": 1},
}
if dynamic:
expected_cos["prims"] = {"total": 1, "shape": 1}
verify_model(Cos(), input_info, expected_cos)
# exp
class Exp(Module):
def forward(self, data):
return torch.exp(data)
expected_exp = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "exp", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 2, "input": 1, "exp": 1},
}
if dynamic:
expected_exp["prims"] = {"total": 1, "shape": 1}
verify_model(Exp(), input_info, expected_exp)
# sqrt
class Sqrt(Module):
def forward(self, data):
return torch.sqrt(data)
expected_sqrt = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "sqrt", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 2, "input": 1, "sqrt": 1},
}
if dynamic:
expected_sqrt["prims"] = {"total": 1, "shape": 1}
verify_model(Sqrt(), input_info, expected_sqrt)
# sigmoid
class Sigmoid(Module):
def forward(self, data):
return torch.sigmoid(data)
expected_sigmoid = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "sigmoid", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 2, "input": 1, "sigmoid": 1},
}
if dynamic:
expected_sigmoid["prims"] = {"total": 1, "shape": 1}
verify_model(Sigmoid(), input_info, expected_sigmoid)
# round
class Round(Module):
def forward(self, data):
return torch.round(data)
expected_round = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "round", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 2, "input": 1, "round": 1},
}
if dynamic:
expected_round["prims"] = {"total": 1, "shape": 1}
verify_model(Round(), input_info, expected_round)
@pytest.mark.parametrize("dynamic", [True, False])
def test_gelu(dynamic: bool):
"""test graph builder for gelu"""
class Gelu(Module):
def forward(self, data):
return torch.nn.functional.gelu(data)
bz = "bz" if dynamic else 1
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "gelu", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 2, "input": 1, "nn.gelu": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(Gelu(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_tanh(dynamic: bool):
"""test graph builder for tanh"""
class Tanh(Module):
def forward(self, data):
return torch.tanh(data)
bz = "bz" if dynamic else 1
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "tanh", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 2, "input": 1, "tanh": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(Tanh(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_clamp(dynamic: bool):
"""test graph builder for clamp"""
class Clamp(Module):
def forward(self, data):
return torch.clamp(data, min=0.1, max=0.5)
bz = "bz" if dynamic else 1
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "clip", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}],
"nodes": {"total": 2, "input": 1, "clip": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(Clamp(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_interpolate(dynamic: bool):
"""test graph builder for interpolate"""
class Interpolate(Module):
def forward(self, data):
return torch.nn.functional.interpolate(data, (5, 5))
bz = "bz" if dynamic else 1
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{"name": "resize2d", "shape": [bz, 3, 5, 5], "dtype": "float32", "layout": "NCHW"}
],
"nodes": {"total": 2, "input": 1, "image.resize2d": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(Interpolate(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_addmm(dynamic: bool):
"""test graph builder for addmm"""
class Addmm(Module):
def forward(self, x_1, x_2, x_3):
return torch.addmm(x_1, x_2, x_3)
mdim = "mdim" if dynamic else 10
ndim = "ndim" if dynamic else 20
kdim = "kdim" if dynamic else 30
expected = {
"inputs": [
{"name": "inp_0", "shape": [mdim, ndim], "dtype": "float32", "layout": "NC"},
{"name": "inp_1", "shape": [mdim, kdim], "dtype": "float32", "layout": "NC"},
{"name": "inp_2", "shape": [kdim, ndim], "dtype": "float32", "layout": "IO"},
],
"outputs": [{"name": "add", "shape": [mdim, ndim], "dtype": "float32", "layout": "NC"}],
"nodes": {"total": 5, "input": 3, "matmul": 1, "add": 1},
}
if dynamic:
expected["prims"] = {"total": 3, "shape": 3}
input_info = [([mdim, ndim], "float32"), ([mdim, kdim], "float32"), ([kdim, ndim], "float32")]
verify_model(Addmm(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_split(dynamic: bool):
"""test graph builder for split"""
class Split1(Module):
def forward(self, data):
return torch.split(data, 1, dim=1)
class Split2(Module):
def forward(self, data):
return torch.split(data, [1, 2], dim=1)
bz = "bz" if dynamic else 1
expected1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "split_0", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"},
{"name": "split_1", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"},
{"name": "split_2", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"},
],
"nodes": {"total": 2, "input": 1, "split": 1},
}
expected2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "split_0", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"},
{"name": "split_1", "shape": [bz, 2, 10, 10], "dtype": "float32", "layout": "ABCD"},
],
"nodes": {"total": 2, "input": 1, "split": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(Split1(), input_info, expected1)
verify_model(Split2(), input_info, expected2)
@pytest.mark.parametrize("dynamic", [True, False])
def test_unbind(dynamic: bool):
"""test graph builder for unbind"""
class Unbind(Module):
def forward(self, data):
return torch.unbind(data, dim=1)
bz = "bz" if dynamic else 1
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "tuple_0", "shape": [bz, 10, 10], "dtype": "float32", "layout": "ACD"},
{"name": "tuple_1", "shape": [bz, 10, 10], "dtype": "float32", "layout": "ACD"},
{"name": "tuple_2", "shape": [bz, 10, 10], "dtype": "float32", "layout": "ACD"},
],
"nodes": {"total": 9, "input": 1, "split": 1, "get_item": 3, "squeeze": 3, "tuple": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(Unbind(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_cumsum(dynamic: bool):
"""test graph builder for cumsum"""
class Cumsum(Module):
def forward(self, data):
return torch.cumsum(data, dim=1, dtype=torch.int32)
bz = "bz" if dynamic else 1
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "cumsum", "shape": [bz, 2, 3, 4], "dtype": "int32", "layout": ""}],
"nodes": {"total": 2, "input": 1, "cumsum": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 2, 3, 4], "float32")]
verify_model(Cumsum(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_chunk(dynamic: bool):
"""test graph builder for chunk"""
class Chunk(Module):
def forward(self, data):
return torch.chunk(data, 3, dim=1)
bz = "bz" if dynamic else 1
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"outputs": [
{"name": "split_0", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"},
{"name": "split_1", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"},
{"name": "split_2", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"},
],
"nodes": {"total": 2, "input": 1, "split": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 3, 10, 10], "float32")]
verify_model(Chunk(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_inplace_fill(dynamic: bool):
"""test graph builder for inplace_fill"""
class InplaceFill(Module):
def forward(self, data):
data.fill_(1.5)
return data
bz = "bz" if dynamic else 1
if dynamic:
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "full", "shape": [bz, 10], "dtype": "float32", "layout": ""}],
"nodes": {"total": 3, "input": 1, "constant": 1, "full": 1},
"prims": {"total": 1, "shape": 1},
}
else:
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "const", "shape": [bz, 10], "dtype": "float32", "layout": ""}],
"nodes": {"total": 2, "input": 1, "constant": 1},
}
verify_model(InplaceFill(), [([bz, 10], "float32")], expected)
def test_arange():
"""test graph builder for arange"""
class Arange(Module):
def forward(self):
return torch.arange(0, 20, dtype=torch.int32)
expected = {
"inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "const", "shape": [20], "dtype": "int32", "layout": ""}],
"nodes": {"total": 2, "input": 1, "constant": 1},
}
verify_model(Arange(), [([10, 10], "float32")], expected)
def test_empty():
"""test graph builder for empty"""
class Empty(Module):
def forward(self):
return torch.empty((10, 10), dtype=torch.float32)
expected = {
"inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "const", "shape": [10, 10], "dtype": "float32", "layout": ""}],
"nodes": {"total": 2, "input": 1, "constant": 1},
}
verify_model(Empty(), [([10, 10], "float32")], expected)
def test_tensor():
"""test graph builder for tensor"""
class Empty1(Module):
def forward(self):
return torch.tensor(3, dtype=torch.float32)
expected1 = {
"inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "const", "shape": [], "dtype": "float32", "layout": ""}],
"nodes": {"total": 2, "input": 1, "constant": 1},
}
class Empty2(Module):
def forward(self):
return torch.tensor(3)
expected2 = {
"inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "const", "shape": [], "dtype": "int64", "layout": ""}],
"nodes": {"total": 2, "input": 1, "constant": 1},
}
verify_model(Empty1(), [([10, 10], "float32")], expected1)
verify_model(Empty2(), [([10, 10], "float32")], expected2)
@pytest.mark.parametrize("dynamic", [True, False])
def test_tril(dynamic: bool):
"""test graph builder for tril"""
class Tril(Module):
def forward(self, data):
return torch.tril(data, 1)
class InplaceTril(Module):
def forward(self, data):
data.tril_(1)
return data
row = "row" if dynamic else 10
col = "col" if dynamic else 10
expected = {
"inputs": [{"name": "inp_0", "shape": [row, col], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "tril", "shape": [row, col], "dtype": "float32", "layout": ""}],
"nodes": {"total": 2, "input": 1, "tril": 1},
}
if dynamic:
expected["prims"] = {"total": 2, "shape": 2}
input_info = [([row, col], "float32")]
verify_model(Tril(), input_info, expected)
verify_model(InplaceTril(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_triu(dynamic: bool):
"""test graph builder for triu"""
class Triu(Module):
def forward(self, data):
return torch.triu(data, 1)
class InplaceTriu(Module):
def forward(self, data):
data.triu_(1)
return data
row = "row" if dynamic else 10
col = "col" if dynamic else 10
expected = {
"inputs": [{"name": "inp_0", "shape": [row, col], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "triu", "shape": [row, col], "dtype": "float32", "layout": ""}],
"nodes": {"total": 2, "input": 1, "triu": 1},
}
if dynamic:
expected["prims"] = {"total": 2, "shape": 2}
input_info = [([row, col], "float32")]
verify_model(Triu(), input_info, expected)
verify_model(InplaceTriu(), input_info, expected)
def test_new_ones():
"""test graph builder for new_ones"""
class NewOnes(Module):
def forward(self, x):
return x.new_ones(1, 2, 3)
expected = {
"inputs": [{"name": "inp_0", "shape": [1, 2, 3], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "const", "shape": [1, 2, 3], "dtype": "float32", "layout": ""}],
"nodes": {"total": 2, "input": 1, "constant": 1},
}
input_info = [([1, 2, 3], "float32")]
verify_model(NewOnes(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_expand(dynamic: bool):
"""test graph builder for expand"""
class Expand1(Module):
def forward(self, x):
return x.expand(4, 2, 3, 4)
class Expand2(Module):
def forward(self, x):
return x.expand(4, -1, -1, 4)
bz = "bz" if dynamic else 1
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": ""}],
"outputs": [
{"name": "broadcast_to", "shape": [4, 2, 3, 4], "dtype": "float32", "layout": ""}
],
"nodes": {"total": 2, "input": 1, "broadcast_to": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 2, 3, 4], "float32")]
verify_model(Expand1(), input_info, expected)
verify_model(Expand2(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_reduce(dynamic: bool):
"""test graph builder for reduce"""
# sum
class Sum(Module):
def forward(self, x):
return torch.sum(x, (2, 1))
bz = "bz" if dynamic else 1
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ACDB"}],
"outputs": [{"name": "sum", "shape": [bz, 4], "dtype": "float32", "layout": "AB"}],
"nodes": {"total": 2, "input": 1, "sum": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz, 2, 3, 4], "float32")]
verify_model(Sum(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_datatype(dynamic: bool):
"""test graph builder for datatype"""
bz = "bz" if dynamic else 1
input_info = [([bz, 2, 3, 4], "float32")]
# float
class ToFloat(Module):
def forward(self, x):
return x.float()
expected1 = {
"inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}],
"outputs": [
{"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 2, "input": 1, "astype": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
verify_model(ToFloat(), input_info, expected1)
# half
class ToHalf(Module):
def forward(self, x):
return x.half()
expected2 = {
"inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}],
"outputs": [
{"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float16", "layout": "ABCD"}
],
"nodes": {"total": 2, "input": 1, "astype": 1},
}
if dynamic:
expected2["prims"] = {"total": 1, "shape": 1}
verify_model(ToHalf(), input_info, expected2)
# type
class Type(Module):
def forward(self, x):
return x.type(torch.float32)
expected3 = {
"inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}],
"outputs": [
{"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 2, "input": 1, "astype": 1},
}
if dynamic:
expected3["prims"] = {"total": 1, "shape": 1}
# type
class TypeFromAttr(Module):
def forward(self, x):
return x.type(x.getattr("dtype"))
expected4 = {
"inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}],
"outputs": [
{"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 2, "input": 1, "astype": 1},
}
if dynamic:
expected4["prims"] = {"total": 1, "shape": 1}
# astype
class AsType(Module):
def forward(self, x):
return x.astype(torch.float32)
expected5 = {
"inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}],
"outputs": [
{"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 2, "input": 1, "astype": 1},
}
if dynamic:
expected5["prims"] = {"total": 1, "shape": 1}
verify_model(Type(), input_info, expected3)
verify_model(TypeFromAttr(), input_info, expected4)
verify_model(AsType(), input_info, expected5)
@pytest.mark.parametrize("dynamic", [True, False])
def test_permute(dynamic: bool):
"""test graph builder for permute"""
class Permute(Module):
def forward(self, x):
return x.permute(0, 3, 2, 1)
bz = "bz" if dynamic else 1
channel = "channel" if dynamic else 2
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, channel, 3, 4], "dtype": "float32", "layout": "ADCB"}
],
"outputs": [
{
"name": "permute_dims",
"shape": [bz, 4, 3, channel],
"dtype": "float32",
"layout": "ABCD",
}
],
"nodes": {"total": 2, "input": 1, "permute_dims": 1},
}
if dynamic:
expected["prims"] = {"total": 2, "shape": 2}
input_info = [([bz, channel, 3, 4], "float32")]
verify_model(Permute(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_reshape(dynamic: bool):
"""test graph builder for reshape"""
class Reshape(Module):
def forward(self, x):
return x.reshape(-1, 12)
bz = "bz" if dynamic else 1
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": ""}],
"outputs": [
{
"name": "reshape",
"shape": ["MUL_2" if dynamic else 2, 12],
"dtype": "float32",
"layout": "",
}
],
"nodes": {"total": 2, "input": 1, "reshape": 1},
}
if dynamic:
expected["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1}
input_info = [([bz, 2, 3, 4], "float32")]
verify_model(Reshape(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_transpose(dynamic: bool):
"""test graph builder for transpose"""
class Transpose(Module):
def forward(self, x):
return x.transpose(1, 3)
bz = "bz" if dynamic else 1
hidden = "hidden" if dynamic else 4
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 2, 3, hidden], "dtype": "float32", "layout": "ADCB"}
],
"outputs": [
{
"name": "permute_dims",
"shape": [bz, hidden, 3, 2],
"dtype": "float32",
"layout": "ABCD",
}
],
"nodes": {"total": 2, "input": 1, "permute_dims": 1},
}
if dynamic:
expected["prims"] = {"total": 2, "shape": 2}
input_info = [([bz, 2, 3, hidden], "float32")]
verify_model(Transpose(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_view(dynamic: bool):
"""test graph builder for view"""
class View(Module):
def forward(self, x):
return x.view(-1, 12)
bz = "bz" if dynamic else 1
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": ""}],
"outputs": [
{
"name": "reshape",
"shape": ["MUL_2" if dynamic else 2, 12],
"dtype": "float32",
"layout": "",
}
],
"nodes": {"total": 2, "input": 1, "reshape": 1},
}
if dynamic:
expected["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1}
input_info = [([bz, 2, 3, 4], "float32")]
verify_model(View(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_keep_params(dynamic: bool):
"""test graph builder for keep_params"""
class Conv2D1(Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
def forward(self, data):
return self.conv(data)
bz = "bz" if dynamic else 1
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{
"name": "conv2d",
"shape": [bz, 6, 4, 4],
"dtype": "float32",
"layout": "NCHW",
}
],
"nodes": {"total": 2, "input": 1, "msc.conv2d_bias": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
verify_model(Conv2D1(), [([bz, 3, 10, 10], "float32")], expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_unwrap_unit_return_tuple(dynamic: bool):
"""test graph builder for unwrap_unit_return_tuple"""
class Identity(Module):
def forward(self, x):
return (x,)
bz = "bz" if dynamic else 1
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "tuple", "shape": [bz, 256], "dtype": "float32", "layout": ""}],
"nodes": {"total": 2, "input": 1, "tuple": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
verify_model(Identity(), [([bz, 256], "float32")], expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_no_bind_return_tuple(dynamic: bool):
"""test graph builder for no_bind_return_tuple"""
class Identity(Module):
def forward(self, x, y):
return (x, y)
bz_x = "bz" if dynamic else 1
bz_y = "bz" if dynamic else 2
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz_x, 256], "dtype": "float32", "layout": ""},
{"name": "inp_1", "shape": [bz_y, 256], "dtype": "float32", "layout": ""},
],
"outputs": [
{"name": "tuple_0", "shape": [bz_x, 256], "dtype": "float32", "layout": ""},
{"name": "tuple_1", "shape": [bz_y, 256], "dtype": "float32", "layout": ""},
],
"nodes": {"total": 3, "input": 2, "tuple": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
input_info = [([bz_x, 256], "float32"), ([bz_y, 256], "float32")]
verify_model(Identity(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_argmax(dynamic: bool):
"""test graph builder for argmax"""
class Argmax1(Module):
def forward(self, data):
return torch.argmax(data, dim=-1)
class Argmax2(Module):
def forward(self, data):
return torch.argmax(data, dim=-1, keepdim=True)
bz = "bz" if dynamic else 1
expected1 = {
"inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "argmax", "shape": [bz], "dtype": "int64", "layout": ""}],
"nodes": {"total": 2, "input": 1, "argmax": 1},
}
expected2 = {
"inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "argmax", "shape": [bz, 1], "dtype": "int64", "layout": ""}],
"nodes": {"total": 2, "input": 1, "argmax": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 1, "shape": 1}
verify_model(Argmax1(), [([bz, 256], "float32")], expected1)
verify_model(Argmax2(), [([bz, 256], "float32")], expected2)
@pytest.mark.parametrize("dynamic", [True, False])
def test_argmin(dynamic: bool):
"""test graph builder for argmin"""
class Argmin1(Module):
def forward(self, data):
return torch.argmin(data)
class Argmin2(Module):
def forward(self, data):
return torch.argmin(data, keepdim=True)
bz = "bz" if dynamic else 1
expected1 = {
"inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "argmin", "shape": [], "dtype": "int64", "layout": ""}],
"nodes": {"total": 2, "input": 1, "argmin": 1},
}
expected2 = {
"inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "argmin", "shape": [1, 1], "dtype": "int64", "layout": ""}],
"nodes": {"total": 2, "input": 1, "argmin": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 1, "shape": 1}
verify_model(Argmin1(), [([bz, 256], "float32")], expected1)
verify_model(Argmin2(), [([bz, 256], "float32")], expected2)
@pytest.mark.parametrize("dynamic", [True, False])
def test_to(dynamic: bool):
"""test graph builder for to"""
class To1(Module):
def forward(self, data):
return data.to(torch.float16)
class To2(Module):
def forward(self, data):
return data.to("cpu")
bz = "bz" if dynamic else 1
expected1 = {
"inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}],
"outputs": [{"name": "astype", "shape": [bz, 256], "dtype": "float16", "layout": "AB"}],
"nodes": {"total": 2, "input": 1, "astype": 1},
}
expected2 = {
"inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}],
"outputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}],
"nodes": {"total": 1, "input": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 1, "shape": 1}
verify_model(To1(), [([bz, 256], "float32")], expected1)
verify_model(To2(), [([bz, 256], "float32")], expected2)
@pytest.mark.parametrize("dynamic", [True, False])
def test_mean(dynamic: bool):
"""test graph builder for mean"""
class Mean(Module):
def forward(self, data):
return data.mean(-1)
class MeanKeepDim(Module):
def forward(self, data):
return data.mean(-1, keepdim=True)
bz = "bz" if dynamic else 1
expected1 = {
"inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}],
"outputs": [{"name": "mean", "shape": [bz], "dtype": "float32", "layout": "A"}],
"nodes": {"total": 2, "input": 1, "mean": 1},
}
expected2 = {
"inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}],
"outputs": [{"name": "mean", "shape": [bz, 1], "dtype": "float32", "layout": "AB"}],
"nodes": {"total": 2, "input": 1, "mean": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 1, "shape": 1}
verify_model(Mean(), [([bz, 256], "float32")], expected1)
verify_model(MeanKeepDim(), [([bz, 256], "float32")], expected2)
@pytest.mark.parametrize("dynamic", [True, False])
def test_rsqrt(dynamic: bool):
"""test graph builder for rsqrt"""
class Rsqrt(Module):
def forward(self, data):
return torch.rsqrt(data)
bz = "bz" if dynamic else 1
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}],
"outputs": [{"name": "rsqrt", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}],
"nodes": {"total": 2, "input": 1, "rsqrt": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
verify_model(Rsqrt(), [([bz, 256], "float32")], expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_neg(dynamic: bool):
"""test graph builder for neg"""
class Neg(Module):
def forward(self, data):
return -data
bz = "bz" if dynamic else 1
expected = {
"inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}],
"outputs": [{"name": "negative", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}],
"nodes": {"total": 2, "input": 1, "negative": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
verify_model(Neg(), [([bz, 256], "float32")], expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_max(dynamic: bool):
"""test graph builder for max"""
class Max(Module):
def forward(self, x, y):
return torch.max(x, y)
bz = "bz" if dynamic else 1
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"},
{"name": "inp_1", "shape": [bz, 256], "dtype": "float32", "layout": "AB"},
],
"outputs": [{"name": "maximum", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}],
"nodes": {"total": 3, "input": 2, "maximum": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
verify_model(Max(), [([bz, 256], "float32"), ([bz, 256], "float32")], expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_cat(dynamic: bool):
"""test graph builder for cat"""
class Cat1(Module):
def forward(self, data, data1, data2):
return torch.cat((data, data1, data2), dim=1)
class Cat2(Module):
def forward(self, data):
const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32)
const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32)
return torch.cat((data, const1, const2), dim=1)
bz = "bz" if dynamic else 1
dim = "dim" if dynamic else 3
input_info = [
([bz, dim, 10, 10], "float32"),
([bz, dim, 10, 10], "float32"),
([bz, dim, 10, 10], "float32"),
]
expected1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, dim, 10, 10], "dtype": "float32", "layout": ""},
{"name": "inp_1", "shape": [bz, dim, 10, 10], "dtype": "float32", "layout": ""},
{"name": "inp_2", "shape": [bz, dim, 10, 10], "dtype": "float32", "layout": ""},
],
"outputs": [
{
"name": "concat",
"shape": [bz, "MUL_3" if dynamic else 9, 10, 10],
"dtype": "float32",
"layout": "ABCD",
}
],
"nodes": {"total": 4, "input": 3, "concat": 1},
}
expected2 = {
"inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}],
"outputs": [
{"name": "concat", "shape": [1, 9, 10, 10], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 4, "input": 1, "constant": 2, "concat": 1},
}
if dynamic:
expected1["prims"] = {"total": 4, "shape": 2, "Int": 1, "Mul": 1}
verify_model(Cat1(), input_info, expected1)
verify_model(Cat2(), [([1, 3, 10, 10], "float32")], expected2)
@pytest.mark.parametrize("dynamic", [True, False])
def test_stack(dynamic: bool):
"""Test graph builder for stack."""
bz = "bz" if dynamic else 1
class Stack(Module):
def forward(self, data, data1, data2):
return torch.stack((data, data1, data2), dim=0)
input_info = [
([bz, 3, 10, 10], "float32"),
([bz, 3, 10, 10], "float32"),
([bz, 3, 10, 10], "float32"),
]
expected = {
"inputs": [
{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""},
{"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""},
{"name": "inp_2", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""},
],
"outputs": [
{
"name": "stack",
"shape": [3, bz, 3, 10, 10],
"dtype": "float32",
"layout": "SABCD",
}
],
"nodes": {"total": 4, "input": 3, "stack": 1},
}
if dynamic:
expected["prims"] = {"total": 1, "shape": 1}
verify_model(Stack(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
def test_scatter(dynamic: bool):
"""test graph builder for scatter"""
bz = "bz" if dynamic else 20
class Scatter1(Module):
def __init__(self):
super().__init__()
self.index = msc_utils.random_data([(2, 5), "int64"], MSCFramework.TORCH, max_val=5)
def forward(self, data, src):
return data.scatter(dim=0, index=self.index, src=src)
class Scatter2(Module):
def forward(self, data, index, src):
return data.scatter(0, index, src)
expected1 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": ""},
{"name": "inp_1", "shape": [2, 5], "dtype": "float32", "layout": ""},
],
"outputs": [
{"name": "scatter_elements", "shape": [bz, 20], "dtype": "float32", "layout": ""}
],
"nodes": {"total": 4, "input": 2, "constant": 1, "scatter_elements": 1},
}
expected2 = {
"inputs": [
{"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": ""},
{"name": "inp_1", "shape": [2, 5], "dtype": "int64", "layout": ""},
{"name": "inp_2", "shape": [2, 5], "dtype": "float32", "layout": ""},
],
"outputs": [
{"name": "scatter_elements", "shape": [bz, 20], "dtype": "float32", "layout": ""}
],
"nodes": {"total": 4, "input": 3, "scatter_elements": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 1, "shape": 1}
verify_model(Scatter1(), [([bz, 20], "float32"), ([2, 5], "float32")], expected1)
verify_model(
Scatter2(), [([bz, 20], "float32"), ([2, 5], "int64"), ([2, 5], "float32")], expected2
)
@pytest.mark.parametrize("dynamic", [True, False])
def test_masked_scatter(dynamic: bool):
"""test graph builder for masked_scatter"""
dim = "dim" if dynamic else 5
class MaskedScatter1(Module):
def forward(self, data, mask, src):
return data.masked_scatter(mask, src)
class MaskedScatter2(Module):
def forward(self, data, mask, src):
return data.masked_scatter(mask, src)
expected1 = {
"inputs": [
{"name": "inp_0", "shape": [dim], "dtype": "float32", "layout": "A"},
{"name": "inp_1", "shape": [dim], "dtype": "bool", "layout": "A"},
{"name": "inp_2", "shape": [10], "dtype": "float32", "layout": "A"},
],
"outputs": [{"name": "where", "shape": [dim], "dtype": "float32", "layout": "A"}],
"nodes": {
"total": 8,
"input": 3,
"cumsum": 1,
"constant": 1,
"subtract": 1,
"take": 1,
"where": 1,
},
}
expected2 = {
"inputs": [
{
"name": "inp_0",
"shape": [2, dim],
"dtype": "float32",
"layout": "" if dynamic else "BA",
},
{
"name": "inp_1",
"shape": [2, dim],
"dtype": "bool",
"layout": "" if dynamic else "BA",
},
{
"name": "inp_2",
"shape": [3, dim],
"dtype": "float32",
"layout": "" if dynamic else "BA",
},
],
"outputs": [
{
"name": "where",
"shape": [2, dim],
"dtype": "float32",
"layout": "" if dynamic else "BA",
}
],
"nodes": {
"total": 11,
"input": 3,
"reshape": 3,
"cumsum": 1,
"constant": 1,
"subtract": 1,
"take": 1,
"where": 1,
},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 5, "shape": 1, "Int": 2, "Mul": 2}
verify_model(
MaskedScatter1(), [([dim], "float32"), ([dim], "bool"), ([10], "float32")], expected1
)
verify_model(
MaskedScatter2(),
[([2, dim], "float32"), ([2, dim], "bool"), ([3, dim], "float32")],
expected2,
)
@pytest.mark.parametrize("dynamic", [True, False])
def test_attention(dynamic: bool):
"""test graph builder for attention"""
# pylint: disable=import-outside-toplevel
import torch.nn.functional as F
seq = "seq" if dynamic else 128
class Attention1(Module):
def forward(self, q_data, k_data, v_data):
return F.scaled_dot_product_attention(q_data, k_data, v_data)
class Attention2(Module):
def forward(self, q_data, k_data, v_data):
return F.scaled_dot_product_attention(q_data, k_data, v_data, is_causal=True)
expected1 = {
"inputs": [
{"name": "inp_0", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"},
{"name": "inp_1", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"},
{"name": "inp_2", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"},
],
"outputs": [
{"name": "attention", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ABCD"}
],
"nodes": {"total": 4, "input": 3, "msc.attention": 1},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
input_info = [
([1, 8, seq, 64], "float32"),
([1, 8, seq, 64], "float32"),
([1, 8, seq, 64], "float32"),
]
verify_model(Attention1(), input_info, expected1)
verify_model(Attention2(), input_info, expected1)
class Attention3(Module):
def forward(self, q_data, k_data, v_data, mask):
return F.scaled_dot_product_attention(q_data, k_data, v_data, mask)
expected2 = {
"inputs": [
{"name": "inp_0", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"},
{"name": "inp_1", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"},
{"name": "inp_2", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"},
{"name": "inp_3", "shape": [1, 8, seq, seq], "dtype": "float32", "layout": "ABCD"},
],
"outputs": [
{
"name": "attention_bias",
"shape": [1, 8, seq, 64],
"dtype": "float32",
"layout": "ABCD",
}
],
"nodes": {"total": 5, "input": 4, "msc.attention": 1},
}
if dynamic:
expected2["prims"] = {"total": 1, "shape": 1}
verify_model(
Attention3(),
[
([1, 8, seq, 64], "float32"),
([1, 8, seq, 64], "float32"),
([1, 8, seq, 64], "float32"),
([1, 8, seq, seq], "float32"),
],
expected2,
)
if __name__ == "__main__":
tvm.testing.main()