blob: a7f708f12a1525da78818c795da3e5d670a49c23 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/target/compilation_config.cc
* \brief Implementation of \p CompilationConfig for collecting \p Targets.
*/
#include <tvm/runtime/device_api.h>
#include <tvm/target/compilation_config.h>
namespace tvm {
TVM_REGISTER_NODE_TYPE(CompilationConfigNode);
void CompilationConfigNode::VisitAttrs(AttrVisitor* v) {
v->Visit("host_target", &host_target);
v->Visit("primitive_targets", &primitive_targets);
v->Visit("default_primitive_virtual_device", &default_primitive_virtual_device);
v->Visit("host_virtual_device", &host_virtual_device);
v->Visit("optional_homogenous_target", &optional_homogeneous_target);
// NOTE: The virtual_device_cache_ is not accessible via FFI.
}
Target CompilationConfigNode::FindPrimitiveTargetForDeviceOrFail(DLDeviceType device_type) const {
ICHECK_GT(device_type, 0) << "Invalid device type";
auto itr = std::find_if(
primitive_targets.begin(), primitive_targets.end(),
[device_type](const Target& target) { return target->GetTargetDeviceType() == device_type; });
if (itr == primitive_targets.end()) {
std::stringstream msg;
msg << "No target is specified for device type " << device_type
<< ". The available device types and targets are:" << std::endl;
for (const auto& target : primitive_targets) {
msg << " " << target->GetTargetDeviceType() << "-> " << target->ToDebugString() << std::endl;
}
LOG(FATAL) << msg.str();
}
return *itr;
}
Optional<Target> CompilationConfigNode::FindPrimitiveTargetForKind(
const std::string& kind_name) const {
Optional<TargetKind> opt_kind = TargetKind::Get(kind_name);
if (!opt_kind.defined()) {
VLOG(1) << "No such target kind for '" << kind_name << "'";
return {};
}
auto itr =
std::find_if(primitive_targets.begin(), primitive_targets.end(),
[kind_name](const Target& target) { return target->kind->name == kind_name; });
if (itr == primitive_targets.end()) {
VLOG(1) << "No target available matching kind '" << kind_name << "'";
return {};
}
return *itr;
}
Target CompilationConfigNode::CanonicalTarget(const Target& target) const {
// Fast path -- object identity.
if (target == host_target) {
return target;
}
for (const auto& primitive_target : primitive_targets) {
if (target == primitive_target) {
return target;
}
}
// Slow path -- structural equality. We have so few targets it does not seem worth building an
// index.
if (StructuralEqual()(target, host_target)) {
return host_target;
}
for (const auto& primitive_target : primitive_targets) {
if (StructuralEqual()(target, primitive_target)) {
return primitive_target;
}
}
// No match.
return target;
}
VirtualDevice CompilationConfigNode::CanonicalVirtualDevice(
const VirtualDevice& virtual_device) const {
// Targets need special handling.
Target target = virtual_device->target;
if (target.defined()) {
// It is possible the given target object was constructed by the user, but was then
// rewritten on the way into the CompilationConfig. So 'canonicalize' it by replacing
// the given target with one structurally equal to one already known in the config if
// possible.
Target canon_target = CanonicalTarget(target);
if (canon_target != target) {
VLOG(1) << "Canonicalized target " << canon_target->ToDebugString();
}
target = canon_target;
} else if (virtual_device->device_type() != kInvalidDeviceType) {
// Since no target was given, choose one with a matching device type.
// This is the one place where we allow device types to imply targets.
target = FindPrimitiveTargetForDeviceOrFail(virtual_device->device_type());
VLOG(1) << "Defaulted to target " << target->ToDebugString();
}
// else: the target will remain unknown.
// Redirect to an existing structurally equal virtual device.
return virtual_device_cache_.Unique(VirtualDevice(virtual_device->device_type(),
virtual_device->virtual_device_id, target,
virtual_device->memory_scope));
}
void CompilationConfigNode::Init(const transform::PassContext& pass_ctx,
const Array<Target>& raw_targets) {
VLOG_CONTEXT << "CompilationConfig";
CHECK_GT(raw_targets.size(), 0U) << "Require at least one target";
//
// Decide on the host target.
//
// Any targets which could act as a host?
auto hosting_itr = std::find_if(raw_targets.begin(), raw_targets.end(), [](const Target& target) {
// TODO(tvm-team): The kDLHexagon device can act as a host. We can remove kDLHexagon
// here once we refactored kDLHexagon to kDLCPU.
return target->GetTargetDeviceType() == kDLCPU || target->GetTargetDeviceType() == kDLHexagon;
});
// Any targets with their host field set?
auto has_host_itr = std::find_if(raw_targets.begin(), raw_targets.end(),
[](const Target& target) { return target->host.defined(); });
if (has_host_itr != raw_targets.end()) {
// RULE A: If any raw target has a host, use the first such host for all the primitive
// targets.
host_target = Target((*has_host_itr)->GetHost().value(), /*host=*/Target());
VLOG(1) << "The target " << (*has_host_itr)->ToDebugString() << " supplies a host target "
<< host_target->ToDebugString() << " of device type "
<< host_target->GetTargetDeviceType();
} else if (hosting_itr != raw_targets.end()) {
// RULE B: If any raw target is for a device which could be a host then use the first such as
// the host.
host_target = Target(*hosting_itr, /*host=*/Target());
VLOG(1) << "Using target " << host_target->ToDebugString() << " of CPU-like device type "
<< host_target->GetTargetDeviceType() << " as the host target";
} else {
// RULE C: Otherwise, create a default CPU host target.
host_target = MakeDefaultCPUTarget();
VLOG(1) << "Created a default target " << host_target->ToDebugString() << " of device type "
<< host_target->GetTargetDeviceType() << " for the host target";
}
ICHECK(host_target.defined());
ICHECK(!host_target->host.defined());
if (host_target->GetTargetDeviceType() != kDLCPU) {
// I think we're on thin ice here until we've audited the code base for assumed CPU hosts.
VLOG(1) << "The host target is not a CPU. This is probably not going to work.";
}
//
// Establish the host VirtualDevice.
//
host_virtual_device = virtual_device_cache_.Unique(
VirtualDevice(static_cast<DLDeviceType>(host_target->GetTargetDeviceType()),
/*virtual_device_id=*/0, host_target));
ICHECK(host_virtual_device.defined());
ICHECK(host_virtual_device->target.defined());
//
// Now that we've settled on a host, we can set it as the host on all the raw targets.
//
primitive_targets.clear();
primitive_targets.reserve(raw_targets.size());
for (const auto& raw_target : raw_targets) {
if (raw_target->host.defined() && !StructuralEqual()(raw_target->host, host_target)) {
VLOG(1) << "The target " << raw_target->ToDebugString()
<< " already has a host which disagrees with the desired host target. It "
<< "will be overridden.";
}
primitive_targets.push_back(Target(raw_target, host_target));
}
ICHECK_GT(primitive_targets.size(), 0U);
//
// Check the primitive_targets are ordered correctly re Target::IsExternalCodegenFor,
// and make sure no two targets share a kind name.
//
// TODO(mbs): We could just sort the list, but given all the implicit defaulting for backwards
// compat it seems we should avoid making this any more magical than necessary. But revisit
// if usability suffers.
std::unordered_set<DLDeviceType> primitive_target_device_types;
std::unordered_set<std::string> kind_names;
for (const auto& target : primitive_targets) {
primitive_target_device_types.emplace(static_cast<DLDeviceType>(target->GetTargetDeviceType()));
CHECK(kind_names.emplace(target->kind->name).second) << "Multiple targets have been given"
"for the same device kind '"
<< target->kind->name << "'";
}
for (DLDeviceType device_type : primitive_target_device_types) {
Target first_primitive_target;
for (const auto& current_primitive_target : primitive_targets) {
if (current_primitive_target->GetTargetDeviceType() != device_type) {
continue;
}
if (!first_primitive_target.defined()) {
first_primitive_target = current_primitive_target;
// Note it is valid to have only one external codegen target.
} else {
CHECK(current_primitive_target.IsExternalCodegenFor(first_primitive_target))
<< "When given multiple targets for the device type " << device_type
<< " the first must be for non external codegen, and all subsequent must be for "
"external codegen. However have been given first "
<< first_primitive_target->ToDebugString() << " and subsequent "
<< current_primitive_target->ToDebugString();
}
}
}
//
// Decide on the default device type for primitives.
//
DLDeviceType default_primitive_device_type;
Optional<Integer> opt_fallback_dev = pass_ctx->GetConfig<Integer>("relay.fallback_device_type");
if (opt_fallback_dev) {
// RULE D: Respect the PassContext setting if given.
const int64_t v = opt_fallback_dev.value()->value;
CHECK_GT(v, 0)
<< "The 'relay.fallback_device_type' pass attribute is set to an invalid device type " << v;
default_primitive_device_type = static_cast<DLDeviceType>(v);
VLOG(1) << "Using the 'relay.fallback_device_type' pass attribute "
<< default_primitive_device_type
<< " as the default device type for all primitive operations";
} else if (primitive_target_device_types.size() == 1) {
// RULE E: Since only one device in use there's no choice to make.
default_primitive_device_type = *primitive_target_device_types.begin();
VLOG(1) << "All primitive targets have the device type " << default_primitive_device_type
<< " so that is also the default device type for all primitive operations.";
} else {
// RULE F: Fallback to CPU.
default_primitive_device_type = kDLCPU;
VLOG(1) << "Using " << default_primitive_device_type
<< " as the default device type for all primitive operations";
}
//
// Establish the default primitive VirtualDevice, choosing a known Target to match the device
// type. We do not create a default target, it must already exist as a primitive target.
//
default_primitive_virtual_device = CanonicalVirtualDevice(
VirtualDevice::ForDeviceType(default_primitive_device_type, /*virtual_device_id=*/0));
ICHECK(default_primitive_virtual_device.defined());
ICHECK(default_primitive_virtual_device->target.defined());
// Legacy: Some passes only support homogenous compilation and expect the target to be
// given by the global target context. Make this easy to detect.
optional_homogeneous_target =
primitive_targets.size() == 1 ? *primitive_targets.begin() : Target();
}
/* static */ Target CompilationConfigNode::MakeDefaultCPUTarget() {
if (runtime::Registry::Get("codegen.LLVMModuleCreate")) {
// LLVM is available.
// TODO(mbs): More robust extension mechanism?
return Target("llvm");
} else {
// LLVM is not available.
// TODO(mbs): Already deprecated?
return Target("stackvm");
}
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CompilationConfigNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = ref.as<CompilationConfigNode>();
p->stream << "Primitive targets:";
for (const auto& target : node->primitive_targets) {
p->stream << std::endl
<< " " << target->GetTargetDeviceType() << " |-> " << target->ToDebugString();
}
p->stream << std::endl
<< "Default primitive virtual device: " << node->default_primitive_virtual_device;
p->stream << std::endl << "Host virtual device: " << node->host_virtual_device;
});
CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx,
const Array<Target>& raw_targets) {
auto node = make_object<CompilationConfigNode>();
node->Init(pass_ctx, raw_targets);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("target.MakeCompilationConfig")
.set_body_typed([](const transform::PassContext& pass_ctx,
const Array<Target>& raw_targets) -> CompilationConfig {
return CompilationConfig(pass_ctx, raw_targets);
});
} // namespace tvm