| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import numpy as np |
| |
| import tvm |
| import tvm.testing |
| from tvm import te |
| from tvm.contrib import hipblas |
| |
| |
| def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5): |
| n = 1024 |
| l = 128 |
| m = 236 |
| A = te.placeholder((n, l), name="A", dtype=in_dtype) |
| B = te.placeholder((l, m), name="B", dtype=in_dtype) |
| C = hipblas.matmul(A, B, dtype=out_dtype) |
| |
| def verify(target="rocm"): |
| if not tvm.get_global_func("tvm.contrib.hipblas.matmul", True): |
| print("skip because extern function is not available") |
| return |
| dev = tvm.rocm(0) |
| f = tvm.compile(te.create_prim_func([A, B, C]), target=target) |
| a = tvm.runtime.tensor(np.random.uniform(0, 128, size=(n, l)).astype(A.dtype), dev) |
| b = tvm.runtime.tensor(np.random.uniform(0, 128, size=(l, m)).astype(B.dtype), dev) |
| c = tvm.runtime.tensor(np.zeros((n, m), dtype=C.dtype), dev) |
| f(a, b, c) |
| tvm.testing.assert_allclose( |
| c.numpy(), np.dot(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)), rtol=rtol |
| ) |
| |
| verify() |
| |
| |
| def roundoff(v, d): |
| return int(np.floor((v + d - 1) / d) * d) |
| |
| |
| def verify_batch_matmul(Ashape, Bshape, Cshape, in_dtype, out_dtype, rtol=1e-5): |
| A = te.placeholder(Ashape, name="A", dtype=in_dtype) |
| B = te.placeholder(Bshape, name="B", dtype=in_dtype) |
| C = hipblas.batch_matmul(A, B, dtype=out_dtype) |
| |
| dev = tvm.rocm(0) |
| f = tvm.compile(te.create_prim_func([A, B, C]), target="rocm") |
| |
| if "int" in in_dtype: |
| a = tvm.runtime.tensor(np.random.uniform(1, 10, size=Ashape).astype(in_dtype), dev) |
| b = tvm.runtime.tensor(np.random.uniform(1, 10, size=Bshape).astype(in_dtype), dev) |
| else: |
| a = tvm.runtime.tensor(np.random.uniform(size=Ashape).astype(A.dtype), dev) |
| b = tvm.runtime.tensor(np.random.uniform(size=Bshape).astype(B.dtype), dev) |
| |
| c = tvm.runtime.tensor(np.zeros(Cshape, dtype=C.dtype), dev) |
| f(a, b, c) |
| tvm.testing.assert_allclose( |
| c.numpy(), |
| np.matmul(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)).astype(C.dtype), |
| rtol=rtol, |
| ) |
| |
| |
| @tvm.testing.requires_rocm |
| def test_matmul_add(): |
| verify_matmul_add("float", "float", rtol=1e-3) |
| verify_matmul_add("float16", "float") |
| verify_matmul_add("float16", "float16", rtol=1e-2) |
| verify_matmul_add("int8", "int32") |
| |
| |
| @tvm.testing.requires_rocm |
| def test_batch_matmul(): |
| if not tvm.get_global_func("tvm.contrib.hipblas.batch_matmul", True): |
| print("skip because extern function is not available") |
| return |
| |
| verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float", "float") |
| verify_batch_matmul((16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float", "float") |
| verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float16", "float") |
| verify_batch_matmul((16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float") |
| verify_batch_matmul( |
| (16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2 |
| ) |
| verify_batch_matmul( |
| (16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2 |
| ) |
| |
| verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "int8", "int32") |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |