blob: 093ddc3c99168b60f08fee7f0ad68af0c1421feb [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relax/backend/vm/lower_runtime_builtin.cc
* \brief Lowers most builtin functions and packed calls.
*/
#include <tvm/ffi/cast.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/backend.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/type.h>
#include <tvm/runtime/data_type.h>
#include <tvm/tirx/op.h>
namespace tvm {
namespace relax {
// This pass lowers most ops to VM specific builtins.
// TODO(relax-team): revisit after PrimValue.
class LowerRuntimeBuiltinMutator : public ExprMutator {
public:
using ExprMutator::VisitExpr_;
Expr VisitExpr_(const CallNode* call_node) final {
static const auto& lower_builtin_fmap = Op::GetAttrMap<FLowerBuiltin>("FLowerBuiltin");
// post-order mutation
Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
if (call->op == call_tir_dyn_op_) {
return CallTIRDyn(call);
} else if (call->op == reshape_op_) {
return Reshape(call);
} else if (call->op == shape_of_op_) {
return ShapeOf(call);
} else if (call->op == tensor_to_shape_op_) {
return TensorToShape(call);
} else if (call->op == call_py_func_op_) {
return CallPyFunc(call);
} else if (call->op == to_vdevice_op_) {
return ToDevice(call);
} else if (call->op == make_closure_op_) {
return MakeClosure(call);
} else if (call->op == invoke_closure_op_) {
return InvokeClosure(call);
} else if (call->op == alloc_tensor_op_) {
TVM_FFI_THROW(InternalError) << "VMBuiltinLower encountered " << call->op << " in expression "
<< ffi::GetRef<Call>(call_node) << ". "
<< "This operation should have been lowered earlier "
<< "using the 'relax.transform.LowerAllocTensor' pass.";
} else if (call->op == mem_alloc_storage_op_) {
return MakeMemAllocStorage(call);
} else if (call->op == mem_alloc_tensor_op_) {
return MakeMemAllocTensor(call);
} else if (call->op == mem_kill_storage_op_ || call->op == mem_kill_tensor_op_) {
return MakeMemKillObject(call);
} else if (const auto* op_node = call->op.as<OpNode>()) {
Op op = ffi::GetRef<Op>(op_node);
if (lower_builtin_fmap.count(op)) {
return lower_builtin_fmap[op](builder_, call);
}
}
return call;
}
Expr MakeMemAllocStorage(const Call& call) {
PrimValue runtime_device_index = Downcast<PrimValue>(call->args[1]);
StringImm storage_scope = Downcast<StringImm>(call->args[2]);
DataTypeImm output_dtype = DataTypeImm(DataType::UInt(8));
return Call(vm_alloc_storage_op_,
{call->args[0], runtime_device_index, output_dtype, storage_scope}, Attrs());
}
Expr MakeMemAllocTensor(const Call& call) {
PrimValue offset = Downcast<PrimValue>(call->args[1]);
DataTypeImm dtype = Downcast<DataTypeImm>(call->args[3]);
ffi::Array<Expr> call_args = {call->args[0], offset, call->args[2], dtype};
if (5 == call->args.size()) {
call_args.push_back(call->args[4]);
}
return Call(vm_alloc_tensor_op_, call_args, Attrs());
}
Expr MakeMemKillObject(const Call& call) {
TVM_FFI_ICHECK_EQ(call->args.size(), 1);
return Call(vm_kill_object_op_, {call->args[0]}, Attrs());
}
Expr CallTIRDyn(const Call& call_node) {
TVM_FFI_ICHECK(call_node->args.size() == 2);
TVM_FFI_ICHECK(call_node->args[0]->IsInstance<GlobalVarNode>());
TVM_FFI_ICHECK(call_node->args[1]->IsInstance<TupleNode>());
ffi::Array<Expr> args;
auto tir_args = Downcast<Tuple>(call_node->args[1]);
args.push_back(call_node->args[0]);
for (Expr arg : tir_args->fields) {
args.push_back(arg);
}
return Call(builtin_call_tir_dyn_, args, Attrs(), {void_sinfo_});
}
Expr Reshape(const Call& call_node) {
TVM_FFI_ICHECK(call_node->args.size() == 2);
TVM_FFI_ICHECK(call_node->struct_info_.defined());
auto arg = call_node->args[1];
TVM_FFI_CHECK(arg->struct_info_->IsInstance<ShapeStructInfoNode>(), TypeError)
<< "VMBuiltinLower expects the shape arg of R.reshape "
<< "to be a ShapeExpr or VarNode bound to a ShapeExpr. "
<< "However, in expression " << call_node << ", the shape argument " << arg
<< " has struct info " << arg->struct_info_;
return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}
Expr ShapeOf(const Call& call_node) {
TVM_FFI_ICHECK(call_node->args.size() == 1);
TVM_FFI_ICHECK(call_node->struct_info_.defined());
return Call(builtin_shape_of_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}
Expr TensorToShape(const Call& call_node) {
TVM_FFI_ICHECK(call_node->args.size() == 1);
TVM_FFI_ICHECK(call_node->struct_info_.defined());
return Call(builtin_tensor_to_shape_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}
Expr CallPyFunc(const Call& call_node) {
TVM_FFI_ICHECK(call_node->args.size() == 2);
TVM_FFI_ICHECK(call_node->struct_info_.defined());
// Create tuple with function name and arguments tuple
ffi::Array<Expr> tuple_fields;
tuple_fields.push_back(call_node->args[0]); // function name
tuple_fields.push_back(call_node->args[1]); // arguments tuple
auto combined_tuple = Tuple(tuple_fields);
// Direct call to vm.builtin.call_py_func
return Call(builtin_call_py_func_, {combined_tuple}, call_node->attrs, call_node->sinfo_args,
call_node->span);
}
Expr ToDevice(const Call& call_node) {
// TODO(yongwww): replace ToVDeviceAttrs with related Expr
TVM_FFI_ICHECK(call_node->args.size() == 1);
TVM_FFI_ICHECK(call_node->struct_info_.defined());
auto attrs = call_node->attrs.as<ToVDeviceAttrs>();
ffi::Array<Expr> args;
args.push_back(call_node->args[0]);
// Get the DLDeviceType and device_id from VDevice
VDevice vdev = attrs->dst_vdevice;
int dev_type = vdev->target->GetTargetDeviceType();
int dev_id = vdev->vdevice_id;
StringImm storage_scope = StringImm(vdev->memory_scope);
args.push_back(PrimValue::Int64(dev_type));
args.push_back(PrimValue::Int64(dev_id));
args.push_back(storage_scope);
return Call(builtin_to_device_, args, call_node->attrs, {GetStructInfo(call_node)});
}
Expr MakeClosure(const Call& call_node) {
TVM_FFI_ICHECK(call_node->args.size() == 2);
TVM_FFI_ICHECK(call_node->args[0]->IsInstance<GlobalVarNode>());
TVM_FFI_ICHECK(call_node->args[1]->IsInstance<TupleNode>());
ffi::Array<Expr> args;
auto func = call_node->args[0];
auto closure_args = Downcast<Tuple>(call_node->args[1]);
args.push_back(func);
for (Expr arg : closure_args->fields) {
args.push_back(arg);
}
return Call(builtin_make_closure_, args, Attrs(), {object_sinfo_});
}
Expr InvokeClosure(const Call& call_node) {
TVM_FFI_ICHECK(call_node->args.size() == 2);
TVM_FFI_ICHECK(call_node->args[0]->IsInstance<VarNode>());
TVM_FFI_ICHECK(call_node->args[1]->IsInstance<TupleNode>());
ffi::Array<Expr> args;
args.push_back(call_node->args[0]);
// args for the invoke_closure
auto invoke_closure_args = Downcast<Tuple>(call_node->args[1]);
for (Expr arg : invoke_closure_args->fields) {
args.push_back(arg);
}
return Call(call_builtin_with_ctx_op_, {builtin_invoke_closure_, Tuple(args)}, Attrs(),
{object_sinfo_});
}
const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
const StructInfo object_sinfo_ = ObjectStructInfo();
const StructInfo void_sinfo_ = TupleStructInfo(ffi::Array<StructInfo>({}));
// object to pattern match.
const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
const Op& reshape_op_ = Op::Get("relax.reshape");
const Op& shape_of_op_ = Op::Get("relax.shape_of");
const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape");
const Op& call_py_func_op_ = Op::Get("relax.call_py_func");
const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice");
const Op& make_closure_op_ = Op::Get("relax.make_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
const Op& alloc_tensor_op_ = Op::Get("relax.builtin.alloc_tensor");
const Op& mem_alloc_storage_op_ = Op::Get("relax.memory.alloc_storage");
const Op& mem_alloc_tensor_op_ = Op::Get("relax.memory.alloc_tensor");
const Op& mem_kill_storage_op_ = Op::Get("relax.memory.kill_storage");
const Op& mem_kill_tensor_op_ = Op::Get("relax.memory.kill_tensor");
// functions to lower to
const Op& vm_alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
const Op& vm_alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
const Op& vm_kill_object_op_ = Op::Get("relax.vm.kill_object");
// Function to compute allocated shape.
const ExternFunc builtin_compute_alloc_shape_{"vm.builtin.compute_alloc_shape"};
const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"};
const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"};
const ExternFunc builtin_tensor_to_shape_{"vm.builtin.tensor_to_shape"};
const ExternFunc builtin_call_py_func_{"vm.builtin.call_py_func"};
const ExternFunc builtin_to_device_{"vm.builtin.to_device"};
const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
};
Expr LowerRuntimeBuiltin(const Expr& e) { return LowerRuntimeBuiltinMutator().VisitExpr(e); }
namespace transform {
Pass LowerRuntimeBuiltin() {
auto pass_func = [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(LowerRuntimeBuiltin(f));
};
return CreateFunctionPass(pass_func, 0, "LowerRuntimeBuiltin", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.transform.LowerRuntimeBuiltin", LowerRuntimeBuiltin);
}
} // namespace transform
} // namespace relax
} // namespace tvm