blob: 77e29fc7fdbb528062c1ef624b203eab5fff67a4 [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file cuda/dense.h
* \brief CUDA schedule for dense operation
*/
#ifndef TOPI_CUDA_DENSE_H_
#define TOPI_CUDA_DENSE_H_
#include "tvm/tvm.h"
#include "tvm/build_module.h"
#include "topi/tags.h"
#include "topi/detail/array_utils.h"
#include "topi/nn/dense.h"
#include "topi/contrib/cublas.h"
#include "topi/generic/extern.h"
namespace topi {
using namespace tvm;
namespace cuda {
/*!
* \brief Implementation of dense for CUDA backend
*
* \param target The target device
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
*
* \return Tensor with shape [batch, out_dim]
*/
inline tvm::Tensor dense_cuda(const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
const tvm::Tensor& bias) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias.defined()) {
CHECK_EQ(bias->shape.size(), 1) << "dense requires 1-D bias";
}
auto batch = data->shape[0];
auto in_dim = data->shape[1];
auto out_dim = weight->shape[0];
if (target->libs().count("cublas")) {
auto mm = topi::contrib::cublas_matmul(data, weight, false, true);
if (bias.defined()) {
mm = tvm::compute({ batch, out_dim },
[&](Var i, Var j) {
return mm(i, j) + bias(j);
}, "tensor", kBroadcast);
}
return mm;
} else {
return topi::nn::dense(data, weight, bias);
}
}
/*!
* \brief Create a CUDA schedule for dense
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_dense(const Target &target, const Array<Tensor>& outs) {
if (target->target_name == "cuda" &&
target->libs().count("cublas")) {
return topi::generic::schedule_extern(target, outs);
}
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
auto _schedule = [&](const Tensor& dense) {
auto num_thread = 64;
auto k = dense->op.as<ComputeOpNode>()->reduce_axis[0];
IterVar ko, kf;
s[dense].split(k, num_thread, &ko, &kf);
auto dense_f = s.rfactor(dense, kf)[0];
Tensor out;
if (detail::contains(s->outputs, dense->op)) {
out = dense;
} else {
out = outs[0]->op.output(0);
s[dense].compute_at(s[out], s[out]->op.as<ComputeOpNode>()->axis[1]);
}
s[out].bind(s[out]->op.as<ComputeOpNode>()->axis[0], tvm::thread_axis(Range(), "blockIdx.y"));
s[out].bind(s[out]->op.as<ComputeOpNode>()->axis[1], tvm::thread_axis(Range(), "blockIdx.x"));
auto tx = s[dense]->op.as<ComputeOpNode>()->reduce_axis[0];
auto thread_x = tvm::thread_axis(Range(), "threadIdx.x");
s[dense].bind(tx, thread_x);
s[dense_f].compute_at(s[dense], tx);
s[dense].set_store_predicate(static_cast<Expr>(thread_x) == 0);
s[out].set_store_predicate(static_cast<Expr>(thread_x) == 0);
};
std::function<void(Operation)> traverse;
traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output)
if (is_broadcast(op->tag)) {
if (!detail::contains(s->outputs, op)) {
s[op].compute_inline();
}
for (auto tensor : op->InputTensors()) {
if (tensor->op->InputTensors().size() > 0) {
traverse(tensor->op);
}
}
} else if (op->tag == "dense") {
// If tag starts with global_pool
auto dense = op.output(0);
_schedule(dense);
} else {
LOG(ERROR) << "Unsupported operator " << op->tag;
}
};
traverse(outs[0]->op);
return s;
}
} // namespace cuda
} // namespace topi
#endif // TOPI_CUDA_DENSE_H_