blob: 3d290c0ae8c6ff96cec597b1e3dcf810e123c8b9 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tempfile
import tvm
import tvm.testing
import tvm.meta_schedule as ms
from tvm import relax
from tvm.ir import transform
from tvm.ir.module import IRModule
from tvm.ir.transform import PassContext
from tvm.script import relax as R
from tvm.script import tir as T
target = tvm.target.Target("llvm --num-cores=16")
@tvm.script.ir_module
class InputModule:
@T.prim_func
def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
T.func_attr({"global_symbol": "tir_matmul"})
k = T.int32()
A = T.match_buffer(x, (32, 32))
B = T.match_buffer(y, (32, 32))
C = T.match_buffer(z, (32, 32))
for i0, j0, k0 in T.grid(32, 32, 32):
with T.block():
i, j, k = T.axis.remap("SSR", [i0, j0, k0])
with T.init():
C[i, j] = 0.0
C[i, j] += A[i, k] * B[j, k]
@T.prim_func
def tir_relu(x: T.handle, y: T.handle):
T.func_attr({"global_symbol": "tir_relu"})
A = T.match_buffer(x, (32, 32))
B = T.match_buffer(y, (32, 32))
for i, j in T.grid(32, 32):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = T.max(A[vi, vj], 0.0)
@R.function
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor:
cls = InputModule
with R.dataflow():
lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32), dtype="float32"))
R.output(lv1)
return lv1
# TODO(@sunggg): determine how to pass MS database object across different passes.
# PassContext might be an option, but we already have TuningAPI database.
# (MS database and TuningAPI database will be unified in the future)
# For now, we only support default JSON database config.
def test_ms_tuning_irmodule():
mod = InputModule
assert isinstance(mod, IRModule)
with tempfile.TemporaryDirectory() as work_dir:
"""
# TODO(@sunggg): revisit when ready
with target, PassContext(trace=Trace(mod), opt_level=0):
tuning_pass = relax.transform.MetaScheduleTuneIRMod(
params={}, work_dir=work_dir, max_trials_global=4
)
out_mod = tuning_pass(mod)
assert PassContext.current().get_trace_stack_size() == 1
assert PassContext.current().get_current_trace().size == 1
tvm.ir.assert_structural_equal(mod, out_mod)
"""
with target, PassContext(opt_level=0):
tuning_pass = relax.transform.MetaScheduleTuneIRMod(
params={}, work_dir=work_dir, max_trials_global=4
)
out_mod = tuning_pass(mod)
application_pass = relax.transform.MetaScheduleApplyDatabase(work_dir)
out_mod = application_pass(mod)
assert not tvm.ir.structural_equal(mod, out_mod)
def test_ms_tuning_primfunc():
mod = InputModule
assert isinstance(mod, IRModule)
with tempfile.TemporaryDirectory() as work_dir:
"""
# TODO(@sunggg): revisit when ready
with target, PassContext(trace=Trace(mod), opt_level=0):
tuning_pass = relax.transform.MetaScheduleTuneTIR(
work_dir=work_dir, max_trials_global=4
)
out_mod = tuning_pass(mod)
assert PassContext.current().get_trace_stack_size() == 1
# TODO (@sunggg): Need to determine how to track subgraph-level tuning traces.
# Currently, we don't track this so the trace size. Revisit this later.
tvm.ir.assert_structural_equal(mod, out_mod)
"""
with target, PassContext(opt_level=0):
tuning_pass = relax.transform.MetaScheduleTuneIRMod(
params={}, work_dir=work_dir, max_trials_global=4
)
out_mod = tuning_pass(mod)
application_pass = relax.transform.MetaScheduleApplyDatabase(work_dir)
out_mod = application_pass(mod)
assert not tvm.ir.structural_equal(mod, out_mod)
with tempfile.TemporaryDirectory() as work_dir:
with target, PassContext(opt_level=0):
tuning_pass = relax.transform.MetaScheduleTuneIRMod(
params={},
work_dir=work_dir,
max_trials_global=4,
max_trials_per_task=2,
op_names=["matmul"],
)
tuning_pass(mod)
db = ms.database.JSONDatabase(
work_dir + "/database_workload.json", work_dir + "/database_tuning_record.json"
)
assert len(db.get_all_tuning_records()) == 2
for rec in db.get_all_tuning_records():
assert rec.workload.mod["main"].attrs["global_symbol"] == "tir_matmul"
@tvm.script.ir_module
class DefaultScheduledModule:
@T.prim_func
def tir_matmul(
A: T.Buffer((32, 32), "float32"),
B: T.Buffer((32, 32), "float32"),
C: T.Buffer((32, 32), "float32"),
):
T.func_attr({"global_symbol": "tir_matmul", "tir.is_scheduled": True})
# with T.block("root"):
for i0_j0_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for i0_j0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
for k0 in range(32):
with T.block(""):
i = T.axis.spatial(32, (i0_j0_fused_0 * 1024 + i0_j0_fused_1) // 32)
j = T.axis.spatial(32, (i0_j0_fused_0 * 1024 + i0_j0_fused_1) % 32)
k = T.axis.reduce(32, k0)
T.reads(A[i, k], B[j, k])
T.writes(C[i, j])
with T.init():
C[i, j] = T.float32(0)
C[i, j] = C[i, j] + A[i, k] * B[j, k]
@T.prim_func
def tir_relu(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32, 32), "float32")):
T.func_attr({"global_symbol": "tir_relu", "tir.is_scheduled": True})
# with T.block("root"):
for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block(""):
vi = T.axis.spatial(32, (i_j_fused_0 * 1024 + i_j_fused_1) // 32)
vj = T.axis.spatial(32, (i_j_fused_0 * 1024 + i_j_fused_1) % 32)
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = T.max(A[vi, vj], T.float32(0))
@R.function
def main(
x: R.Tensor((32, 32), dtype="float32"), w: R.Tensor((32, 32), dtype="float32")
) -> R.Tensor((32, 32), dtype="float32"):
with R.dataflow():
lv0 = R.call_tir(
DefaultScheduledModule.tir_matmul,
(x, w),
out_sinfo=R.Tensor((32, 32), dtype="float32"),
)
lv1 = R.call_tir(
DefaultScheduledModule.tir_relu,
(lv0,),
out_sinfo=R.Tensor((32, 32), dtype="float32"),
)
R.output(lv1)
return lv1
def test_ms_database_apply_fallback():
mod = InputModule
target_cuda = tvm.target.Target("nvidia/geforce-rtx-3090-ti")
assert isinstance(mod, IRModule)
with tempfile.TemporaryDirectory() as work_dir:
"""
with target_cuda, PassContext(trace=Trace(mod), opt_level=0):
tuning_pass = relax.transform.MetaScheduleTuneTIR(
work_dir=work_dir, max_trials_global=0
)
out_mod = tuning_pass(mod)
tvm.ir.assert_structural_equal(mod, out_mod)
"""
with target_cuda, PassContext(opt_level=0):
tuning_pass = relax.transform.MetaScheduleTuneTIR(
work_dir=work_dir, max_trials_global=0
)
out_mod = tuning_pass(mod)
default_pass = tvm.tir.transform.DefaultGPUSchedule()
out_mod = default_pass(mod)
tvm.ir.assert_structural_equal(out_mod, DefaultScheduledModule)
if __name__ == "__main__":
tvm.testing.main()