blob: e3f35e9f0c196fa3a92014261df78ec2b141647b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file auto_scheduler/search_task.cc
* \brief Meta information and hardware parameters for a search task.
*/
#include <tvm/auto_scheduler/search_task.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/threading_backend.h>
#include <utility>
namespace tvm {
namespace auto_scheduler {
TVM_REGISTER_NODE_TYPE(HardwareParamsNode);
TVM_REGISTER_NODE_TYPE(SearchTaskNode);
HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes) {
auto node = make_object<HardwareParamsNode>();
node->num_cores = num_cores;
node->vector_unit_bytes = vector_unit_bytes;
node->cache_line_bytes = cache_line_bytes;
data_ = std::move(node);
}
HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target,
const Target& target_host) {
if (target->kind->device_type == kDLCPU) {
return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64);
} else if (target->kind->device_type == kDLGPU) {
auto hardware_params = HardwareParams(-1, 16, 64);
auto* p_hardware_params = hardware_params.CopyOnWrite();
auto ctx = TVMContext{kDLGPU, 0};
auto func = tvm::runtime::Registry::Get("device_api.gpu");
CHECK(func != nullptr) << "Cannot find GPU device_api in registry";
auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());
tvm::runtime::TVMRetValue ret;
device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret);
p_hardware_params->max_shared_memory_per_block = ret;
device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, &ret);
p_hardware_params->max_registers_per_block = ret;
device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret);
p_hardware_params->max_threads_per_block = ret;
device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret);
p_hardware_params->warp_size = ret;
p_hardware_params->max_vthread_extent = p_hardware_params->warp_size / 4;
return hardware_params;
} else {
LOG(FATAL) << "No default hardware parameters for target: " << target;
}
return HardwareParams();
}
SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target,
Target target_host, Optional<HardwareParams> hardware_params) {
auto node = make_object<SearchTaskNode>();
node->compute_dag = std::move(compute_dag);
node->workload_key = std::move(workload_key);
node->target = std::move(target);
node->target_host = std::move(target_host);
if (hardware_params) {
node->hardware_params = hardware_params.value();
} else {
node->hardware_params =
HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host);
}
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams")
.set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes) {
return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes);
});
TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask")
.set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target,
Target target_host, Optional<HardwareParams> hardware_params) {
return SearchTask(compute_dag, workload_key, target, target_host, hardware_params);
});
} // namespace auto_scheduler
} // namespace tvm