| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import pytest |
| import numpy as np |
| import tvm |
| import tvm.testing |
| |
| from tvm import relax |
| from tvm.script import relax as R |
| from tvm.relax.dpl import make_fused_bias_activation_pattern |
| from tvm.contrib.pickle_memoize import memoize |
| |
| |
| @tvm.script.ir_module |
| class Conv2dReLUx2: |
| @R.function |
| def main( |
| data: R.Tensor((1, 64, 56, 56), "float32"), |
| weight1: R.Tensor((64, 64, 3, 3), "float32"), |
| weight2: R.Tensor((64, 64, 3, 3), "float32"), |
| ): |
| with R.dataflow(): |
| conv1 = relax.op.nn.relu(relax.op.nn.conv2d(data, weight1, padding=(1, 1))) |
| conv2 = relax.op.nn.relu(relax.op.nn.conv2d(conv1, weight2, padding=(0, 0))) |
| R.output(conv2) |
| |
| return conv2 |
| |
| |
| has_dnnl = tvm.get_global_func("relax.ext.dnnl", True) |
| |
| dnnl_enabled = pytest.mark.skipif( |
| not has_dnnl, |
| reason="DNNL note enabled.", |
| ) |
| |
| pytestmark = [dnnl_enabled] |
| |
| |
| def build_and_run(mod, inputs, legalize=False): |
| target = tvm.target.Target("llvm") |
| dev = tvm.cpu() |
| inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs] |
| |
| with tvm.transform.PassContext(config={"relax.transform.apply_legalize_ops": legalize}): |
| ex = tvm.compile(mod, target) |
| vm = relax.VirtualMachine(ex, dev) |
| f = vm["main"] |
| return f(*inputs).numpy() |
| |
| |
| def test_dnnl_offload(): |
| pat = make_fused_bias_activation_pattern( |
| "relax.nn.conv2d", with_bias=False, activation="relax.nn.relu" |
| ) |
| |
| seq = tvm.transform.Sequential( |
| [ |
| relax.transform.FuseOpsByPattern([("dnnl.conv2d_relu", pat)]), |
| relax.transform.MergeCompositeFunctions(), |
| relax.transform.RunCodegen(), |
| ] |
| ) |
| |
| @memoize("relax.tests.test_codegen_dnnl.conv2d_relu_x2") |
| def get_ref(): |
| data_np = np.random.randn(1, 64, 56, 56).astype("float32") |
| weight1_np = np.random.randn(64, 64, 3, 3).astype("float32") |
| weight2_np = np.random.randn(64, 64, 3, 3).astype("float32") |
| inputs = [data_np, weight1_np, weight2_np] |
| ref = build_and_run(Conv2dReLUx2, inputs, legalize=True) |
| return inputs, ref |
| |
| inputs, ref = get_ref() |
| |
| out = build_and_run(seq(Conv2dReLUx2), inputs) |
| |
| tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) |
| |
| |
| if __name__ == "__main__": |
| test_dnnl_offload() |