blob: 8875046874e4133e62c32459d7b63643ddb0fa6b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* FFI registration code used for frontend testing purposes.
* \file ffi_testing.cc
*/
#include <tvm/ffi/container/variant.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/env_func.h>
#include <tvm/runtime/module.h>
#include <tvm/te/tensor.h>
#include <tvm/tir/expr.h>
#include <chrono>
#include <thread>
namespace tvm {
// Attrs used to python API
struct TestAttrs : public AttrsNodeReflAdapter<TestAttrs> {
int axis;
ffi::String name;
ffi::Array<PrimExpr> padding;
TypedEnvFunc<int(int)> func;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<TestAttrs>()
.def_ro("axis", &TestAttrs::axis, "axis field", refl::DefaultValue(10))
.def_ro("name", &TestAttrs::name, "name")
.def_ro("padding", &TestAttrs::padding, "padding of input",
refl::DefaultValue(ffi::Array<PrimExpr>({0, 0})))
.def_ro("func", &TestAttrs::func, "some random env function",
refl::DefaultValue(TypedEnvFunc<int(int)>(nullptr)));
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("attrs.TestAttrs", TestAttrs, BaseAttrsNode);
};
TVM_FFI_STATIC_INIT_BLOCK() { TestAttrs::RegisterReflection(); }
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("testing.GetShapeSize",
[](ffi::Shape shape) { return static_cast<int64_t>(shape.size()); })
.def("testing.GetShapeElem",
[](ffi::Shape shape, int idx) {
ICHECK_LT(idx, shape.size());
return shape[idx];
})
.def_packed("testing.test_wrap_callback",
[](ffi::PackedArgs args, ffi::Any* ret) {
ffi::Function pf = args[0].cast<ffi::Function>();
*ret = ffi::TypedFunction<void()>([pf]() { pf(); });
})
.def_packed("testing.test_wrap_callback_suppress_err",
[](ffi::PackedArgs args, ffi::Any* ret) {
ffi::Function pf = args[0].cast<ffi::Function>();
auto result = ffi::TypedFunction<void()>([pf]() {
try {
pf();
} catch (std::exception& err) {
}
});
*ret = result;
})
.def_packed("testing.test_check_eq_callback",
[](ffi::PackedArgs args, ffi::Any* ret) {
auto msg = args[0].cast<std::string>();
*ret = ffi::TypedFunction<void(int x, int y)>(
[msg](int x, int y) { CHECK_EQ(x, y) << msg; });
})
.def_packed("testing.device_test",
[](ffi::PackedArgs args, ffi::Any* ret) {
auto dev = args[0].cast<Device>();
int dtype = args[1].cast<int>();
int did = args[2].cast<int>();
CHECK_EQ(static_cast<int>(dev.device_type), dtype);
CHECK_EQ(static_cast<int>(dev.device_id), did);
*ret = dev;
})
.def_packed("testing.identity_cpp", [](ffi::PackedArgs args, ffi::Any* ret) {
const auto identity_func = tvm::ffi::Function::GetGlobal("testing.identity_py");
ICHECK(identity_func.has_value())
<< "AttributeError: \"testing.identity_py\" is not registered. Please check "
"if the python module is properly loaded";
*ret = (*identity_func)(args[0]);
});
}
// in src/api_test.cc
void ErrorTest(int x, int y) {
// raise ValueError
CHECK_EQ(x, y) << "ValueError: expect x and y to be equal.";
if (x == 1) {
// raise InternalError.
LOG(FATAL) << "InternalError: cannot reach here";
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("testing.ErrorTest", ErrorTest);
}
class FrontendTestModuleNode : public ffi::ModuleObj {
public:
const char* kind() const final { return "frontend_test"; }
static constexpr const char* kAddFunctionName = "__add_function";
virtual ffi::Optional<ffi::Function> GetFunction(const ffi::String& name);
private:
std::unordered_map<std::string, ffi::Function> functions_;
};
constexpr const char* FrontendTestModuleNode::kAddFunctionName;
ffi::Optional<ffi::Function> FrontendTestModuleNode::GetFunction(const ffi::String& name) {
ffi::Module self_strong_ref = ffi::GetRef<ffi::Module>(this);
if (name == kAddFunctionName) {
return ffi::Function::FromTyped(
[this, self_strong_ref](std::string func_name, ffi::Function pf) {
CHECK_NE(func_name, kAddFunctionName)
<< "func_name: cannot be special function " << kAddFunctionName;
functions_[func_name] = pf;
});
}
auto it = functions_.find(name);
if (it == functions_.end()) {
return std::nullopt;
}
return it->second;
}
ffi::Module NewFrontendTestModule() {
auto n = ffi::make_object<FrontendTestModuleNode>();
return ffi::Module(n);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("testing.FrontendTestModule", NewFrontendTestModule)
.def(
"testing.sleep_in_ffi",
[](double timeout) {
std::chrono::duration<int64_t, std::nano> duration(static_cast<int64_t>(timeout * 1e9));
std::this_thread::sleep_for(duration);
})
.def("testing.ReturnsVariant",
[](int x) -> ffi::Variant<ffi::String, IntImm> {
if (x % 2 == 0) {
return IntImm(DataType::Int(64), x / 2);
} else {
return ffi::String("argument was odd");
}
})
.def("testing.AcceptsVariant",
[](ffi::Variant<ffi::String, Integer> arg) -> ffi::String {
if (auto opt_str = arg.as<ffi::String>()) {
return ffi::StaticTypeKey::kTVMFFIStr;
} else {
return arg.get<Integer>().GetTypeKey();
}
})
.def("testing.AcceptsBool", [](bool arg) -> bool { return arg; })
.def("testing.AcceptsInt", [](int arg) -> int { return arg; })
.def("testing.AcceptsObjectRefArray", [](ffi::Array<Any> arg) -> Any { return arg[0]; })
.def("testing.AcceptsMapReturnsValue",
[](ffi::Map<Any, Any> map, Any key) -> Any { return map[key]; })
.def("testing.AcceptsMapReturnsMap", [](ffi::Map<Any, Any> map) -> ObjectRef { return map; })
.def("testing.AcceptsPrimExpr", [](PrimExpr expr) -> ObjectRef { return expr; })
.def("testing.AcceptsArrayOfPrimExpr",
[](ffi::Array<PrimExpr> arr) -> ObjectRef {
for (ObjectRef item : arr) {
CHECK(item->IsInstance<PrimExprNode>()) << "Array contained " << item->GetTypeKey()
<< " when it should contain PrimExpr";
}
return arr;
})
.def("testing.AcceptsArrayOfVariant",
[](ffi::Array<ffi::Variant<ffi::Function, PrimExpr>> arr) -> ObjectRef {
for (auto item : arr) {
CHECK(item.as<PrimExpr>() || item.as<ffi::Function>())
<< "Array should contain either PrimExpr or ffi::Function";
}
return arr;
})
.def("testing.AcceptsMapOfPrimExpr", [](ffi::Map<Any, PrimExpr> map) -> ObjectRef {
for (const auto& kv : map) {
ObjectRef value = kv.second;
CHECK(value->IsInstance<PrimExprNode>())
<< "Map contained " << value->GetTypeKey() << " when it should contain PrimExpr";
}
return map;
});
}
/**
* Simple event logger that can be used for testing purposes
*/
class TestingEventLogger {
public:
struct Entry {
ffi::String event;
double time_us;
};
TestingEventLogger() {
entries_.reserve(1024);
start_ = std::chrono::high_resolution_clock::now();
}
void Record(ffi::String event) {
auto tend = std::chrono::high_resolution_clock::now();
double time_us = static_cast<double>((tend - start_).count()) / 1e3;
entries_.emplace_back(Entry{event, time_us});
}
void Reset() { entries_.clear(); }
void Dump() const {
for (const Entry& e : entries_) {
LOG(INFO) << e.event << "\t" << e.time_us << " us";
}
}
static TestingEventLogger* ThreadLocal() {
thread_local TestingEventLogger inst;
return &inst;
}
private:
std::chrono::high_resolution_clock::time_point start_;
std::vector<Entry> entries_;
};
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_packed("testing.record_event",
[](ffi::PackedArgs args, ffi::Any* rv) {
if (args.size() != 0 && args[0].try_cast<ffi::String>()) {
TestingEventLogger::ThreadLocal()->Record(args[0].cast<ffi::String>());
} else {
TestingEventLogger::ThreadLocal()->Record("X");
}
})
.def_packed(
"testing.reset_events",
[](ffi::PackedArgs args, ffi::Any* rv) { TestingEventLogger::ThreadLocal()->Reset(); })
.def("testing.dump_events", []() { TestingEventLogger::ThreadLocal()->Dump(); });
}
} // namespace tvm