blob: aa02d097e96603985d3005fd1b293407c6c48f26 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/node/structural_hash.cc
*/
#include <dmlc/memory_io.h>
#include <tvm/ffi/extra/base64.h>
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/extra/structural_hash.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/access_path.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
#include <tvm/node/structural_hash.h>
#include <tvm/runtime/profiling.h>
#include <tvm/target/codegen.h>
#include <algorithm>
#include <unordered_map>
#include "../support/base64.h"
#include "../support/str_escape.h"
#include "../support/utils.h"
namespace tvm {
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("node.StructuralHash",
[](const Any& object, bool map_free_vars) -> int64_t {
return ffi::StructuralHash::Hash(object, map_free_vars);
});
refl::TypeAttrDef<ffi::ModuleObj>()
.def("__data_to_json__",
[](const ffi::ModuleObj* node) {
std::string bytes = codegen::SerializeModuleToBytes(ffi::GetRef<ffi::Module>(node),
/*export_dso*/ false);
return ffi::Base64Encode(ffi::Bytes(bytes));
})
.def("__data_from_json__", [](const ffi::String& base64_bytes) {
ffi::Bytes bytes = ffi::Base64Decode(base64_bytes);
ffi::Module rtmod = codegen::DeserializeModuleFromBytes(bytes.operator std::string());
return rtmod;
});
refl::TypeAttrDef<runtime::Tensor::Container>()
.def("__data_to_json__",
[](const runtime::Tensor::Container* node) {
std::string blob;
dmlc::MemoryStringStream mstrm(&blob);
support::Base64OutStream b64strm(&mstrm);
runtime::SaveDLTensor(&b64strm, node);
b64strm.Finish();
return ffi::String(blob);
})
.def("__data_from_json__", [](const std::string& blob) {
dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
support::Base64InStream b64strm(&mstrm);
b64strm.InitPosition();
runtime::Tensor temp;
ICHECK(temp.Load(&b64strm));
return temp;
});
}
uint64_t StructuralHash::operator()(const ffi::Any& object) const {
return ffi::StructuralHash::Hash(object, false);
}
struct RefToObjectPtr : public ObjectRef {
static ObjectPtr<Object> Get(const ObjectRef& ref) {
return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(ref);
}
};
struct ReportNodeTrait {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<runtime::profiling::ReportNode>()
.def_ro("calls", &runtime::profiling::ReportNode::calls)
.def_ro("device_metrics", &runtime::profiling::ReportNode::device_metrics)
.def_ro("configuration", &runtime::profiling::ReportNode::configuration);
}
};
TVM_FFI_STATIC_INIT_BLOCK() { ReportNodeTrait::RegisterReflection(); }
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<runtime::profiling::ReportNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const runtime::profiling::ReportNode*>(node.get());
p->stream << op->AsTable();
});
struct CountNodeTrait {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<runtime::profiling::CountNode>().def_ro("value",
&runtime::profiling::CountNode::value);
}
};
TVM_FFI_STATIC_INIT_BLOCK() { CountNodeTrait::RegisterReflection(); }
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<runtime::profiling::CountNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const runtime::profiling::CountNode*>(node.get());
p->stream << op->GetTypeKey() << "(" << op->value << ")";
});
struct DurationNodeTrait {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<runtime::profiling::DurationNode>().def_ro(
"microseconds", &runtime::profiling::DurationNode::microseconds);
}
};
TVM_FFI_STATIC_INIT_BLOCK() { DurationNodeTrait::RegisterReflection(); }
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<runtime::profiling::DurationNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const runtime::profiling::DurationNode*>(node.get());
p->stream << op->GetTypeKey() << "(" << op->microseconds << ")";
});
struct PercentNodeTrait {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<runtime::profiling::PercentNode>().def_ro(
"percent", &runtime::profiling::PercentNode::percent);
}
};
TVM_FFI_STATIC_INIT_BLOCK() { PercentNodeTrait::RegisterReflection(); }
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<runtime::profiling::PercentNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const runtime::profiling::PercentNode*>(node.get());
p->stream << op->GetTypeKey() << "(" << op->percent << ")";
});
struct RatioNodeTrait {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<runtime::profiling::RatioNode>().def_ro("ratio",
&runtime::profiling::RatioNode::ratio);
}
};
TVM_FFI_STATIC_INIT_BLOCK() { RatioNodeTrait::RegisterReflection(); }
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<runtime::profiling::RatioNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const runtime::profiling::RatioNode*>(node.get());
p->stream << op->GetTypeKey() << "(" << op->ratio << ")";
});
} // namespace tvm