| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| /*! |
| * \file src/runtime/vm/builtin.cc |
| */ |
| #include <tvm/ffi/any.h> |
| #include <tvm/ffi/container/array.h> |
| #include <tvm/ffi/container/shape.h> |
| #include <tvm/ffi/function.h> |
| #include <tvm/ffi/memory.h> |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/runtime/data_type.h> |
| #include <tvm/runtime/device_api.h> |
| #include <tvm/runtime/logging.h> |
| #include <tvm/runtime/memory/memory_manager.h> |
| #include <tvm/runtime/tensor.h> |
| #include <tvm/runtime/vm/builtin.h> |
| #include <tvm/runtime/vm/bytecode.h> |
| #include <tvm/runtime/vm/vm.h> |
| |
| #include <unordered_map> |
| |
| namespace tvm { |
| namespace runtime { |
| namespace vm { |
| |
| using tvm::runtime::Tensor; |
| |
| //------------------------------------------------- |
| // Shape/StructInfo handling. |
| //------------------------------------------------- |
| /*! |
| * \brief Builtin function to allocate shape heap. |
| * \param ctx_ptr The context module pointer. |
| * \param size the size of the heap. |
| * \return An allocate Tensor as shape heap. |
| */ |
| Tensor AllocShapeHeap(void* ctx_ptr, int64_t size) { |
| VirtualMachine* vm = static_cast<VirtualMachine*>(ctx_ptr); |
| // use host allocator, which is always last element. |
| size_t host_device_index = vm->devices.size() - 1; |
| // specially handle hexagon on-device RT. |
| // TODO(relax-team): visit and consider other possible choices. |
| if (vm->devices[0].device_type == kDLHexagon) { |
| host_device_index = 0; |
| } else { |
| TVM_FFI_ICHECK_EQ(vm->devices[host_device_index].device_type, kDLCPU); |
| } |
| auto* alloc = vm->allocators[host_device_index]; |
| return alloc->Empty({size}, DLDataType{kDLInt, 64, 1}, vm->devices[host_device_index]); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("vm.builtin.alloc_shape_heap", AllocShapeHeap); |
| } |
| |
| /*! |
| * \brief Builtin match R.Prim function. |
| * |
| * \param input_value The runtime value provided by the user |
| * |
| * \param heap The VM storage for symbolic shapes |
| * |
| * \param code_value The op code, defined in MatchShapeCode, |
| * indicating how this value should be interpreted. |
| * |
| * \param reg The register, if using kStoreToHeap or |
| * kAssertEqualToLoad, or a literal value if using kAssertEqualToImm |
| * |
| * \param err_ctx An optional string used in error messages, providing |
| * additional context |
| * |
| * \sa MatchShape |
| */ |
| void MatchPrimValue(int64_t input_value, DLTensor* heap, int code_value, int64_t reg, |
| ffi::Optional<ffi::String> err_ctx) { |
| int64_t* heap_data = heap == nullptr ? nullptr : static_cast<int64_t*>(heap->data); |
| MatchShapeCode code = static_cast<MatchShapeCode>(code_value); |
| |
| if (code == MatchShapeCode::kAssertEqualToImm) { |
| TVM_FFI_CHECK_EQ(input_value, reg, RuntimeError) |
| << err_ctx.value_or("") << " match_cast error, " |
| << " PrimValue mismatch to specified constant."; |
| } else if (code == MatchShapeCode::kStoreToHeap) { |
| heap_data[reg] = input_value; |
| } else if (code == MatchShapeCode::kNoOp) { |
| } else if (code == MatchShapeCode::kAssertEqualToLoad) { |
| TVM_FFI_CHECK_EQ(input_value, heap_data[reg], RuntimeError) |
| << err_ctx.value_or("") << " match_cast error, " |
| << " PrimValue mismatch to a previous populated value."; |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unknown match shape code: " << static_cast<int>(code); |
| } |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("vm.builtin.match_prim_value", MatchPrimValue); |
| } |
| |
| /*! |
| * \brief Builtin match shape function. |
| * \param args The packed function arguments. |
| * \param rv The return value. |
| * |
| * \sa MatchShapeCode |
| */ |
| void MatchShape(ffi::PackedArgs args, ffi::Any* rv) { |
| // input shape the first argument can take in tensor, DLTensor* or shape. |
| ffi::Shape input_shape; |
| if (auto opt_tensor = args[0].as<Tensor>()) { |
| input_shape = opt_tensor.value().Shape(); |
| } else if (auto opt_dltensor = args[0].try_cast<DLTensor*>()) { |
| DLTensor* ptr = opt_dltensor.value(); |
| input_shape = ffi::Shape(ptr->shape, ptr->shape + ptr->ndim); |
| } else { |
| input_shape = args[0].cast<ffi::Shape>(); |
| } |
| auto heap = args[1].try_cast<DLTensor*>(); |
| int64_t* heap_data = heap.has_value() ? static_cast<int64_t*>((*heap)->data) : nullptr; |
| int64_t size = args[2].cast<int64_t>(); |
| const int64_t kBeginCode = 3; |
| TVM_FFI_ICHECK_LE(kBeginCode + size * 2, args.size()); |
| // a function that lazily get context for error reporting |
| const int64_t kErrorContextOffset = kBeginCode + size * 2; |
| ffi::Optional<ffi::String> err_ctx = args[kErrorContextOffset].cast<ffi::String>(); |
| |
| TVM_FFI_CHECK_EQ(input_shape.size(), size, RuntimeError) |
| << err_ctx.value_or("") << " match_cast shape size mismatch."; |
| |
| for (int64_t i = 0; i < size; ++i) { |
| MatchShapeCode code = static_cast<MatchShapeCode>(args[kBeginCode + i * 2].cast<int>()); |
| int64_t reg = args[kBeginCode + i * 2 + 1].cast<int64_t>(); |
| |
| if (code == MatchShapeCode::kAssertEqualToImm) { |
| TVM_FFI_CHECK_EQ(input_shape[i], reg, RuntimeError) |
| << err_ctx.value_or("") << " match_cast error, " |
| << " shape[" << i << "]" |
| << " mismatch to specified constant."; |
| } else if (code == MatchShapeCode::kStoreToHeap) { |
| heap_data[reg] = input_shape[i]; |
| } else if (code == MatchShapeCode::kNoOp) { |
| } else { |
| TVM_FFI_ICHECK(code == MatchShapeCode::kAssertEqualToLoad); |
| TVM_FFI_CHECK_EQ(input_shape[i], heap_data[reg], RuntimeError) |
| << err_ctx.value_or("") << " match_cast error, " |
| << " shape[" << i << "]" |
| << " mismatch to a previous populated value."; |
| } |
| } |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def_packed("vm.builtin.match_shape", MatchShape); |
| } |
| |
| /*! |
| * \brief Builtin make prim value function. |
| * \param heap The shape heap to use |
| * \param shape_code The shape code of the value |
| * \param rv The return value. |
| * |
| * \sa MakeShape |
| */ |
| int64_t MakePrimValue(DLTensor* heap, int shape_code, int64_t reg) { |
| // NOTE: heap can be nullptr |
| int64_t* heap_data = heap == nullptr ? nullptr : static_cast<int64_t*>(heap->data); |
| |
| MakeShapeCode code = static_cast<MakeShapeCode>(shape_code); |
| if (code == MakeShapeCode::kUseImm) { |
| return reg; |
| } else if (code == MakeShapeCode::kLoadShape) { |
| return heap_data[reg]; |
| } else { |
| TVM_FFI_THROW(InternalError) << "Invalid shape code: " << shape_code; |
| } |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("vm.builtin.make_prim_value", MakePrimValue); |
| } |
| |
| /*! |
| * \brief Builtin make shape function. |
| * \param args The packed function arguments. |
| * \param rv The return value. |
| * |
| * \sa MakeShapeCode |
| */ |
| void MakeShape(ffi::PackedArgs args, ffi::Any* rv) { |
| // NOTE: heap can be nullptr |
| auto heap = args[0].try_cast<DLTensor*>(); |
| int64_t* heap_data = heap.has_value() ? static_cast<int64_t*>((*heap)->data) : nullptr; |
| int64_t size = args[1].cast<int64_t>(); |
| const int64_t kBeginCode = 2; |
| |
| std::vector<int64_t> shape(size); |
| |
| for (int64_t i = 0; i < size; ++i) { |
| MakeShapeCode code = static_cast<MakeShapeCode>(args[kBeginCode + i * 2].cast<int>()); |
| int64_t reg = args[kBeginCode + i * 2 + 1].cast<int64_t>(); |
| if (code == MakeShapeCode::kUseImm) { |
| shape[i] = reg; |
| } else { |
| TVM_FFI_ICHECK(code == MakeShapeCode::kLoadShape); |
| shape[i] = heap_data[reg]; |
| } |
| } |
| *rv = ffi::Shape(std::move(shape)); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def_packed("vm.builtin.make_shape", MakeShape); |
| } |
| |
| /*! |
| * \brief Builtin function to check if arg is Tensor(dtype, ndim) |
| * \param arg The input argument. |
| * \param ndim Expected ndim of the Tensor, can be -1 (indicate unknown). |
| * \param dtype The expected content data type. |
| * \param err_ctx Additional context if error occurs. |
| */ |
| void CheckTensorInfo(ffi::PackedArgs args, ffi::Any* rv) { |
| ffi::AnyView arg = args[0]; |
| int ndim = args[1].cast<int>(); |
| DataType dtype; |
| ffi::Optional<ffi::String> err_ctx; |
| |
| if (args.size() == 3) { |
| dtype = DataType::Void(); |
| err_ctx = args[2].cast<ffi::Optional<ffi::String>>(); |
| } else { |
| dtype = args[2].cast<DataType>(); |
| err_ctx = args[3].cast<ffi::Optional<ffi::String>>(); |
| } |
| |
| auto opt_ptr = arg.try_cast<DLTensor*>(); |
| TVM_FFI_CHECK(opt_ptr.has_value(), TypeError) |
| << err_ctx.value_or("") << " expect a Tensor but get " << arg.GetTypeKey(); |
| |
| DLTensor* ptr = opt_ptr.value(); |
| if (ndim != -1) { |
| TVM_FFI_CHECK(ptr->ndim == ndim, ValueError) |
| << err_ctx.value_or("") << " expect Tensor with ndim " << ndim << " but get " << ptr->ndim; |
| } |
| |
| if (dtype != DataType::Void()) { |
| TVM_FFI_CHECK(DataType(ptr->dtype) == dtype, ValueError) |
| << err_ctx.value_or("") << " expect Tensor with dtype " << dtype << " but get " |
| << DataType(ptr->dtype); |
| } |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def_packed("vm.builtin.check_tensor_info", CheckTensorInfo); |
| } |
| |
| /*! |
| * \brief Builtin function to check if arg is Shape(ndim) |
| * \param arg The input argument. |
| * \param ndim Expected size of the shape, can be -1 (indicate unknown). |
| * \param err_ctx Additional context if error occurs. |
| */ |
| void CheckShapeInfo(ObjectRef arg, int ndim, ffi::Optional<ffi::String> err_ctx) { |
| // a function that lazily get context for error reporting |
| auto* ptr = arg.as<ffi::Shape::ContainerType>(); |
| TVM_FFI_CHECK(ptr != nullptr, TypeError) |
| << err_ctx.value_or("") << " expect a Shape but get " << arg->GetTypeKey(); |
| if (ndim != -1) { |
| TVM_FFI_CHECK(ptr->size == static_cast<uint64_t>(ndim), ValueError) |
| << err_ctx.value_or("") << " expect Shape with ndim " << ndim << " but get " << ptr->size; |
| } |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("vm.builtin.check_shape_info", CheckShapeInfo); |
| } |
| |
| /*! |
| * \brief Builtin function to check if arg is PrimValue(dtype) |
| * \param arg The input argument. |
| * \param dtype Expected dtype of the PrimValue. Can be DataType::Void() for unknown dtype. |
| * \param err_ctx Additional context if error occurs. |
| */ |
| void CheckPrimValueInfo(ffi::AnyView arg, DataType dtype, ffi::Optional<ffi::String> err_ctx) { |
| if (auto opt_obj = arg.as<ObjectRef>()) { |
| TVM_FFI_THROW(TypeError) << err_ctx.value_or("") << ", expected dtype " << dtype |
| << ", but received ObjectRef of type " |
| << opt_obj.value()->GetTypeKey(); |
| } else if (dtype.is_bool()) { |
| arg.cast<bool>(); |
| } else if (dtype.is_int()) { |
| arg.cast<int64_t>(); |
| } else if (dtype.is_uint()) { |
| arg.cast<uint64_t>(); |
| } else if (dtype.is_float()) { |
| arg.cast<double>(); |
| } else if (dtype.is_handle()) { |
| arg.cast<void*>(); |
| } else { |
| TVM_FFI_THROW(TypeError) << err_ctx.value_or("") << ", unsupported dtype " << dtype; |
| } |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("vm.builtin.check_prim_value_info", CheckPrimValueInfo); |
| } |
| |
| /*! |
| * \brief Builtin function to check if arg is Tuple with size elements. |
| * \param arg The input argument. |
| * \param size The expected size of the tuple. |
| * \param err_ctx Additional context if error occurs. |
| */ |
| void CheckTupleInfo(ObjectRef arg, int64_t size, ffi::Optional<ffi::String> err_ctx) { |
| // a function that lazily get context for error reporting |
| auto* ptr = arg.as<ffi::ArrayObj>(); |
| TVM_FFI_CHECK(ptr != nullptr, TypeError) |
| << err_ctx.value_or("") << " expect a Tuple but get " << arg->GetTypeKey(); |
| TVM_FFI_CHECK(static_cast<int64_t>(ptr->size()) == size, ValueError) |
| << err_ctx.value_or("") << " expect a Tuple with " << size << " elements, " |
| << " but get a Tuple with " << ptr->size() << " elements."; |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("vm.builtin.check_tuple_info", CheckTupleInfo); |
| } |
| |
| /*! |
| * \brief Builtin function to check if arg is a callable function. |
| * \param arg The input argument. |
| * \param err_ctx Additional context if error occurs. |
| */ |
| void CheckFuncInfo(ObjectRef arg, ffi::Optional<ffi::String> err_ctx) { |
| // a function that lazily get context for error reporting |
| bool is_func = arg.as<ffi::Function::ContainerType>() || arg.as<VMClosure::ContainerType>(); |
| TVM_FFI_CHECK(is_func, TypeError) |
| << err_ctx.value_or("") << " expect a Function but get " << arg->GetTypeKey(); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("vm.builtin.check_func_info", CheckFuncInfo); |
| } |
| |
| //------------------------------------------------- |
| // Storage management. |
| //------------------------------------------------- |
| Storage VMAllocStorage(void* ctx_ptr, ffi::Shape buffer_shape, Index device_index, |
| DLDataType dtype_hint, ffi::String mem_scope) { |
| VirtualMachine* vm = static_cast<VirtualMachine*>(ctx_ptr); |
| |
| TVM_FFI_ICHECK_LT(device_index, vm->devices.size()) |
| << "The device index is out of VM physical devices list"; |
| |
| if (device_index == -1) { |
| // Allocate on host. Host is always the last element of vm->devices. |
| device_index = vm->devices.size() - 1; |
| } |
| |
| auto* alloc = vm->allocators[device_index]; |
| TVM_FFI_ICHECK(alloc) << "Did you forget to init the VirtualMachine with devices?"; |
| |
| auto buffer = alloc->Alloc(vm->devices[device_index], buffer_shape, dtype_hint, mem_scope); |
| |
| return Storage(buffer, alloc); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def("vm.builtin.alloc_storage", VMAllocStorage) |
| .def_packed("vm.builtin.alloc_tensor", [](ffi::PackedArgs args, ffi::Any* rv) { |
| Storage sobj = args[0].cast<Storage>(); |
| int64_t offset = args[1].cast<int64_t>(); |
| ffi::Shape shape = args[2].cast<ffi::Shape>(); |
| DataType dtype = args[3].cast<DataType>(); |
| if (args.size() == 5) { |
| ffi::String scope = args[4].cast<ffi::String>(); |
| *rv = sobj->AllocTensorScoped(offset, shape, dtype, scope); |
| } else { |
| *rv = sobj->AllocTensor(offset, shape, dtype); |
| } |
| }); |
| } |
| |
| //------------------------------------------------- |
| // Closure function handling, calling convention |
| //------------------------------------------------- |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def_packed("vm.builtin.make_closure", |
| [](ffi::PackedArgs args, ffi::Any* rv) { |
| VMClosure clo = args[0].cast<VMClosure>(); |
| std::vector<ffi::Any> saved_args; |
| saved_args.resize(args.size() - 1); |
| for (size_t i = 0; i < saved_args.size(); ++i) { |
| saved_args[i] = args[i + 1]; |
| } |
| auto impl = VMClosure::BindLastArgs(clo->impl, saved_args); |
| *rv = VMClosure(clo->func_name, impl); |
| }) |
| .def_packed("vm.builtin.invoke_closure", |
| [](ffi::PackedArgs args, ffi::Any* rv) { |
| // args[0]: vm; args[1]: closure; args[2, 3, ...]: function arguments |
| VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); |
| ObjectRef vm_closure = args[1].cast<ObjectRef>(); |
| vm->InvokeClosurePacked(vm_closure, args.Slice(2), rv); |
| }) |
| .def_packed("vm.builtin.call_tir_dyn", [](ffi::PackedArgs args, ffi::Any* rv) { |
| ffi::Function func = args[0].cast<ffi::Function>(); |
| ffi::Shape to_unpack = args[args.size() - 1].cast<ffi::Shape>(); |
| size_t num_tensor_args = args.size() - 2; |
| |
| std::vector<ffi::AnyView> packed_args(num_tensor_args + to_unpack.size()); |
| std::copy(args.data() + 1, args.data() + args.size() - 1, packed_args.data()); |
| |
| for (size_t i = 0; i < to_unpack.size(); ++i) { |
| packed_args[i + num_tensor_args] = to_unpack[i]; |
| } |
| func.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), rv); |
| }); |
| } |
| |
| //------------------------------------- |
| // Python function call support |
| //------------------------------------- |
| |
| // Global registry for Python functions |
| static std::unordered_map<std::string, ffi::Function> py_func_registry; |
| |
| /*! |
| * \brief Clear the Python function registry on shutdown |
| */ |
| void ClearPyFuncRegistry() { py_func_registry.clear(); } |
| |
| /*! |
| * \brief Register a Python function for call_py_func |
| * \param name The function name |
| * \param func The Python function wrapped as ffi::Function |
| */ |
| void RegisterPyFunc(const std::string& name, ffi::Function func) { py_func_registry[name] = func; } |
| |
| /*! |
| * \brief Get a registered Python function |
| * \param name The function name |
| * \return The Python function |
| */ |
| ffi::Function GetPyFunc(const std::string& name) { |
| auto it = py_func_registry.find(name); |
| if (it == py_func_registry.end()) { |
| TVM_FFI_THROW(InternalError) << "Python function '" << name << "' not found in registry"; |
| } |
| return it->second; |
| } |
| |
| /*! |
| * \brief Call a Python function from VM |
| * \param args The packed function arguments (tuple containing function name and arguments) |
| * \param rv The return value |
| */ |
| void CallPyFunc(ffi::PackedArgs args, ffi::Any* rv) { |
| // args[0] should be a tuple containing (func_name, args_tuple) |
| if (args.size() != 1) { |
| TVM_FFI_THROW(InternalError) << "vm.builtin.call_py_func expects exactly 1 argument (tuple)"; |
| } |
| |
| auto tuple_arg = args[0].cast<ffi::Array<ffi::Any>>(); |
| if (tuple_arg.size() != 2) { |
| TVM_FFI_THROW(InternalError) |
| << "vm.builtin.call_py_func tuple should contain (func_name, args)"; |
| } |
| |
| // Get function name |
| std::string func_name = tuple_arg[0].cast<ffi::String>(); |
| |
| // Get arguments tuple |
| auto func_args = tuple_arg[1].cast<ffi::Array<ffi::Any>>(); |
| |
| // Look up Python function in registry |
| ffi::Function py_func = GetPyFunc(func_name); |
| |
| // Call the Python function with the arguments |
| std::vector<ffi::AnyView> py_args_vec(func_args.begin(), func_args.end()); |
| ffi::PackedArgs py_args(py_args_vec.data(), py_args_vec.size()); |
| py_func.CallPacked(py_args, rv); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def_packed("vm.builtin.call_py_func", CallPyFunc) |
| .def("vm.builtin.register_py_func", RegisterPyFunc) |
| .def("vm.builtin.get_py_func", GetPyFunc) |
| .def("vm.builtin.clear_py_func_registry", ClearPyFuncRegistry); |
| } |
| |
| //------------------------------------- |
| // Builtin runtime operators. |
| //------------------------------------- |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def_method("vm.builtin.shape_of", |
| [](ffi::Any any) -> ffi::Shape { |
| if (auto opt_tensor = any.try_cast<Tensor>()) { |
| return opt_tensor.value().Shape(); |
| } else if (auto opt_dltensor = any.try_cast<DLTensor*>()) { |
| DLTensor* ptr = opt_dltensor.value(); |
| return ffi::Shape(ptr->shape, ptr->shape + ptr->ndim); |
| } else { |
| TVM_FFI_THROW(TypeError) |
| << "vm.builtin.shape_of expects a Tensor or DLTensor*, but get " |
| << any.GetTypeKey(); |
| } |
| }) |
| .def("vm.builtin.copy", [](ffi::Any a) -> ffi::Any { return a; }) |
| .def("vm.builtin.reshape", |
| [](ffi::Any any, ffi::Shape new_shape) { |
| if (auto opt_tensor = any.try_cast<Tensor>()) { |
| Tensor data = opt_tensor.value(); |
| return data.CreateView(new_shape, data->dtype); |
| } else if (auto opt_dltensor = any.try_cast<DLTensor*>()) { |
| DLTensor* ptr = opt_dltensor.value(); |
| auto tmp = std::make_unique<DLManagedTensor>(); |
| tmp->dl_tensor = *ptr; |
| tmp->manager_ctx = nullptr; |
| tmp->deleter = nullptr; |
| Tensor data = Tensor::FromDLPack(tmp.release()); |
| return data.CreateView(new_shape, data->dtype); |
| } else { |
| TVM_FFI_THROW(TypeError) |
| << "vm.builtin.reshape expects a Tensor or DLTensor*, but get " |
| << any.GetTypeKey(); |
| } |
| }) |
| .def("vm.builtin.null_value", []() -> std::nullptr_t { return nullptr; }) |
| .def_packed("vm.builtin.to_device", [](ffi::PackedArgs args, ffi::Any* rv) { |
| Tensor data = args[0].cast<Tensor>(); |
| int dev_type = args[1].cast<int>(); |
| int dev_id = args[2].cast<int>(); |
| Device dst_device = {(DLDeviceType)dev_type, dev_id}; |
| ffi::String mem_scope = "global"; |
| if (args.size() == 4) { |
| mem_scope = args[3].cast<ffi::String>(); |
| } |
| *rv = data.CopyTo(dst_device, mem_scope); |
| }); |
| } |
| |
| /*! |
| * \brief Load the scalar value in cond and return the result value. |
| * \param cond The condition |
| * \return Bool |
| */ |
| bool ReadIfCond(ffi::AnyView cond) { |
| if (auto opt_int = cond.try_cast<bool>()) { |
| return opt_int.value(); |
| } |
| Tensor arr = cond.cast<tvm::runtime::Tensor>(); |
| if (arr->device.device_type != kDLCPU) { |
| arr = arr.CopyTo(DLDevice{kDLCPU, 0}); |
| } |
| TVM_FFI_ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt || |
| arr->dtype.code == kDLBool); |
| int64_t result; |
| switch (arr->dtype.bits) { |
| case 1: { |
| result = reinterpret_cast<int8_t*>(arr->data)[0]; |
| break; |
| } |
| case 8: { |
| result = reinterpret_cast<int8_t*>(arr->data)[0]; |
| break; |
| } |
| case 16: { |
| result = reinterpret_cast<int16_t*>(arr->data)[0]; |
| break; |
| } |
| case 32: { |
| result = reinterpret_cast<int32_t*>(arr->data)[0]; |
| break; |
| } |
| case 64: { |
| result = reinterpret_cast<int64_t*>(arr->data)[0]; |
| break; |
| } |
| default: |
| TVM_FFI_THROW(InternalError) << "Unknown scalar int type: " << DLDataTypeToString(arr->dtype); |
| throw; |
| } |
| return result != 0; |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("vm.builtin.read_if_cond", ReadIfCond); |
| } |
| |
| //------------------------------------- |
| // Debugging API |
| //------------------------------------- |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def_packed( |
| "vm.builtin.invoke_debug_func", [](ffi::PackedArgs args, ffi::Any* rv) -> void { |
| TVM_FFI_ICHECK_GE(args.size(), 3); |
| int num_args = args.size() - 3; |
| ObjectRef io_effect = args[0].cast<ObjectRef>(); |
| TVM_FFI_CHECK(!io_effect.defined(), ValueError) |
| << "IOEffect is expected to be lowered to None."; |
| ffi::String debug_func_name = args[1].cast<ffi::String>(); |
| const auto debug_func = tvm::ffi::Function::GetGlobal(debug_func_name); |
| TVM_FFI_CHECK(debug_func.has_value(), ValueError) |
| << debug_func_name << " is not found. " |
| << "Use the decorator `@tvm.register_global_func(\"" << debug_func_name |
| << "\")` to register it."; |
| ffi::String line_info = args[2].cast<ffi::String>(); |
| std::vector<ffi::AnyView> call_args(num_args + 1); |
| { |
| call_args[0] = line_info; |
| for (int i = 0; i < num_args; ++i) { |
| call_args[i + 1] = args[i + 3]; |
| } |
| } |
| debug_func->CallPacked(ffi::PackedArgs(call_args.data(), call_args.size()), rv); |
| *rv = io_effect; |
| }); |
| } |
| |
| //------------------------------------- |
| // Data structure API |
| //------------------------------------- |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def("vm.builtin.tuple_getitem", |
| [](ffi::Array<ffi::Any> arr, int64_t index) { return arr[index]; }) |
| .def("vm.builtin.tuple_reset_item", |
| [](const ffi::ArrayObj* arr, int64_t index) { |
| const_cast<ffi::ArrayObj*>(arr)->SetItem(index, nullptr); |
| }) |
| .def_packed("vm.builtin.make_tuple", |
| [](ffi::PackedArgs args, ffi::Any* rv) { |
| ffi::Array<ffi::Any> arr; |
| for (int i = 0; i < args.size(); ++i) { |
| arr.push_back(args[i]); |
| } |
| *rv = arr; |
| }) |
| .def("vm.builtin.tensor_to_shape", |
| [](Tensor data) { |
| Tensor arr = data; |
| if (data->device.device_type != kDLCPU) { |
| arr = data.CopyTo(DLDevice{kDLCPU, 0}); |
| } |
| |
| TVM_FFI_ICHECK_EQ(arr->ndim, 1); |
| TVM_FFI_ICHECK_EQ(arr->dtype.code, kDLInt); |
| |
| std::vector<int64_t> out_shape; |
| for (int i = 0; i < arr.Shape()[0]; ++i) { |
| int64_t result; |
| switch (arr->dtype.bits) { |
| case 16: { |
| result = reinterpret_cast<int16_t*>(arr->data)[i]; |
| break; |
| } |
| case 32: { |
| result = reinterpret_cast<int32_t*>(arr->data)[i]; |
| break; |
| } |
| case 64: { |
| result = reinterpret_cast<int64_t*>(arr->data)[i]; |
| break; |
| } |
| default: |
| TVM_FFI_THROW(InternalError) |
| << "Unknown scalar int type: " << DLDataTypeToString(arr->dtype); |
| throw; |
| } |
| out_shape.push_back(result); |
| } |
| return ffi::Shape(out_shape); |
| }) |
| .def("vm.builtin.ensure_zero_offset", [](Tensor data) { |
| if (data->byte_offset == 0) { |
| return data; |
| } |
| auto* device_api = DeviceAPI::Get(data->device); |
| if (device_api->SupportsDevicePointerArithmeticsOnHost() && |
| data->byte_offset % tvm::runtime::kAllocAlignment == 0) { |
| DLManagedTensor* dl_tensor = data.ToDLPack(); |
| dl_tensor->dl_tensor.data = |
| reinterpret_cast<char*>(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset; |
| dl_tensor->dl_tensor.byte_offset = 0; |
| return Tensor::FromDLPack(dl_tensor); |
| } else { |
| auto new_array = Tensor::Empty(data.Shape(), data->dtype, data->device); |
| new_array.CopyFrom(data); |
| return new_array; |
| } |
| }); |
| } |
| |
| } // namespace vm |
| } // namespace runtime |
| } // namespace tvm |
| |
| //------------------------------------------------- |
| // AnyList C runtime API: keep in relax for now. |
| //-------------------------------------------------- |
| extern "C" { |
| /*! |
| * \brief Backend function to get anylist item and set into Packed Func call arg stack. |
| * |
| * \param anylist The handle to the anylist, backed by ffi::Any* |
| * \param int The index. |
| * \param args The args stack. |
| * \param arg_offset The offset of argument. |
| * \return 0 when no error is thrown, -1 when failure happens |
| */ |
| TVM_DLL int TVMBackendAnyListSetPackedArg(void* anylist, int index, TVMFFIAny* args, |
| int arg_offset); |
| /*! |
| * \brief Backend function to get anylist item and set into Packed Func call arg stack. |
| * |
| * \param anylist The handle to the anylist, backed by ffi::Any* |
| * \param int The index. |
| */ |
| TVM_DLL int TVMBackendAnyListResetItem(void* anylist, int index); |
| |
| /*! |
| * \brief Backend function to set anylist item by moving from packed func return. |
| * |
| * \param anylist The handle to the anylist, backed by ffi::Any* |
| * \param int The index. |
| * \param args The args stack. |
| * \param type_codes The type codes stack. |
| * \param arg_offset The offset of argument. |
| * \return 0 when no error is thrown, -1 when failure happens. |
| */ |
| TVM_DLL int TVMBackendAnyListMoveFromPackedReturn(void* anylist, int index, TVMFFIAny* args, |
| int ret_offset); |
| |
| int TVMBackendAnyListSetPackedArg(void* anylist, int index, TVMFFIAny* args, int arg_offset) { |
| using namespace tvm::runtime; |
| TVM_FFI_SAFE_CALL_BEGIN(); |
| auto* list = static_cast<TVMFFIAny*>(anylist); |
| args[arg_offset] = list[index]; |
| TVM_FFI_SAFE_CALL_END(); |
| } |
| |
| int TVMBackendAnyListResetItem(void* anylist, int index) { |
| using namespace tvm::runtime; |
| TVM_FFI_SAFE_CALL_BEGIN(); |
| auto* list = static_cast<tvm::ffi::Any*>(anylist); |
| list[index] = nullptr; |
| TVM_FFI_SAFE_CALL_END(); |
| } |
| |
| int TVMBackendAnyListMoveFromPackedReturn(void* anylist, int index, TVMFFIAny* args, |
| int ret_offset) { |
| using namespace tvm::runtime; |
| TVM_FFI_SAFE_CALL_BEGIN(); |
| auto* list = static_cast<tvm::ffi::Any*>(anylist); |
| list[index] = tvm::ffi::details::AnyUnsafe::MoveTVMFFIAnyToAny(&args[ret_offset]); |
| TVM_FFI_SAFE_CALL_END(); |
| } |
| } // extern "C" |