| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| "Example code to perform int8 GEMM" |
| import logging |
| import sys |
| import numpy as np |
| import tvm |
| from tvm import te |
| from tvm import autotvm |
| from tvm.topi.cuda.tensor_intrin import dp4a |
| |
| DO_TUNING = True |
| PRETUNED_INDEX = 75333 |
| |
| intrin_dp4a = dp4a("local", "local", "local") |
| |
| |
| @autotvm.template |
| def gemm_int8(n, m, l): |
| A = te.placeholder((n, l), name="A", dtype="int8") |
| B = te.placeholder((m, l), name="B", dtype="int8") |
| |
| k = te.reduce_axis((0, l), name="k") |
| C = te.compute( |
| (n, m), |
| lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k), |
| name="C", |
| ) |
| |
| cfg = autotvm.get_config() |
| s = te.create_schedule(C.op) |
| y, x = C.op.axis |
| |
| AA = s.cache_read(A, "shared", [C]) |
| BB = s.cache_read(B, "shared", [C]) |
| AL = s.cache_read(AA, "local", [C]) |
| BL = s.cache_read(BB, "local", [C]) |
| CC = s.cache_write(C, "local") |
| |
| k = CC.op.reduce_axis[0] |
| |
| cfg.define_split( |
| "tile_k", |
| cfg.axis(k), |
| num_outputs=3, |
| filter=lambda entity: entity.size[2] == 4 and entity.size[0] * 2 >= entity.size[1], |
| ) |
| |
| ko, kt, ki = cfg["tile_k"].apply(s, CC, k) |
| |
| s[CC].tensorize(ki, intrin_dp4a) |
| |
| block_x = te.thread_axis("blockIdx.x") |
| block_y = te.thread_axis("blockIdx.y") |
| thread_x = te.thread_axis("threadIdx.x") |
| thread_y = te.thread_axis("threadIdx.y") |
| |
| def block_size_filter(entity): |
| return ( |
| entity.size[0] * 2 >= entity.size[1] * 2 |
| and entity.size[1] <= 16 |
| and entity.size[3] <= 4 |
| ) |
| |
| cfg.define_split("tile_y", cfg.axis(y), num_outputs=4, filter=block_size_filter) |
| cfg.define_split("tile_x", cfg.axis(x), num_outputs=4, filter=block_size_filter) |
| by, tyz, ty, yi = cfg["tile_y"].apply(s, C, y) |
| bx, txz, tx, xi = cfg["tile_x"].apply(s, C, x) |
| |
| s[C].bind(by, block_y) |
| s[C].bind(bx, block_x) |
| s[C].bind(tyz, te.thread_axis("vthread")) |
| s[C].bind(txz, te.thread_axis("vthread")) |
| s[C].bind(ty, thread_y) |
| s[C].bind(tx, thread_x) |
| s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi) |
| |
| s[CC].compute_at(s[C], tx) |
| |
| yo, xo = CC.op.axis |
| s[CC].reorder(ko, kt, yo, xo, ki) |
| s[CC].unroll(kt) |
| |
| for stage in [AL, BL]: |
| s[stage].compute_at(s[CC], kt) |
| _, xi = s[stage].split(stage.op.axis[1], factor=4) |
| s[stage].vectorize(xi) |
| s[stage].double_buffer() |
| |
| cfg.define_knob("storage_align", [16, 48]) |
| for stage in [AA, BB]: |
| s[stage].storage_align(s[stage].op.axis[0], cfg["storage_align"].val, 0) |
| s[stage].compute_at(s[CC], ko) |
| |
| fused = s[stage].fuse(*s[stage].op.axis) |
| ty, tx = s[stage].split(fused, nparts=cfg["tile_y"].size[2]) |
| tx, xi = s[stage].split(tx, nparts=cfg["tile_x"].size[2]) |
| _, xi = s[stage].split(xi, factor=16) |
| |
| s[stage].bind(ty, thread_y) |
| s[stage].bind(tx, thread_x) |
| s[stage].vectorize(xi) |
| |
| cfg.define_knob("auto_unroll_max_step", [512, 1500]) |
| s[C].pragma(by, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) |
| s[C].pragma(by, "unroll_explicit", False) |
| |
| cfg.add_flop(n * m * l * 2) |
| return s, [A, B, C] |
| |
| |
| if __name__ == "__main__": |
| N = 2048 |
| n = m = l = N |
| |
| logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) |
| task = autotvm.task.create(gemm_int8, args=(n, m, l), target="cuda") |
| print(task.config_space) |
| |
| measure_option = autotvm.measure_option( |
| builder=autotvm.LocalBuilder(), |
| runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4), |
| ) |
| |
| log_name = "gemm_int8.log" |
| if DO_TUNING: |
| tuner = autotvm.tuner.XGBTuner(task) |
| tuner.tune( |
| n_trial=1000, |
| measure_option=measure_option, |
| callbacks=[autotvm.callback.log_to_file(log_name)], |
| ) |
| |
| dispatch_context = autotvm.apply_history_best(log_name) |
| best_config = dispatch_context.query(task.target, task.workload) |
| print("\nBest config:") |
| print(best_config) |
| else: |
| config = task.config_space.get(PRETUNED_INDEX) |
| dispatch_context = autotvm.task.ApplyConfig(config) |
| print("Using pretuned config:") |
| print(config) |
| |
| with dispatch_context: |
| with tvm.target.Target("cuda"): |
| s, arg_bufs = gemm_int8(n, m, l) |
| f = tvm.build(s, arg_bufs, "cuda", name="gemm_int8") |
| |
| ctx = tvm.context("cuda", 0) |
| |
| a_np = np.random.randint(size=(n, l), low=-128, high=127, dtype="int8") |
| b_np = np.random.randint(size=(m, l), low=-128, high=127, dtype="int8") |
| |
| a = tvm.nd.array(a_np, ctx) |
| b = tvm.nd.array(b_np, ctx) |
| c = tvm.nd.array(np.zeros((n, m), dtype="int32"), ctx) |
| f(a, b, c) |
| |
| tvm.testing.assert_allclose( |
| c.asnumpy(), np.dot(a_np.astype("int32"), b_np.T.astype("int32")), rtol=1e-5 |
| ) |
| |
| num_ops = 2 * l * m * n |
| num_runs = 1000 |
| timer_f = f.time_evaluator(f.entry_name, ctx, number=num_runs) |
| t = timer_f(a, b, c).mean |
| GOPS = num_ops / (t * 1e3) / 1e6 |
| print("average time cost of %d runs = %g ms, %g GOPS." % (num_runs, t * 1e3, GOPS)) |