blob: e87daf6c80e231f137e2002372833a453d6acbe3 [file] [log] [blame]
``mx.symbol.linalg_gemm2``
====================================================
Description
----------------------
Performs general matrix multiplication.
Input are tensors *A*, *B*, each of dimension *n >= 2* and having the same shape
on the leading *n-2* dimensions.
If *n=2*, the BLAS3 function *gemm* is performed:
*out* = *alpha* \* *op*\ (*A*) \* *op*\ (*B*)
Here *alpha* is a scalar parameter and *op()* is either the identity or the matrix
transposition (depending on *transpose_a*, *transpose_b*).
If *n>2*, *gemm* is performed separately for a batch of matrices. The column indices of the matrices
are given by the last dimensions of the tensors, the row indices by the axis specified with the *axis*
parameter. By default, the trailing two dimensions will be used for matrix encoding.
For a non-default axis parameter, the operation performed is equivalent to a series of swapaxes/gemm/swapaxes
calls. For example let *A*, *B* be 5 dimensional tensors. Then gemm(*A*, *B*, axis=1) is equivalent to
the following without the overhead of the additional swapaxis operations::
A1 = swapaxes(A, dim1=1, dim2=3)
B1 = swapaxes(B, dim1=1, dim2=3)
C = gemm2(A1, B1)
C = swapaxis(C, dim1=1, dim2=3)
When the input data is of type float32 and the environment variables MXNET_CUDA_ALLOW_TENSOR_CORE
and MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION are set to 1, this operator will try to use
pseudo-float16 precision (float32 math with float16 I/O) precision in order to use
Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.
.. note:: The operator supports float32 and float64 data types only.
**Example**::
Single matrix multiply
A = [[1.0, 1.0], [1.0, 1.0]]
B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]
gemm2(A, B, transpose_b=True, alpha=2.0)
= [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]]
Batch matrix multiply
A = [[[1.0, 1.0]], [[0.1, 0.1]]]
B = [[[1.0, 1.0]], [[0.1, 0.1]]]
gemm2(A, B, transpose_b=True, alpha=2.0)
= [[[4.0]], [[0.04 ]]]
Usage
----------
.. code:: r
mx.symbol.linalg_gemm2(...)
Arguments
------------------
+----------------------------------------+------------------------------------------------------------+
| Argument | Description |
+========================================+============================================================+
| ``A`` | NDArray-or-Symbol. |
| | |
| | Tensor of input matrices |
+----------------------------------------+------------------------------------------------------------+
| ``B`` | NDArray-or-Symbol. |
| | |
| | Tensor of input matrices |
+----------------------------------------+------------------------------------------------------------+
| ``transpose.a`` | boolean, optional, default=0. |
| | |
| | Multiply with transposed of first input (A). |
+----------------------------------------+------------------------------------------------------------+
| ``transpose.b`` | boolean, optional, default=0. |
| | |
| | Multiply with transposed of second input (B). |
+----------------------------------------+------------------------------------------------------------+
| ``alpha`` | double, optional, default=1. |
| | |
| | Scalar factor multiplied with A*B. |
+----------------------------------------+------------------------------------------------------------+
| ``axis`` | int, optional, default='-2'. |
| | |
| | Axis corresponding to the matrix row indices. |
+----------------------------------------+------------------------------------------------------------+
| ``name`` | string, optional. |
| | |
| | Name of the resulting symbol. |
+----------------------------------------+------------------------------------------------------------+
Value
----------
``out`` The result mx.symbol
Link to Source Code: http://github.com/apache/incubator-mxnet/blob/1.6.0/src/operator/tensor/la_op.cc#L163