blob: f2af290f3e221a2a8ebc8fadb634891c2eaf3648 [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
* \file
* \brief Mutate conv2d operator to sparse conv2d operator
#include <tvm/ir/expr.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
namespace tvm {
namespace relay {
// Search conv2d op weight name from Expr
class Conv2dOpWeightVisitor : private ExprVisitor {
Conv2dOpWeightVisitor() : conv2d_op_(Op::Get("nn.conv2d")) {}
Array<String> Search(const Expr& expr) {
return memo_;
void VisitExpr_(const CallNode* n) final {
if (n->op == conv2d_op_) {
const auto weight = n->args[1].as<VarNode>();
if (weight) {
for (const auto& arg : n->args) {
// Cache op
const Op& conv2d_op_;
Array<String> memo_;
}; // SearchConv2dOpWeight
Array<String> SearchConv2dOpWeight(const Expr& e) { return Conv2dOpWeightVisitor().Search(e); }
// Mutate ```nn.conv2d``` to ```nn.sparse_conv2d```
class Conv2dToSparseConv2dMutator : public ExprRewriter {
Conv2dToSparseConv2dMutator(const Array<ObjectRef>& weight_name,
const Array<Array<PrimExpr>>& weight_shape, const String& layout,
int kernel_size)
: conv2d_op_(Op::Get("nn.conv2d")), sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")) {
ICHECK_EQ(weight_name.size(), weight_shape.size());
layout_ = layout;
kernel_size_ = kernel_size;
for (size_t i = 0; i < weight_name.size(); ++i) {
std::string k = weight_name[i].as<runtime::StringObj>()->data;
const auto& ws = weight_shape[i];
std::vector<int> v(ws.size());
for (size_t j = 0; j < ws.size(); ++j) {
v[j] = ws[j].as<IntImmNode>()->value;
target_weights_.emplace(k, v);
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
if (pre->op == conv2d_op_) {
const auto weight = pre->args[1].as<VarNode>();
if (weight) {
if (target_weights_.count(weight->name_hint())) {
const auto& prefix = weight->name_hint();
const auto& ws =;
const auto data =<CallNode>()->args[0];
relay::TensorType ws_data_type, ws_indices_type, ws_indptr_type;
if (ws.size() == 5) {
ws_data_type = relay::TensorType({,,}, DataType::Float(32));
ws_indices_type = relay::TensorType({}, DataType::Int(32));
ws_indptr_type = relay::TensorType({}, DataType::Int(32));
} else if (ws.size() == 4) {
ws_data_type = relay::TensorType({,}, DataType::Float(32));
ws_indices_type = relay::TensorType({}, DataType::Int(32));
ws_indptr_type = relay::TensorType({}, DataType::Int(32));
Var weight_data(prefix + ".data", ws_data_type);
Var weight_indices(prefix + ".indices", ws_indices_type);
Var weight_indptr(prefix + ".indptr", ws_indptr_type);
auto attrs = make_object<SparseConv2DAttrs>();
attrs->layout = std::move(layout_);
attrs->kernel_size = Array<IndexExpr>{kernel_size_, kernel_size_};
return Call(sparse_conv2d_op_, {data, weight_data, weight_indices, weight_indptr},
return post;
// Cached op
const Op& conv2d_op_;
const Op& sparse_conv2d_op_;
std::unordered_map<std::string, std::vector<int>> target_weights_;
String layout_;
int kernel_size_;
}; // class Conv2dToSparseConv2dAlter
Expr Conv2dToSparse(const Expr& e, const Array<ObjectRef>& weight_name,
const Array<Array<PrimExpr>>& weight_shape, const String& layout,
int kernel_size) {
auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout, kernel_size);
return PostOrderRewrite(e, &rewriter);
template <typename elemTy, size_t... Is>
auto unpack_to_tuple_internal(elemTy* arr, std::index_sequence<Is...>) {
return std::make_tuple(arr[Is]...);
template <int N, typename elemTy>
auto unpack_to_tuple(elemTy* arr) {
return unpack_to_tuple_internal(arr, std::make_index_sequence<N>{});
struct Range {
size_t dim;
explicit Range(size_t d) : dim(d) {}
struct iterpoint {
size_t val, lim;
iterpoint(size_t v1, size_t v2) : val(v1), lim(v2) {}
size_t operator*() const { return val; }
iterpoint operator/(const iterpoint& rhs) const {
return iterpoint(val * rhs.lim + rhs.val, lim * rhs.lim);
struct iterator {
size_t val, lim;
iterator(size_t v1, size_t v2) : val(v1), lim(v2) {}
bool operator!=(const iterator& rhs) const { return val != rhs.val; }
void operator++() { ++val; }
iterpoint operator*() const { return iterpoint(val, lim); }
iterator begin() { return iterator(0, dim); }
iterator end() { return iterator(dim, dim); }
// Mutate ```nn.conv2d``` to ```nn.sparse_conv2d```
class Conv2dToSparseConv2dMutator2 : public ExprRewriter {
Conv2dToSparseConv2dMutator2(const String& layout, int kernel_size, int blockH, int blockW,
double sparse_thresh)
: sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")),
dev_cpu0_{DLDeviceType::kDLCPU, 0},
sparse_thresh_(sparse_thresh) {}
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
// check op type & attrs
const auto pre_attrs = pre-><Conv2DAttrs>();
if (!pre_attrs || pre_attrs->data_layout != layout_ ||
pre_attrs->strides[0].as<IntImmNode>()->value != 1 ||
pre_attrs->kernel_size[0].as<IntImmNode>()->value != kernel_size_)
return post;
// check constant weight
const auto pre_weight_node = pre->args[1].as<ConstantNode>();
if (!pre_weight_node) return post;
// check weight dtype & shape
auto&& pre_weight = pre_weight_node->data;
auto dtype = pre_weight.DataType(), itype = runtime::DataType::Int(32);
ICHECK(dtype.code() == DataType::kFloat && dtype.bits() == 32); // float32 only
auto pre_weight_shape = unpack_to_tuple<4>(pre_weight.Shape().data());
int O, I, H, W;
if (layout_ == "NCHW") {
std::tie(O, I, H, W) = pre_weight_shape;
} else { // NHWC
std::tie(H, W, I, O) = pre_weight_shape;
int CO = O, CI = H * W * I;
// copy to vector
std::vector<float> pre_weight_data(CO * CI);
pre_weight.CopyToBytes(, pre_weight_data.size() * sizeof(float));
if (layout_ == "NHWC") {
std::vector<float> tmp(pre_weight_data.size());
for (auto i : Range(CO))
for (auto j : Range(CI)) tmp[*(i / j)] = pre_weight_data[*(j / i)];
std::swap(tmp, pre_weight_data);
// convert to BSR
std::vector<float> wdata, block(blockH_ * blockW_);
std::vector<int32_t> windices, windptr;
for (auto bh : Range(CO / blockH_)) {
for (auto bw : Range(CI / blockW_)) {
int cntnnz = 0;
for (auto i : Range(blockH_))
for (auto j : Range(blockW_)) {
auto tmp = pre_weight_data[*(bh / i / bw / j)];
if (tmp) cntnnz++;
block[*(i / j)] = tmp;
if (cntnnz) {
wdata.insert(wdata.end(), block.begin(), block.end());
double sprate = 1 - 1.0 * wdata.size() / pre_weight_data.size();
if (sprate < sparse_thresh_) return post;
// constrct return data
int nnz = windices.size();
auto weight_data = runtime::NDArray::Empty({nnz, blockH_, blockW_}, dtype, dev_cpu0_);
auto weight_indices = runtime::NDArray::Empty({nnz}, itype, dev_cpu0_);
auto weight_indptr = runtime::NDArray::Empty({CO / blockH_ + 1}, itype, dev_cpu0_);
weight_data.CopyFromBytes(, wdata.size() * sizeof(float));
weight_indices.CopyFromBytes(, windices.size() * sizeof(int32_t));
weight_indptr.CopyFromBytes(, windptr.size() * sizeof(int32_t));
// construct return call
auto args = runtime::Array<relay::Expr>{<CallNode>()->args[0], Constant(weight_data),
Constant(weight_indices), Constant(weight_indptr)};
auto attrs = make_object<SparseConv2DAttrs>();
attrs->layout = layout_;
attrs->kernel_size = Array<IndexExpr>{kernel_size_, kernel_size_};
return Call(sparse_conv2d_op_, args, Attrs(attrs));
const Op& sparse_conv2d_op_;
DLDevice dev_cpu0_;
String layout_;
int kernel_size_, blockH_, blockW_;
double sparse_thresh_;
}; // class Conv2dToSparseConv2dMutator2
Expr Conv2dToSparse2(const Expr& e, const String& layout, int kernel_size, int blockH, int blockW,
double sparse_thresh) {
auto rewriter = Conv2dToSparseConv2dMutator2(layout, kernel_size, blockH, blockW, sparse_thresh);
return PostOrderRewrite(e, &rewriter);
namespace transform {
// Convert a model with seperate weight info (already sparsified).
Pass Conv2dToSparse(const Array<ObjectRef>& weight_name, const Array<Array<PrimExpr>>& weight_shape,
const String& layout, int kernel_size) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
// Remove FreeVar warnings
auto f0 =
Downcast<Function>(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size));
Array<Var> sparse_params = FreeVars(f0);
auto f1 = WithFields(f0, sparse_params);
Array<Var> params = FreeVars(f1);
for (const auto& var : sparse_params) {
return WithFields(f1, params);
return CreateFunctionPass(pass_func, 4, "Conv2dToSparse", {"DeadCodeElimination"});
// Convert a model with freezed params (sparsified in the pass).
Pass Conv2dToSparse2(const String& layout, int kernel_size, int blockH, int blockW,
double sparse_thresh) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
auto f0 = Downcast<Function>(
Conv2dToSparse2(f, layout, kernel_size, blockH, blockW, sparse_thresh));
return f0;
return CreateFunctionPass(pass_func, 5, "Conv2dToSparse2", {"DeadCodeElimination"});
} // namespace transform
} // namespace relay
} // namespace tvm