blob: 87cb69b4f4ceff94e29fbe20a010414fa0d89781 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Example package that uses TVM.
* \file tvm_ext.cc
*/
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/op.h>
using namespace tvm;
using namespace tvm::tir;
using namespace tvm::runtime;
namespace tvm_ext {
/*!
* \brief A subclass of TVM's NDArray.
*
* To use this extension, an external library should
*
* 1) Inherit TVM's NDArray and NDArray container,
*
* 2) Follow the new object protocol to define new NDArray as a reference class.
*
* 3) On Python frontend, inherit `tvm.nd.NDArray`,
* register the type using tvm.register_object
*/
class NDSubClass : public tvm::runtime::NDArray {
public:
class SubContainer : public NDArray::Container {
public:
SubContainer(int additional_info) : additional_info_(additional_info) {
type_index_ = SubContainer::RuntimeTypeIndex();
}
int additional_info_{0};
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "tvm_ext.NDSubClass";
TVM_DECLARE_FINAL_OBJECT_INFO(SubContainer, NDArray::Container);
};
static void SubContainerDeleter(Object* obj) {
auto* ptr = static_cast<SubContainer*>(obj);
delete ptr;
}
NDSubClass() {}
explicit NDSubClass(ObjectPtr<Object> n) : NDArray(n) {}
explicit NDSubClass(int additional_info) {
SubContainer* ptr = new SubContainer(additional_info);
ptr->SetDeleter(SubContainerDeleter);
data_ = GetObjectPtr<Object>(ptr);
}
NDSubClass AddWith(const NDSubClass& other) const {
SubContainer* a = static_cast<SubContainer*>(get_mutable());
SubContainer* b = static_cast<SubContainer*>(other.get_mutable());
CHECK(a != nullptr && b != nullptr);
return NDSubClass(a->additional_info_ + b->additional_info_);
}
int get_additional_info() const {
SubContainer* self = static_cast<SubContainer*>(get_mutable());
CHECK(self != nullptr);
return self->additional_info_;
}
using ContainerType = SubContainer;
};
TVM_REGISTER_OBJECT_TYPE(NDSubClass::SubContainer);
/*!
* \brief Introduce additional extension data structures
* by sub-classing TVM's object system.
*/
class IntVectorObj : public Object {
public:
std::vector<int> vec;
static constexpr const char* _type_key = "tvm_ext.IntVector";
TVM_DECLARE_FINAL_OBJECT_INFO(IntVectorObj, Object);
};
/*!
* \brief Int vector reference class.
*/
class IntVector : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(IntVector, ObjectRef, IntVectorObj);
};
TVM_REGISTER_OBJECT_TYPE(IntVectorObj);
} // namespace tvm_ext
namespace tvm_ext {
TVM_REGISTER_GLOBAL("tvm_ext.ivec_create").set_body([](TVMArgs args, TVMRetValue* rv) {
auto n = tvm::runtime::make_object<IntVectorObj>();
for (int i = 0; i < args.size(); ++i) {
n->vec.push_back(args[i].operator int());
}
*rv = IntVector(n);
});
TVM_REGISTER_GLOBAL("tvm_ext.ivec_get").set_body([](TVMArgs args, TVMRetValue* rv) {
IntVector p = args[0];
*rv = p->vec[args[1].operator int()];
});
TVM_REGISTER_GLOBAL("tvm_ext.bind_add").set_body([](TVMArgs args_, TVMRetValue* rv_) {
PackedFunc pf = args_[0];
int b = args_[1];
*rv_ = PackedFunc([pf, b](TVMArgs args, TVMRetValue* rv) { *rv = pf(b, args[0]); });
});
TVM_REGISTER_GLOBAL("tvm_ext.sym_add").set_body([](TVMArgs args, TVMRetValue* rv) {
Var a = args[0];
Var b = args[1];
*rv = a + b;
});
TVM_REGISTER_GLOBAL("device_api.ext_dev").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = (*tvm::runtime::Registry::Get("device_api.cpu"))();
});
TVM_REGISTER_GLOBAL("tvm_ext.nd_create").set_body([](TVMArgs args, TVMRetValue* rv) {
int additional_info = args[0];
*rv = NDSubClass(additional_info);
CHECK_EQ(rv->type_code(), kTVMNDArrayHandle);
});
TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two").set_body([](TVMArgs args, TVMRetValue* rv) {
NDSubClass a = args[0];
NDSubClass b = args[1];
*rv = a.AddWith(b);
});
TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info").set_body([](TVMArgs args, TVMRetValue* rv) {
NDSubClass a = args[0];
*rv = a.get_additional_info();
});
} // namespace tvm_ext
// External function exposed to runtime.
extern "C" float TVMTestAddOne(float y) { return y + 1; }
// This callback approach allows extension allows tvm to extract
// This way can be helpful when we want to use a header only
// minimum version of TVM Runtime.
extern "C" int TVMExtDeclare(TVMFunctionHandle pregister) {
const PackedFunc& fregister = *static_cast<PackedFunc*>(pregister);
auto mul = [](TVMArgs args, TVMRetValue* rv) {
int x = args[0];
int y = args[1];
*rv = x * y;
};
fregister("mul", PackedFunc(mul));
return 0;
}