blob: f386f8f2f8d0d5347256d62ea1619fee1394d24a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import numpy as np
import tvm
import tvm.testing
from tvm import relax
from tvm.script import relax as R
from tvm.relax.dpl import make_fused_bias_activation_pattern
from tvm.contrib.pickle_memoize import memoize
@tvm.script.ir_module
class Conv2dReLUx2:
@R.function
def main(
data: R.Tensor((1, 64, 56, 56), "float32"),
weight1: R.Tensor((64, 64, 3, 3), "float32"),
weight2: R.Tensor((64, 64, 3, 3), "float32"),
):
with R.dataflow():
conv1 = relax.op.nn.relu(relax.op.nn.conv2d(data, weight1, padding=(1, 1)))
conv2 = relax.op.nn.relu(relax.op.nn.conv2d(conv1, weight2, padding=(0, 0)))
R.output(conv2)
return conv2
has_dnnl = tvm.get_global_func("relax.ext.dnnl", True)
dnnl_enabled = pytest.mark.skipif(
not has_dnnl,
reason="DNNL note enabled.",
)
pytestmark = [dnnl_enabled]
def build_and_run(mod, inputs, legalize=False):
target = tvm.target.Target("llvm")
dev = tvm.cpu()
inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs]
with tvm.transform.PassContext(config={"relax.transform.apply_legalize_ops": legalize}):
ex = tvm.compile(mod, target)
vm = relax.VirtualMachine(ex, dev)
f = vm["main"]
return f(*inputs).numpy()
def test_dnnl_offload():
pat = make_fused_bias_activation_pattern(
"relax.nn.conv2d", with_bias=False, activation="relax.nn.relu"
)
seq = tvm.transform.Sequential(
[
relax.transform.FuseOpsByPattern([("dnnl.conv2d_relu", pat)]),
relax.transform.MergeCompositeFunctions(),
relax.transform.RunCodegen(),
]
)
@memoize("relax.tests.test_codegen_dnnl.conv2d_relu_x2")
def get_ref():
data_np = np.random.randn(1, 64, 56, 56).astype("float32")
weight1_np = np.random.randn(64, 64, 3, 3).astype("float32")
weight2_np = np.random.randn(64, 64, 3, 3).astype("float32")
inputs = [data_np, weight1_np, weight2_np]
ref = build_and_run(Conv2dReLUx2, inputs, legalize=True)
return inputs, ref
inputs, ref = get_ref()
out = build_and_run(seq(Conv2dReLUx2), inputs)
tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)
if __name__ == "__main__":
test_dnnl_offload()