blob: 2ca6f8c71093f28431f939ad5ed31cc4ae27a345 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file transformer.cc
* \brief CPU implementation of the operators used in Transformer
*/
#include <mxnet/base.h>
#include "./transformer-inl.h"
#include "../tensor/elemwise_unary_op.h"
namespace mxnet {
namespace op {
DMLC_REGISTER_PARAMETER(InterleavedMatMulParam);
static bool InterleavedMatMulSelfAttQKShape(const NodeAttrs& attrs,
mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 1U) << "Input:[queries_keys_values] currently have, "
<< in_shape->size() << " inputs";
auto qkv_shape = in_shape->at(0);
CHECK_EQ(qkv_shape.ndim(), 3U)
<< "Input queries_keys_values should be 3D in seq_length-batch-proj_dim, "
<< "currently is: " << qkv_shape.ndim() << "D";
out_shape->resize(1);
SHAPE_ASSIGN_CHECK(*out_shape, 0,
mxnet::TShape({params.heads * qkv_shape[1], qkv_shape[0], qkv_shape[0]}));
return true;
}
static bool InterleavedMatMulSelfAttValAttShape(const NodeAttrs& attrs,
mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
CHECK_EQ(in_shape->size(), 2U) << "Input:[queries_keys_values, attention] currently have, "
<< in_shape->size() << " inputs";
auto qkv_shape = in_shape->at(0);
auto att_shape = in_shape->at(1);
CHECK_EQ(qkv_shape.ndim(), 3U)
<< "Input queries_keys_values should be 3D in seq_length-batch-3*proj_dim, "
<< "currently is: " << qkv_shape.ndim() << "D";
CHECK_EQ(att_shape.ndim(), 3U)
<< "Input attention should be 3D in batch-seq_length-seq_length, "
<< "currently is: " << att_shape.ndim() << "D";
CHECK_EQ(qkv_shape[0], att_shape[1])
<< "queries_keys_values.shape[0] and attention.shape[1] should be the same, "
<< "currently are " << qkv_shape[0] << " and " << att_shape[1];
CHECK_EQ(qkv_shape[0], att_shape[2])
<< "queries_keys_values.shape[0] and attention.shape[2] should be the same, "
<< "currently are " << qkv_shape[0] << " and " << att_shape[2];
CHECK_EQ(qkv_shape[2] % 3, 0)
<< "queries_keys_values.shape[2] should be a multiple of 3, "
<< "currently is " << qkv_shape[2];
SHAPE_ASSIGN_CHECK(*out_shape, 0,
mxnet::TShape({qkv_shape[0], qkv_shape[1], qkv_shape[2] / 3}));
return true;
}
static bool InterleavedMatMulEncDecQKShape(const NodeAttrs& attrs,
mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 2U) << "Input:[queries, keys_values], currently have "
<< in_shape->size() << " inputs";
auto q_shape = in_shape->at(0);
auto kv_shape = in_shape->at(1);
CHECK_EQ(q_shape.ndim(), 3U) << "Input queries should be 3D in seq_length-batch-proj_dim, "
<< "currently is " << q_shape.ndim() << "D";
CHECK_EQ(kv_shape.ndim(), 3U) << "Input queries should be 3D in seq_length-batch-2*proj_dim, "
<< "currently is " << kv_shape.ndim() << "D";
CHECK_EQ(q_shape[2] * 2, kv_shape[2])
<< "keys_values.shape[2] should be equal to queries.shape[2] * 2, "
<< "currently are: " << kv_shape[2] << " and " << q_shape[2];
CHECK_EQ(q_shape[1], kv_shape[1])
<< "queries.shape[1] should be equal to keys_values.shape[1], "
<< "currently are: " << q_shape[1] << " and " << kv_shape[1];
SHAPE_ASSIGN_CHECK(*out_shape, 0,
mxnet::TShape({q_shape[1] * params.heads, q_shape[0], kv_shape[0]}));
return true;
}
static bool InterleavedMatMulEncDecValAttShape(const NodeAttrs& attrs,
mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 2U) << "Input: [keys_values, attention], currently have "
<< in_shape->size() << " inputs";
auto kv_shape = in_shape->at(0);
auto att_shape = in_shape->at(1);
CHECK_EQ(kv_shape.ndim(), 3U)
<< "Input keys_values should be 3D in seq_length-batch-2*proj_dim, "
<< "currently is " << kv_shape.ndim() << "D";
CHECK_EQ(att_shape.ndim(), 3U)
<< "Input attention should be 3D in batch-seq_length-seq_length, "
<< "currently is " << att_shape.ndim() << "D";
CHECK_EQ(kv_shape[0], att_shape[2])
<< "keys_values.shape[0] should be equal to attention.shape[2], currently are "
<< kv_shape[0] << " and " << att_shape[2];
CHECK_EQ(kv_shape[1] * params.heads, att_shape[0]) << "attention.shape[0] "
<< "should be equal to keys_values.shape[1] * heads, currently are: "
<< att_shape[2] << " and " << kv_shape[1];
SHAPE_ASSIGN_CHECK(*out_shape, 0,
mxnet::TShape({att_shape[1], kv_shape[1], kv_shape[2] / 2}));
return true;
}
NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_qk)
.describe(R"code(Compute the matrix multiplication between the projections of
queries and keys in multihead attention use as self attention.
the input must be a single tensor of interleaved projections
of queries, keys and values following the layout:
(seq_length, batch_size, num_heads * head_dim * 3)
the equivalent code would be:
tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1))
q_proj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3))
q_proj = mx.nd.reshape(q_proj, shape=(-1, 0, 0), reverse=True)
q_proj = mx.nd.contrib.div_sqrt_dim(q_proj)
k_proj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3))
k_proj = mx.nd.reshap(k_proj, shape=(-1, 0, 0), reverse=True)
output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True)
This Op is GPU only
)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<InterleavedMatMulParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"queries_keys_values"};
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
.set_attr<mxnet::FInferShape>("FInferShape", InterleavedMatMulSelfAttQKShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_interleaved_matmul_selfatt_qk"})
.add_argument("queries_keys_values", "NDArray-or-Symbol", "Interleaved queries, keys and values")
.add_arguments(InterleavedMatMulParam::__FIELDS__());
NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_qk)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<InterleavedMatMulParam>);
NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_valatt)
.describe(R"code(Compute the matrix multiplication between the projections of
values and the attention weights in multihead attention use as self attention.
the inputs must be a tensor of interleaved projections
of queries, keys and values following the layout:
(seq_length, batch_size, num_heads * head_dim * 3)
and the attention weights following the layout:
(batch_size, seq_length, seq_length)
the equivalent code would be:
tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1))
v_proj = mx.nd.transpose(tmp[:,:,:,2,:], axes=(1, 2, 0, 3))
v_proj = mx.nd.reshape(v_proj, shape=(-1, 0, 0), reverse=True)
output = mx.nd.batch_dot(attention, v_proj, transpose_b=True)
output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True)
output = mx.nd.transpose(output, axes=(0, 2, 1, 3))
output = mx.nd.reshape(output, shape=(0, 0, -1))
This Op is GPU only
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<InterleavedMatMulParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"queries_keys_values", "attention"};
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
.set_attr<mxnet::FInferShape>("FInferShape", InterleavedMatMulSelfAttValAttShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_interleaved_matmul_selfatt_valatt"})
.add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved")
.add_argument("attention", "NDArray-or-Symbol", "Attention maps")
.add_arguments(InterleavedMatMulParam::__FIELDS__());
NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_valatt)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<InterleavedMatMulParam>);
NNVM_REGISTER_OP(_contrib_interleaved_matmul_encdec_qk)
.describe(R"code(Compute the matrix multiplication between the projections of
queries and keys in multihead attention use as encoder-decoder.
the inputs must be a tensor of projections of queries following the layout:
(seq_length, batch_size, num_heads * head_dim)
and a tensor of interleaved projections of values and keys following the layout:
(seq_length, batch_size, num_heads * head_dim * 2)
the equivalent code would be:
q_proj = mx.nd.transpose(queries, axes=(1, 2, 0, 3))
q_proj = mx.nd.reshape(q_proj, shape=(-1, 0, 0), reverse=True)
q_proj = mx.nd.contrib.div_sqrt_dim(q_proj)
tmp = mx.nd.reshape(keys_values, shape=(0, 0, num_heads, 2, -1))
k_proj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3))
k_proj = mx.nd.reshap(k_proj, shape=(-1, 0, 0), reverse=True)
output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True)
This Op is GPU only
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<InterleavedMatMulParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"queries", "keys_values"};
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
.set_attr<mxnet::FInferShape>("FInferShape", InterleavedMatMulEncDecQKShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_interleaved_matmul_encdec_qk"})
.add_argument("queries", "NDArray-or-Symbol", "Queries")
.add_argument("keys_values", "NDArray-or-Symbol", "Keys and values interleaved")
.add_arguments(InterleavedMatMulParam::__FIELDS__());
NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_qk)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<InterleavedMatMulParam>);
NNVM_REGISTER_OP(_contrib_interleaved_matmul_encdec_valatt)
.describe(R"code(Compute the matrix multiplication between the projections of
values and the attention weights in multihead attention use as encoder-decoder.
the inputs must be a tensor of interleaved projections of
keys and values following the layout:
(seq_length, batch_size, num_heads * head_dim * 2)
and the attention weights following the layout:
(batch_size, seq_length, seq_length)
the equivalent code would be:
tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1))
v_proj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3))
v_proj = mx.nd.reshape(v_proj, shape=(-1, 0, 0), reverse=True)
output = mx.nd.batch_dot(attention, v_proj, transpose_b=True)
output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True)
output = mx.nd.transpose(output, axes=(0, 2, 1, 3))
output = mx.nd.reshape(output, shape=(0, 0, -1))
This Op is GPU only
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<InterleavedMatMulParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"keys_values", "attention"};
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
.set_attr<mxnet::FInferShape>("FInferShape", InterleavedMatMulEncDecValAttShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_interleaved_matmul_encdec_valatt"})
.add_argument("keys_values", "NDArray-or-Symbol", "Keys and values interleaved")
.add_argument("attention", "NDArray-or-Symbol", "Attention maps")
.add_arguments(InterleavedMatMulParam::__FIELDS__());
NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_valatt)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<InterleavedMatMulParam>);
// relu
MXNET_OPERATOR_REGISTER_UNARY(_contrib_div_sqrt_dim)
.describe(R"code(Rescale the input by the square root of the channel dimension.
out = data / sqrt(data.shape[-1])
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", DivSqrtDimForward_<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_contrib_div_sqrt_dim"});
} // namespace op
} // namespace mxnet