blob: 208d584e99a52e45be942188e2199262c09ad60f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import pytest
import tvm
import tvm.testing
from tvm.script import tir as T
@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j, k in T.grid(128, 128, 128):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
@tvm.testing.requires_cuda
@pytest.mark.parametrize("f_preproc", ["", "l2_cache_flush_cuda"])
def test_time_evalutor_with_preproc(f_preproc: str):
mod = tvm.IRModule.from_expr(matmul.with_attr("global_symbol", "main"))
sch = tvm.tir.Schedule(mod)
blk = sch.get_block("matmul")
i, j, k = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.x")
f = tvm.tir.build(sch.mod["main"], target="cuda")
dev = tvm.cuda(0)
evaluator = f.time_evaluator(f.entry_name, dev, repeat=1000, number=1, f_preproc=f_preproc)
a = tvm.runtime.tensor(np.random.rand(128, 128).astype("float32"), device=dev)
b = tvm.runtime.tensor(np.random.rand(128, 128).astype("float32"), device=dev)
c = tvm.runtime.tensor(np.zeros((128, 128)).astype("float32"), device=dev)
args = [a, b, c]
print("Evaluator (f_preproc={}):\t{:.5f}ms".format(f_preproc, evaluator(*args).mean * 1000))
if __name__ == "__main__":
test_time_evalutor_with_preproc("l2_cache_flush_cuda")