blob: e78636a2ff2face22e69939b050cd579c95be4ee [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_webgpu.cc
*/
#include "codegen_webgpu.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/cast.h>
#include <tvm/ffi/extra/json.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/support/io.h>
#include <tvm/tirx/builtin.h>
#include <tvm/tirx/transform.h>
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "../../arith/pattern_match.h"
#include "../../runtime/file_utils.h"
#include "../../runtime/metadata.h"
#include "../../runtime/thread_storage_scope.h"
#include "../../support/bytes_io.h"
#include "../build_common.h"
#include "webgpu_fallback_module.h"
namespace tvm {
namespace codegen {
// WebGPU Info
struct WebGPUWorkGroupInfo {
int workgroup_size[3] = {1, 1, 1};
// whether we have ref to block index z is used.
bool has_block_index_z{false};
// set of handles that have write access
std::unordered_set<Var> write_access_set;
};
class WebGPUWorkgroupInfoCollector : public StmtExprVisitor {
public:
static WebGPUWorkGroupInfo Collect(const Stmt& stmt) {
WebGPUWorkgroupInfoCollector collector;
collector(stmt);
return collector.info_;
}
private:
void VisitExpr_(const VarNode* op) final {
StmtExprVisitor::VisitExpr_(op);
Var buffer_var = ffi::GetRef<Var>(op);
if (buffer_var.dtype().is_handle()) {
info_.write_access_set.insert(buffer_var);
}
}
void VisitStmt_(const BufferStoreNode* op) final {
StmtExprVisitor::VisitStmt_(op);
info_.write_access_set.insert(op->buffer->data);
}
void VisitStmt_(const AttrStmtNode* op) final {
// record workgroup size
if (op->attr_key == tirx::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag.length() != 0) {
runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
if (ts.rank == 1) {
TVM_FFI_ICHECK_GE(ts.dim_index, 0) << "vthread should have been optimized out by here";
TVM_FFI_ICHECK_LT(ts.dim_index, 3);
auto* sizeptr = op->value.as<tirx::IntImmNode>();
TVM_FFI_ICHECK(sizeptr) << "CodeGenWebGPU: only allows constant thread group size "
<< " get " << op->value;
info_.workgroup_size[ts.dim_index] = static_cast<uint32_t>(sizeptr->value);
} else if (ts.rank == 0) {
if (ts.dim_index == 2) {
info_.has_block_index_z = true;
}
}
}
}
// normal operation
StmtExprVisitor::VisitStmt_(op);
}
WebGPUWorkGroupInfo info_;
};
std::string CodeGenWebGPU::Finish() {
// Using f16 requires enable directive
if (enable_fp16_) {
header_stream << "enable f16;\n\n";
}
if (enable_subgroups_) {
header_stream << "enable subgroups;\n\n";
}
return header_stream.str() + decl_stream.str() + this->fwd_decl_stream.str() + stream.str();
}
void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f);
// analyze the data;
for (Var arg : f->params) {
if (arg.dtype().is_handle()) {
alloc_storage_scope_[arg.get()] = "global";
}
}
}
CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {
enable_subgroups_ = target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false));
}
runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_readonly_decl) {
// clear previous generated state.
this->InitFuncState(f);
// reserve keywords
name_supply_->ReserveName("var");
name_supply_->ReserveName("let");
name_supply_->ReserveName("const");
name_supply_->ReserveName("std");
name_supply_->ReserveName("storage");
name_supply_->ReserveName("uniform");
name_supply_->ReserveName("workgroup");
name_supply_->ReserveName("private");
name_supply_->ReserveName("function");
name_supply_->ReserveName("read");
name_supply_->ReserveName("read_write");
// skip the first underscore, so SSA variable starts from
name_supply_->FreshName("v_");
// Setup the thread group info.
TVM_FFI_ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
TVM_FFI_ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
TVM_FFI_ICHECK_EQ(name_supply_->FreshName("gridDim"), "gridDim");
// add to alloc buffer type.
auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
TVM_FFI_ICHECK(global_symbol.has_value())
<< "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute";
header_stream << "//----------------------------------------\n"
<< "// Function: " << global_symbol.value() << "\n"
<< "//----------------------------------------\n";
ffi::String func_name = global_symbol.value();
ffi::Array<DLDataType> func_arg_types;
ffi::Array<ffi::String> func_launch_param_tags;
WebGPUWorkGroupInfo info = WebGPUWorkgroupInfoCollector::Collect(f->body);
std::vector<Var> pod_args;
int num_buffer = 0;
// add param_access modes info to launch params
std::ostringstream os_param_access;
os_param_access << "paramWriteAccess:[";
// setup buffer argumemts
for (Var arg : f->params) {
DataType t = arg.dtype();
func_arg_types.push_back(t);
if (t.is_handle()) {
auto* ptr = arg->type_annotation.as<PointerTypeNode>();
TVM_FFI_ICHECK(ptr)
<< "All handles passed to the CodeGenWebGPU must have a type_annotation as a "
"PointerType, "
<< "and must point to a PrimType";
auto* prim = ptr->element_type.as<PrimTypeNode>();
TVM_FFI_ICHECK(prim)
<< "All handles passed to the CodeGenWebGPU must have a type_annotation as a "
"PointerType, "
<< "and must point to a PrimType";
DataType value_storage_type = prim->dtype;
if (value_storage_type == DataType::Bool()) {
// We need a physically addressable buffer type to support boolean tensors.
// The loaded byte is cast to bool inside the LoadNode visitor below.
value_storage_type = boolean_storage_type_.with_lanes(value_storage_type.lanes());
}
std::string vid = AllocVarID(arg.get());
std::string access_mode;
if (num_buffer != 0) {
os_param_access << ",";
}
if (skip_readonly_decl || info.write_access_set.count(arg)) {
access_mode = "read_write";
os_param_access << "1";
} else {
access_mode = "read";
os_param_access << "0";
}
// add extra access mode info to launch params
this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") "
<< "var<storage, " << access_mode << "> " << vid << " : array<";
this->PrintType(value_storage_type, this->decl_stream);
this->decl_stream << ">;\n";
} else {
pod_args.push_back(arg);
}
}
// Store all pod arguments in a single buffer of int32
// do bitcast to change to other data types
// always pass gridDimX in to get around of the 65535 gridDim
// restrictions in some platforms
std::string type_pod_args = name_supply_->FreshName("PODArgs");
std::string val_pod_args = name_supply_->FreshName("podArgs");
std::string packGridDimX = name_supply_->FreshName("packGridDimX");
this->decl_stream << "\nstruct " << type_pod_args << " {\n";
for (size_t i = 0; i < pod_args.size(); ++i) {
Var v = pod_args[i];
TVM_FFI_ICHECK(!v.dtype().is_handle());
std::string vid = AllocVarID(v.get());
if (v.dtype() == DataType::Int(32)) {
this->decl_stream << " " << vid << ": i32";
} else if (v.dtype() == DataType::UInt(32)) {
this->decl_stream << " " << vid << ": u32";
} else if (v.dtype() == DataType::Float(32)) {
this->decl_stream << " " << vid << ": f32";
} else {
TVM_FFI_THROW(InternalError) << "Do not support pod argument type " << v.dtype();
}
this->decl_stream << ",\n";
// value ref
std::ostringstream vref;
vref << val_pod_args << "." << vid;
var_idmap_[v.get()] = vref.str();
}
this->decl_stream << " " << packGridDimX << ": u32\n}\n";
this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") "
<< "var<uniform> " << val_pod_args << " : " << type_pod_args << ";\n\n";
// setup thread tags and param access in launch param tags;
if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(tirx::attr::kKernelLaunchParams)) {
for (const auto& thread_tag : opt.value()) {
func_launch_param_tags.push_back(thread_tag);
}
}
os_param_access << "]";
func_launch_param_tags.push_back(os_param_access.str());
TVM_FFI_ICHECK(!info.has_block_index_z)
<< "blockIdx.z is not supported in WebGPU to accomodate large blockIdx.x";
// anotate workgroup
this->stream << "@compute @workgroup_size(" << info.workgroup_size[0] << ", "
<< info.workgroup_size[1] << ", " << info.workgroup_size[2] << ")\n";
// add to alloc buffer type.
// Function header.
this->stream << "fn " << func_name << "(\n"
<< " @builtin(workgroup_id) blockIdx : vec3<u32>,\n"
<< " @builtin(num_workgroups) gridDim : vec3<u32>,\n"
<< " @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
<< ") {\n";
// skip out of bound grids
this->stream << " if (blockIdx.z * gridDim.x + blockIdx.x > " // NOLINT(*)
<< val_pod_args << "." << packGridDimX << ") { return; }\n";
// the function scope.
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
return runtime::FunctionInfo(std::move(func_name), std::move(func_arg_types),
std::move(func_launch_param_tags), {});
}
void CodeGenWebGPU::BindThreadIndex(const IterVar& iv) {
TVM_FFI_ICHECK(!var_idmap_.count(iv->var.get()));
std::ostringstream os;
PrintType(iv->var.dtype(), os);
if (iv->thread_tag == "blockIdx.x") {
// WebGPU have restriction to limit the maximum size of blockId.x to be 65535
// We allow runtime to spread the load out to blockIdx.z so it can be a large number.
os << "(blockIdx.z * gridDim.x + blockIdx.x)";
std::string tidx = os.str();
std::string aggregated_bidx = SSAGetID(os.str(), iv->var.dtype());
var_idmap_[iv->var.get()] = aggregated_bidx;
} else {
os << "(" << iv->thread_tag << ")";
std::string tidx = os.str();
this->MarkConst(tidx);
var_idmap_[iv->var.get()] = tidx;
}
}
void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
TVM_FFI_THROW(InternalError) << "Cannot print handle type in WebGPU";
}
if (t.is_void()) {
os << "void";
return;
}
if (t == DataType::Bool()) {
os << "bool";
return;
}
if (lanes != 1) {
TVM_FFI_ICHECK(lanes >= 2 && lanes <= 4)
<< "CodeGenWebGPU: only allows vector with lanes in {2, 3, 4}";
// Currently WebGPU doesn't support `i8` and an `int8x4` is represented as a `u32`.
if (t.is_int() && t.bits() == 8 && lanes == 4) {
os << "u32";
return;
}
os << "vec" << lanes << "<";
}
if (t.is_float()) {
TVM_FFI_ICHECK(t.bits() == 16 || t.bits() == 32) << "CodeGenWebGPU: only support f16 or f32";
if (t.bits() == 16) {
// Using f16 requires enable directive
enable_fp16_ = true;
}
os << "f" << t.bits();
} else if (t.is_uint()) {
TVM_FFI_ICHECK(t.bits() != 64) << "CodeGenWebGPU: do not support u64";
os << "u" << t.bits();
} else if (t.is_int()) {
TVM_FFI_ICHECK(t.bits() != 64) << "CodeGenWebGPU: do not support i64";
os << "i" << t.bits();
} else {
TVM_FFI_THROW(InternalError) << "CodeGenWebGPU: Cannot convert type " << t << " to WebGPU type";
}
if (lanes != 1) {
os << ">";
}
}
void CodeGenWebGPU::PrintStorageSync(const CallNode* op) {
const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
this->PrintIndent();
this->stream << "workgroupBarrier();\n";
} else if (sync == "shared") {
this->PrintIndent();
this->stream << "workgroupBarrier();\n";
} else if (sync == "global") {
TVM_FFI_THROW(InternalError) << "global barrier not supported";
}
}
void CodeGenWebGPU::PrintSSAAssign(const std::string& target, const std::string& src,
DataType type) {
stream << "let " << target << " : ";
PrintType(type, stream);
stream << " = " << src << ";\n";
}
void CodeGenWebGPU::PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) { // NOLINT(*)
os << vec << "[" << i << "]";
}
void CodeGenWebGPU::PrintVecElemStore(const std::string& vec, DataType t, int i,
const std::string& value) {
this->PrintIndent();
stream << vec << "[" << i << "] = " << value << ";\n";
}
void CodeGenWebGPU::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
int lanes = op->dtype.lanes();
PrintType(op->dtype, os);
os << "(";
for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", ";
os << v;
}
os << ')';
}
PrimExpr CodeGenWebGPU::EnforceU32(PrimExpr value) {
return cast(DataType::UInt(32, value.dtype().lanes()), value);
}
void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
if (op->op.same_as(builtin::reinterpret())) {
// generate bitcast<TYPE>(ARG)
os << "bitcast<";
this->PrintType(op->dtype, os);
os << ">(";
this->PrintExpr(op->args[0], os);
os << ")";
} else if (op->op.same_as(builtin::shift_right())) {
os << '(';
this->PrintExpr(op->args[0], os);
os << ">>";
// WebGPU requires shift bits to be u32.
this->PrintExpr(EnforceU32(op->args[1]), os);
os << ')';
} else if (op->op.same_as(builtin::shift_left())) {
os << '(';
this->PrintExpr(op->args[0], os);
os << "<<";
// WebGPU requires shift bits to be u32.
this->PrintExpr(EnforceU32(op->args[1]), os);
os << ')';
} else if (op->op.same_as(builtin::if_then_else())) {
// conditional that skips eval if cond evals to false
std::string result = name_supply_->FreshName("condval");
std::string cond = PrintExpr(op->args[0]);
this->PrintIndent();
this->stream << "var " << result << " : ";
PrintType(op->dtype, this->stream);
this->stream << ";\n";
this->PrintIndent();
this->stream << "if (" << cond << ") {\n";
{
int then_scope = this->BeginScope();
std::string true_val = PrintExpr(op->args[1]);
this->PrintIndent();
this->stream << result << " = " << true_val << ";\n} else {\n";
this->EndScope(then_scope);
}
{
int else_scope = this->BeginScope();
std::string false_val = PrintExpr(op->args[2]);
this->PrintIndent();
this->stream << result << " = " << false_val << ";\n}\n";
this->EndScope(else_scope);
}
os << result;
} else if (op->op.same_as(builtin::dp4a())) {
// generate `dot4I8Packed(vec1, vec2) + acc` for the builtin `dp4a`
os << "dot4I8Packed(";
this->PrintExpr(op->args[0], os);
os << ", ";
this->PrintExpr(op->args[1], os);
os << ") + ";
this->PrintExpr(op->args[2], os);
} else {
CodeGenC::VisitExpr_(op, os);
}
}
void CodeGenWebGPU::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*)
PrintType(op->dtype, os);
os << "(" << PrintExpr(op->value) << ")";
}
void CodeGenWebGPU::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*)
os << "select(" << PrintExpr(op->false_value) << ", " << PrintExpr(op->true_value) << ", "
<< PrintExpr(op->condition) << ")";
}
void CodeGenWebGPU::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
// use ssa form.
if (print_ssa_form_) {
std::string value = PrintExpr(op->value);
TVM_FFI_ICHECK(!var_idmap_.count(op->var.get()));
var_idmap_[op->var.get()] = value;
} else {
PrintIndent();
std::string value = PrintExpr(op->value);
this->stream << "let " << AllocVarID(op->var.get()) << " : ";
PrintType(op->var.dtype(), this->stream);
this->stream << " = " << value << ";\n";
}
os << PrintExpr(op->body);
// Pop the defined var from var_idmap when exiting its scope.
// We do this because it is hard to completely avoid a same LetNode appearing
// at different places.
bool removed = var_idmap_.erase(op->var.get());
TVM_FFI_ICHECK(removed);
}
void CodeGenWebGPU::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*)
if (op->dtype.bits() == 32) {
std::ostringstream temp;
if (op->dtype.is_int()) {
temp << op->value << "i";
} else {
TVM_FFI_ICHECK(op->dtype.is_uint());
temp << op->value << "u";
}
this->MarkConst(temp.str());
os << temp.str();
} else {
this->PrintType(op->dtype, os);
os << "(" << op->value << ")";
}
}
void CodeGenWebGPU::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
std::ostringstream temp;
temp << std::scientific << op->value;
if (op->dtype.bits() == 32) {
temp << 'f';
} else if (op->dtype.bits() == 16) {
// Using f16 requires enable directive
enable_fp16_ = true;
temp << 'h';
} else {
TVM_FFI_THROW(InternalError) << "Unsupported floating point bits " << op->dtype.bits();
}
MarkConst(temp.str());
os << temp.str();
}
void CodeGenWebGPU::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*)
// NOTE: direct impl of load/store for correctness
// Each printing stmt must stand on their own after all preprocessing steps
// to ensure correctness in the case of nested-expression
// do not try to lift common printings from each case
TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported.";
TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported.";
DataType value_dtype = op->dtype;
PrimExpr index = op->indices[0];
Var buffer_var = op->buffer->data;
DataType element_dtype = op->buffer->dtype;
int lanes = op->dtype.lanes();
std::string buffer_vid = GetVarID(buffer_var.get());
if (value_dtype.lanes() == element_dtype.lanes()) {
// Direct buffer loading
// Special handle bool loading
if (value_dtype == DataType::Bool()) {
this->PrintType(value_dtype, os);
os << "(";
} else {
TVM_FFI_ICHECK(value_dtype == element_dtype);
}
TVM_FFI_ICHECK_EQ(index.dtype().lanes(), 1);
os << buffer_vid << "[" << this->PrintExpr(index) << "]";
// Special handle bool loading
if (value_dtype == DataType::Bool()) {
os << ")";
}
} else {
// Vector load from scalar buffer
TVM_FFI_ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array";
TVM_FFI_ICHECK(value_dtype.element_of() == element_dtype)
<< "WebGPU vector loading requires base type to match";
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) {
// vec3<f32>(buf[base + 0], buf[base + 1], buf[base + 2]);
std::string base_vid = SSAGetID(PrintExpr(base.Eval()), base.Eval().dtype());
PrintType(element_dtype.with_lanes(value_dtype.lanes()), os);
os << "(";
for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", ";
os << buffer_vid << "[" << base_vid << " + " << i << "]";
}
os << ")";
} else {
// vec3<f32>(buf[index[0]], buf[index[1]], buf[index[2]]);
std::string index_vid = SSAGetID(PrintExpr(index), index.dtype());
PrintType(element_dtype.with_lanes(value_dtype.lanes()), os);
os << "(";
for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", ";
os << buffer_vid << "[" << index_vid << "[" << i << "]]";
}
os << ")";
}
}
}
void CodeGenWebGPU::VisitStmt_(const BindNode* op) {
// use ssa form.
if (print_ssa_form_) {
std::string value = PrintExpr(op->value);
TVM_FFI_ICHECK(!var_idmap_.count(op->var.get()));
var_idmap_[op->var.get()] = value;
} else {
PrintIndent();
std::string value = PrintExpr(op->value);
this->stream << "let " << AllocVarID(op->var.get()) << " : ";
PrintType(op->var.dtype(), this->stream);
this->stream << " = " << value << ";\n";
}
}
void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) {
TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported.";
TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer store is not supported.";
DataType value_dtype = op->value.dtype();
DataType element_dtype = op->buffer->dtype;
PrimExpr index = op->indices[0];
Var buffer_var = op->buffer->data;
std::string buffer_vid = GetVarID(buffer_var.get());
if (value_dtype.lanes() == element_dtype.lanes()) {
// must execute print expr first
// so we won't have recursive append to stream
std::string index_vid = PrintExpr(index);
std::string value_vid = PrintExpr(op->value);
// now print the assignment line.
this->PrintIndent();
stream << buffer_vid << "[" << index_vid << "] = ";
// special explicit conversion of bool
if (value_dtype == DataType::Bool()) {
PrintType(element_dtype, stream);
stream << "(";
} else {
TVM_FFI_ICHECK(value_dtype == element_dtype);
}
stream << value_vid;
// Special handle bool store
if (value_dtype == DataType::Bool()) {
stream << ")";
}
stream << ";\n";
} else {
// Vector store into scalar buffer
TVM_FFI_ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array";
TVM_FFI_ICHECK(value_dtype.element_of() == element_dtype)
<< "WebGPU vector stire requires base type to match";
std::string value_vid = PrintExpr(op->value);
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, value_dtype.lanes()).Match(index)) {
// buf[base + 0] = value[0]
// buf[base + 1] = value[1]
std::string base_vid = SSAGetID(PrintExpr(base.Eval()), base.Eval().dtype());
for (int i = 0; i < value_dtype.lanes(); ++i) {
this->PrintIndent();
stream << buffer_vid << "[" << base_vid << " + " << i << "] = " << value_vid << "[" << i
<< "];\n";
}
} else {
// buf[index[0]] = value[0]
// buf[index[1]] = value[1]
std::string index_vid = SSAGetID(PrintExpr(index), index.dtype());
for (int i = 0; i < value_dtype.lanes(); ++i) {
this->PrintIndent();
stream << buffer_vid << "[" << index_vid << "[" << i << "]] = " << value_vid << "[" << i
<< "];\n";
}
}
}
}
void CodeGenWebGPU::VisitStmt_(const AllocBufferNode* op) {
TVM_FFI_ICHECK(op->buffer.defined());
std::string vid = AllocVarID(op->buffer->data.get());
size_t constant_size = 1;
for (const auto& dim : op->buffer->shape) {
const IntImmNode* dim_imm = dim.as<IntImmNode>();
TVM_FFI_ICHECK(dim_imm) << "Can only handle constant size stack allocation for now";
constant_size *= dim_imm->value;
}
TVM_FFI_ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
if (storage_scope.rank == runtime::StorageRank::kShared) {
this->decl_stream << "var<workgroup> " << vid << " : array<";
PrintType(op->buffer->dtype, this->decl_stream);
this->decl_stream << ", " << constant_size << ">;\n";
} else if (storage_scope.rank == runtime::StorageRank::kLocal) {
this->PrintIndent();
this->stream << "var " << vid << " : array<";
PrintType(op->buffer->dtype, this->stream);
this->stream << ", " << constant_size << ">;\n";
} else {
TVM_FFI_THROW(InternalError) << "WebGPU: Do not support storage scope: "
<< storage_scope.to_string();
}
}
void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
std::string begin_str = PrintExpr(op->min);
PrimExpr end = is_zero(op->min) ? op->extent : arith::Analyzer().Simplify(op->min + op->extent);
std::string end_str = PrintExpr(end);
std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : "";
std::string vid = AllocVarID(op->loop_var.get());
PrintIndent();
stream << "for (var " << vid << " : ";
PrintType(op->loop_var.dtype(), stream);
stream << " = " << begin_str << "; " << vid << " < " << end_str << "; " << vid;
if (step_str.empty()) {
stream << "++";
} else {
stream << " += " << step_str;
}
stream << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
PrintIndent();
stream << "}\n";
}
void CodeGenWebGPU::VisitStmt_(const AssertStmtNode* op) {
// skip assert — AssertStmt is a leaf, nothing to emit.
}
void CodeGenWebGPU::VisitStmt_(const WhileNode* op) {
PrintIndent();
stream << "while (true) {\n";
int while_scope = BeginScope();
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if (!(" << cond << ")) { break; }\n";
PrintStmt(op->body);
this->EndScope(while_scope);
PrintIndent();
stream << "}\n";
}
//-------------------------------------------------
// Build logic.
//-------------------------------------------------
//
// The "C++ side" canonical WebGPU module is `WebGPUFallbackModuleNode` in
// src/target/webgpu/webgpu_fallback_module.{h,cc} — there is no native
// WebGPU runtime in the C++ tree (the real receiver is the wasm runtime
// in web/emcc/webgpu_runtime.cc).
ffi::Module BuildWebGPU(IRModule mod, Target target) {
mod = tirx::transform::PointerValueTypeRewrite()(std::move(mod));
bool output_ssa = false;
bool skip_readonly_decl = false;
ffi::Map<ffi::String, ffi::Bytes> smap;
ffi::Map<ffi::String, runtime::FunctionInfo> fmap;
std::ostringstream source_maker;
// narrow all i64 to i32
mod = tirx::transform::ForceNarrowIndexToInt32()(std::move(mod));
for (auto kv : mod->functions) {
CodeGenWebGPU cg(target);
TVM_FFI_ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenWebGPU: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
TVM_FFI_ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
TVM_FFI_ICHECK(global_symbol.has_value())
<< "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute";
std::string f_name = global_symbol.value();
cg.Init(output_ssa);
fmap.Set(f_name, cg.AddFunction(f, skip_readonly_decl));
std::string code = cg.Finish();
source_maker << "// Function: " << f_name << "\n" << code << "\n";
smap.Set(f_name, ffi::Bytes(std::move(code)));
}
// The aggregated WGSL source dump is preserved in the in-memory source
// map keyed by "wgsl" — only used by InspectSource and never serialized.
ffi::Map<ffi::String, ffi::String> source;
source.Set("wgsl", source_maker.str());
return target::WebGPUModuleCreateWithFallback(std::move(smap), ffi::String("wgsl"),
std::move(fmap), std::move(source));
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.webgpu",
[](IRModule mod, Target target) { return BuildWebGPU(mod, target); });
}
} // namespace codegen
} // namespace tvm