blob: 4b41dae440ce5f86355f5e1ad4ae5a977a60e08e [file] [log] [blame]
``mx.nd.linalg.gemm``
==========================================
Description
----------------------
Performs general matrix multiplication and accumulation.
Input are tensors *A*, *B*, *C*, each of dimension *n >= 2* and having the same shape
on the leading *n-2* dimensions.
If *n=2*, the BLAS3 function *gemm* is performed:
*out* = *alpha* \* *op*\ (*A*) \* *op*\ (*B*) + *beta* \* *C*
Here, *alpha* and *beta* are scalar parameters, and *op()* is either the identity or
matrix transposition (depending on *transpose_a*, *transpose_b*).
If *n>2*, *gemm* is performed separately for a batch of matrices. The column indices of the matrices
are given by the last dimensions of the tensors, the row indices by the axis specified with the *axis*
parameter. By default, the trailing two dimensions will be used for matrix encoding.
For a non-default axis parameter, the operation performed is equivalent to a series of swapaxes/gemm/swapaxes
calls. For example let *A*, *B*, *C* be 5 dimensional tensors. Then gemm(*A*, *B*, *C*, axis=1) is equivalent
to the following without the overhead of the additional swapaxis operations::
A1 = swapaxes(A, dim1=1, dim2=3)
B1 = swapaxes(B, dim1=1, dim2=3)
C = swapaxes(C, dim1=1, dim2=3)
C = gemm(A1, B1, C)
C = swapaxis(C, dim1=1, dim2=3)
When the input data is of type float32 and the environment variables MXNET_CUDA_ALLOW_TENSOR_CORE
and MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION are set to 1, this operator will try to use
pseudo-float16 precision (float32 math with float16 I/O) precision in order to use
Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.
.. note:: The operator supports float32 and float64 data types only.
**Example**::
Single matrix multiply-add
A = [[1.0, 1.0], [1.0, 1.0]]
B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]
C = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
gemm(A, B, C, transpose_b=True, alpha=2.0, beta=10.0)
= [[14.0, 14.0, 14.0], [14.0, 14.0, 14.0]]
Batch matrix multiply-add
A = [[[1.0, 1.0]], [[0.1, 0.1]]]
B = [[[1.0, 1.0]], [[0.1, 0.1]]]
C = [[[10.0]], [[0.01]]]
gemm(A, B, C, transpose_b=True, alpha=2.0 , beta=10.0)
= [[[104.0]], [[0.14]]]
Arguments
------------------
+----------------------------------------+------------------------------------------------------------+
| Argument | Description |
+========================================+============================================================+
| ``A`` | NDArray-or-Symbol. |
| | |
| | Tensor of input matrices |
+----------------------------------------+------------------------------------------------------------+
| ``B`` | NDArray-or-Symbol. |
| | |
| | Tensor of input matrices |
+----------------------------------------+------------------------------------------------------------+
| ``C`` | NDArray-or-Symbol. |
| | |
| | Tensor of input matrices |
+----------------------------------------+------------------------------------------------------------+
| ``transpose.a`` | boolean, optional, default=0. |
| | |
| | Multiply with transposed of first input (A). |
+----------------------------------------+------------------------------------------------------------+
| ``transpose.b`` | boolean, optional, default=0. |
| | |
| | Multiply with transposed of second input (B). |
+----------------------------------------+------------------------------------------------------------+
| ``alpha`` | double, optional, default=1. |
| | |
| | Scalar factor multiplied with A*B. |
+----------------------------------------+------------------------------------------------------------+
| ``beta`` | double, optional, default=1. |
| | |
| | Scalar factor multiplied with C. |
+----------------------------------------+------------------------------------------------------------+
| ``axis`` | int, optional, default='-2'. |
| | |
| | Axis corresponding to the matrix rows. |
+----------------------------------------+------------------------------------------------------------+
Value
----------
``out`` The result mx.ndarray
Link to Source Code: http://github.com/apache/incubator-mxnet/blob/1.6.0/src/operator/tensor/la_op.cc#L89