blob: 939a239e9c90fdc9e831606a2b7d30f3d56c8979 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file gemm_lib.cc
* \brief Sample 2D gemm custom operator implementation library file
*/
#include <iostream>
#include <utility>
#include "mxnet/lib_api.h"
using namespace mxnet::ext;
// main matrix multiplication routine
void gemm(const float* A, const float* B, float* C,
const unsigned n, const unsigned k, const unsigned m) {
unsigned i, j, kk;
for (i = 0; i < n; i++) {
for (j = 0; j < m; j++) {
C[i*m+j] = 0;
for (kk = 0; kk < k; kk++) {
C[i*m+j] += A[i*k+kk] * B[kk*m+j];
}
}
}
}
void transpose(const float* A, float* At, const unsigned n, const unsigned m) {
unsigned i, j;
for (i = 0; i < n; i++) {
for (j = 0; j < m; j++) {
At[i*m+j] = A[j*n+i];
}
}
}
/*
* Executes C = A * B
* inputs[0] = A; inputs[1] = B; outputs[0] = C
*/
MXReturnValue forward(const std::unordered_map<std::string, std::string>& attrs,
std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
const OpResource& res) {
// simple example of using runtime data type
if (inputs->at(0).dtype == kFloat32) {
typedef float DType;
// extract data pointers from tensors
// if using dltensor repr, below lines can be changed to something like
// DType* A = reinterpret_cast<DType*>(inputs[0].dltensor.data);
DType* A = inputs->at(0).data<DType>();
DType* B = inputs->at(1).data<DType>();
DType* C = outputs->at(0).data<DType>();
// set tensor shapes
unsigned n = inputs->at(0).shape[0];
unsigned k = inputs->at(0).shape[1];
unsigned m = inputs->at(1).shape[1];
gemm(A, B, C, n, k, m);
}
return MX_SUCCESS;
}
/*
* Executes dA = dC * B.T; Executes dB = A.T * dC
***** gradient inputs
* inputs[0] = dC
***** original inputs
* inputs[1] = A; inputs[2] = B
***** original outputs
* inputs[3] = C
***** gradient outputs
* outputs[0] = dA; outputs[1] = dB
*/
MXReturnValue backward(const std::unordered_map<std::string, std::string>& attrs,
std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
const OpResource& res) {
// extract data pointers from tensors
float* dC = inputs->at(0).data<float>();
float* A = inputs->at(1).data<float>();
float* B = inputs->at(2).data<float>();
float* dA = outputs->at(0).data<float>();
float* dB = outputs->at(1).data<float>();
// set tensor shapes
unsigned n = inputs->at(1).shape[0];
unsigned k = inputs->at(1).shape[1];
unsigned m = inputs->at(2).shape[1];
// allocate temporary workspace memory through resource manager
// for multiple arrays better to request a big memory pool
void *workspace = res.alloc_cpu((k*n + m*k) * sizeof(float));
float *At = static_cast<float*>(workspace);
float *Bt = static_cast<float*>(workspace) + (k*n);
transpose(A, At, k, n);
transpose(B, Bt, m, k);
gemm(dC, Bt, dA, n, m, k);
gemm(At, dC, dB, k, n, m);
return MX_SUCCESS;
}
MXReturnValue parseAttrs(const std::unordered_map<std::string, std::string>& attrs,
int* num_in, int* num_out) {
*num_in = 2;
*num_out = 1;
return MX_SUCCESS;
}
MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attrs,
std::vector<int> *intypes,
std::vector<int> *outtypes) {
// validate inputs
if (intypes->size() != 2) {
MX_ERROR_MSG << "Expected 2 inputs to inferType";
return MX_FAIL;
}
for (unsigned i = 0; i < intypes->size(); i++) {
if (intypes->at(i) != kFloat32) {
MX_ERROR_MSG << "Expected input " << i << " to have float32 type";
return MX_FAIL;
}
}
outtypes->at(0) = intypes->at(0);
return MX_SUCCESS;
}
MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& attrs,
std::vector<std::vector<unsigned int>>* inshapes,
std::vector<std::vector<unsigned int>>* outshapes) {
// validate inputs
if (inshapes->size() != 2) {
MX_ERROR_MSG << "Expected 2 inputs to inferShape";
return MX_FAIL;
}
if (inshapes->at(0).size() != 2 || inshapes->at(1).size() != 2) {
MX_ERROR_MSG << "Expected 2D matrices for both inputs to inferShape";
return MX_FAIL;
}
unsigned n = inshapes->at(0)[0];
unsigned k = inshapes->at(0)[1];
unsigned kk = inshapes->at(1)[0];
unsigned m = inshapes->at(1)[1];
if (k != kk) {
MX_ERROR_MSG << "Exected first input axis 1 equals to second input axis 0";
return MX_FAIL;
}
outshapes->at(0) = {n, m};
return MX_SUCCESS;
}
REGISTER_OP(my_gemm)
.setForward(forward, "cpu")
.setBackward(backward, "cpu")
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape);
/* ------------------------------------------------------------------------- */
class MyStatefulGemm : public CustomStatefulOp {
public:
explicit MyStatefulGemm(int count,
std::unordered_map<std::string, std::string> attrs)
: count(count), attrs_(std::move(attrs)) {}
MXReturnValue Forward(std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
const OpResource& op_res) override {
std::cout << "Info: keyword + number of forward: " << ++count << std::endl;
return forward(attrs_, inputs, outputs, op_res);
}
MXReturnValue Backward(std::vector<MXTensor>* inputs,
std::vector<MXTensor>* outputs,
const OpResource& op_res) override {
return backward(attrs_, inputs, outputs, op_res);
}
private:
int count;
const std::unordered_map<std::string, std::string> attrs_;
};
MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs,
const MXContext& ctx,
const std::vector<std::vector<unsigned int> >& in_shapes,
const std::vector<int> in_types,
CustomStatefulOp** op_inst) {
// testing passing of keyword arguments
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0;
// creating stateful operator instance
*op_inst = new MyStatefulGemm(count, attrs);
std::cout << "Info: stateful operator created" << std::endl;
return MX_SUCCESS;
}
MXReturnValue mutateInputs(const std::unordered_map<std::string, std::string>& attrs,
std::vector<int>* input_indices) {
// input_indices.push_back(1); // mark mutate input
return MX_SUCCESS;
}
REGISTER_OP(state_gemm)
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape)
.setMutateInputs(mutateInputs)
.setCreateOpState(createOpState, "cpu");
MXReturnValue initialize(int version) {
if (version >= 10700) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}