blob: 700639c10a1bed7220fca4a80d8e1cfe7b959c9a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2D int8 schedule on ARM"""
import tvm
from tvm import te
from tvm.ir import register_intrin_lowering
def gemm_4x4_int8_int8_int32(M, N, K, unroll, in_type):
"""
Int8 4x4 matrix multiplication and accumulation using a sequence of
umull -> uadalp -> umull2 -> uadalp instructions. This function
takes two arrays of int8 data type A[4][K] and B[4][K], and produces
a 4x4 matrix which is equal to A*B'.
The pseudo code is as follows.
.. code-block:: c
void gemm_4x4_int8_int8_int32(int8 A[4][K], int8 B[4][K], int32 C[4][4]){
for (int i = 0; i < 4; i++){
for (int j = 0; j < 4; j++){
for (int k = 0; k < K; k++){
C[i][j] += A[i][k] * B[j][k]
}
}
}
Notes:
* The tiling strategy is picked to maximize register usage.
Parameters
----------
M : int
rows of the matrix A
N : int
columns of the matrix B
K : int
columns of matrix A
unroll : bool
Unroll the loop accumulation if True
in_type : str, {'uint8', 'int8'}
Returns
-------
intrin : TensorIntrin
The ARM uint8/int8 TensorIntrin that can be used in tensorizing schedule
"""
assert in_type in ["uint8", "int8"]
A = te.placeholder((K // 16, te.var("m"), 16), dtype=in_type, name="A")
B = te.placeholder((K // 16, te.var("n"), 16), dtype=in_type, name="B")
dtype_vec = in_type + "x16"
idxm = tvm.tir.indexmod
k = te.reduce_axis((0, K), "k")
C = te.compute(
(te.var("m"), te.var("n")),
lambda x, y: te.sum(
A[k // 16, x, idxm(k, 16)].astype("int32") * B[k // 16, y, idxm(k, 16)].astype("int32"),
axis=k,
),
name="C",
)
a_buffer = tvm.tir.decl_buffer(
A.shape,
dtype=in_type,
name="a_buffer",
offset_factor=1,
strides=[te.var("sa_1"), te.var("sa_2"), 1],
)
b_buffer = tvm.tir.decl_buffer(
B.shape,
dtype=in_type,
name="b_buffer",
offset_factor=1,
strides=[te.var("sb_1"), te.var("sb_2"), 1],
)
c_buffer = tvm.tir.decl_buffer(
C.shape, dtype="int32", name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]
)
# Intrinsics used in the following algorithm
umull_intrin = "llvm.aarch64.neon.umull" if in_type == "uint8" else "llvm.aarch64.neon.smull"
uaddlp_intrin = "llvm.aarch64.neon.uaddlp" if in_type == "uint8" else "llvm.aarch64.neon.saddlp"
addp_intrin = "llvm.aarch64.neon.addp"
def uadalp(a, b):
"""Add pair and accumulate
Parameters:
----------
a: int16x8 vector
b: int16x8 vector
Returns:
--------
return a int32x4 vector
Pseudocode:
----------
a += (b0+b1, b2+b3, b4+b5, b6+b7)
"""
return a + tvm.tir.call_llvm_pure_intrin(
"int32x4", uaddlp_intrin, tvm.tir.const(1, "uint32"), b
)
def umull(a, b):
"""Multiply long (higher part)
Parameters:
----------
a: int8x16 vector
b: int8x16 vector
Returns:
--------
return a int16x8 vector
Pseudocode:
----------
c = (a0*b0, a1*b1, a2*b2, a3*b3, a4*b4, a5*b5, a6*b6, a7*b7)
"""
a_high = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", a)
b_high = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", b)
c = tvm.tir.call_llvm_pure_intrin(
"int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_high, b_high
)
return c
def umull2(a, b):
"""Multiply long (lower part)
Parameters:
----------
a: int8x16 vector
b: int8x16 vector
Returns:
--------
return a int16x8 vector
Pseudocode:
----------
c = (a8*b8, a9*b9, a10*b10, a11*b11, a12*b12, a13*b13, a14*b14, a15*b15)
"""
a_low = tvm.tir.call_intrin("int8x8", "tir.vectorlow", a)
b_low = tvm.tir.call_intrin("int8x8", "tir.vectorlow", b)
c = tvm.tir.call_llvm_pure_intrin(
"int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_low, b_low
)
return c
def addp(a, b):
"""Add two vectors in pairs
Parameters:
----------
a: int32x4 vector
b: int32x4 vector
Returns:
--------
return a int32x4 vector
Pseudocode:
----------
c = (a0+a1, a2+a3, b0+b1, b0+b3)
"""
return tvm.tir.call_llvm_pure_intrin(
"int32x4", addp_intrin, tvm.tir.const(2, "uint32"), a, b
)
def accumulation_loop(M, N, ins, acc, tile_idx):
"""Internal tile accumulation. This function
takes two arrays of int8 data type A[tile_idx][4][16] and B[tile_idx][4][16], produces
a 4x4 matrix which is equal to A*B' and accumulates into C[4][4]
The pseudo code is as follows.
.. code-block:: c
void gemm_4x4_int8_int8_int32(int8 A[tile_idx][4][K],
int8 B[tile_idx][4][K],
int32 C[4][4]){
for (int i = 0; i < 4; i++){
for (int j = 0; j < 4; j++){
for (int k = 0; k < 16; k++){
C[i][j] += A[tile_idx][i][k] * B[tile_idx][j][k]
}
}
}
Notes:
* The tiling strategy is picked to maximize register usage.
Parameters:
----------
M : int
Number of total rows of the output matrix
N : int
Number of total columns of the output matrix
ins : list of tvm.tir.buffer
Input buffers
acc : tvm.tir.ir_builder.BufferVar
Bank of register accumulators
tiled_idx : int
Index of a sub-tile of A and B in A[tile_idx][:][:] and B[tile_idx][:][:].
Please note that 0 <= tile_idx <= K//16
"""
a0 = ins[0].vload([tile_idx, 0, 0], dtype_vec)
a1 = tvm.tir.const(0, "int8x16")
if M > 1:
a1 = ins[0].vload([tile_idx, 1, 0], dtype_vec)
a2 = tvm.tir.const(0, "int8x16")
if M > 2:
a2 = ins[0].vload([tile_idx, 2, 0], dtype_vec)
a3 = tvm.tir.const(0, "int8x16")
if M > 3:
a3 = ins[0].vload([tile_idx, 3, 0], dtype_vec)
b0 = ins[1].vload([tile_idx, 0, 0], dtype_vec)
b1 = tvm.tir.const(0, "int8x16")
if N > 1:
b1 = ins[1].vload([tile_idx, 1, 0], dtype_vec)
b2 = tvm.tir.const(0, "int8x16")
if N > 2:
b2 = ins[1].vload([tile_idx, 2, 0], dtype_vec)
b3 = tvm.tir.const(0, "int8x16")
if N > 3:
b3 = ins[1].vload([tile_idx, 3, 0], dtype_vec)
# First half
# Lower part of a0 * {b0,b1,b2,b3}
d00 = umull(a0, b0)
d01 = umull(a0, b1)
d02 = umull(a0, b2)
d03 = umull(a0, b3)
# Lower part of a1 * {b0,b1,b2,b3}
d10 = umull(a1, b0)
d11 = umull(a1, b1)
d12 = umull(a1, b2)
d13 = umull(a1, b3)
# Accumulate
acc[0] = uadalp(acc[0], d00)
acc[1] = uadalp(acc[1], d01)
acc[2] = uadalp(acc[2], d02)
acc[3] = uadalp(acc[3], d03)
acc[4] = uadalp(acc[4], d10)
acc[5] = uadalp(acc[5], d11)
acc[6] = uadalp(acc[6], d12)
acc[7] = uadalp(acc[7], d13)
# Higher part of a0 * {b0,b1,b2,b3}
d00 = umull2(a0, b0)
d01 = umull2(a0, b1)
d02 = umull2(a0, b2)
d03 = umull2(a0, b3)
# Higher part of a1 * {b0,b1,b2,b3}
d10 = umull2(a1, b0)
d11 = umull2(a1, b1)
d12 = umull2(a1, b2)
d13 = umull2(a1, b3)
# Accumulate again
acc[0] = uadalp(acc[0], d00)
acc[1] = uadalp(acc[1], d01)
acc[2] = uadalp(acc[2], d02)
acc[3] = uadalp(acc[3], d03)
acc[4] = uadalp(acc[4], d10)
acc[5] = uadalp(acc[5], d11)
acc[6] = uadalp(acc[6], d12)
acc[7] = uadalp(acc[7], d13)
# Second half
# Lower part of a2 * {b0,b1,b2,b3}
d00 = umull(a2, b0)
d01 = umull(a2, b1)
d02 = umull(a2, b2)
d03 = umull(a2, b3)
# Lower part of a3 * {b0,b1,b2,b3}
d10 = umull(a3, b0)
d11 = umull(a3, b1)
d12 = umull(a3, b2)
d13 = umull(a3, b3)
# Accumulate
acc[8] = uadalp(acc[8], d00)
acc[9] = uadalp(acc[9], d01)
acc[10] = uadalp(acc[10], d02)
acc[11] = uadalp(acc[11], d03)
acc[12] = uadalp(acc[12], d10)
acc[13] = uadalp(acc[13], d11)
acc[14] = uadalp(acc[14], d12)
acc[15] = uadalp(acc[15], d13)
# Higher part of a2 * {b0,b1,b2,b3}
d00 = umull2(a2, b0)
d01 = umull2(a2, b1)
d02 = umull2(a2, b2)
d03 = umull2(a2, b3)
# Lower part of a3 * {b0,b1,b2,b3}
d10 = umull2(a3, b0)
d11 = umull2(a3, b1)
d12 = umull2(a3, b2)
d13 = umull2(a3, b3)
# Accumulate
acc[8] = uadalp(acc[8], d00)
acc[9] = uadalp(acc[9], d01)
acc[10] = uadalp(acc[10], d02)
acc[11] = uadalp(acc[11], d03)
acc[12] = uadalp(acc[12], d10)
acc[13] = uadalp(acc[13], d11)
acc[14] = uadalp(acc[14], d12)
acc[15] = uadalp(acc[15], d13)
def _intrin_func(ins, outs):
def _instr():
ib = tvm.tir.ir_builder.create()
# Allocate a local buffer (possibly translates to registers)
acc = ib.allocate("int32x4", 16, name="accs", scope="local")
m = outs[0].shape[0]
n = outs[0].shape[1]
# Initialization
for i in range(0, 16):
acc[i] = tvm.tir.const(0, "int32x4")
if unroll:
for i in range(0, int(K // 16)):
accumulation_loop(M, N, ins, acc, i)
else:
with ib.for_range(0, K // 16, name="i") as i:
accumulation_loop(M, N, ins, acc, i)
# Final accumulations
# acc[4*r + c] contains the partial accumulations of element C[r][c]
#
# In particular:
# acc[4*r] contains the partial sums of a[r,0:K].*b[0,0:K] -> (a,b,c,d)
# acc[4*r+1] contains the partial sums of a[r, 0:K].*b[1,0:K] -> (e,f,g,h)
# acc[4*r+2] contains the partial sums of a[r, 0:K].*b[2,0:K] -> (i,j,k,l)
# acc[4*r+3] contains the partial sums of a[r, 0:K].*b[3,0:K] -> (m,n,o,p)
#
# Please note that 0<= r, c < 4
acc[0] = addp(acc[0], acc[1]) # (a+b, c+d, e+f, g+h)
acc[1] = addp(acc[2], acc[3]) # (i+j, k+l, m+n, o+p)
acc[0] = addp(acc[0], acc[1]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
acc[4] = addp(acc[4], acc[5]) # (a+b, c+d, e+f, g+h)
acc[5] = addp(acc[6], acc[7]) # (i+j, k+l, m+n, o+p)
acc[4] = addp(acc[4], acc[5]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
acc[8] = addp(acc[8], acc[9]) # (a+b, c+d, e+f, g+h)
acc[9] = addp(acc[10], acc[11]) # (i+j, k+l, m+n, o+p)
acc[8] = addp(acc[8], acc[9]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
acc[12] = addp(acc[12], acc[13]) # (a+b, c+d, e+f, g+h)
acc[13] = addp(acc[14], acc[15]) # (i+j, k+l, m+n, o+p)
acc[12] = addp(acc[12], acc[13]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
# Store the result
if N > 3:
out_0 = acc[0]
out_1 = acc[4]
out_2 = acc[8]
out_3 = acc[12]
elif N > 2:
out_0 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[0])
out_1 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[4])
out_2 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[8])
out_3 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[12])
elif N > 1:
out_0 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[0])
out_1 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[4])
out_2 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[8])
out_3 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[12])
else:
out_0 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[0])
out_1 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[4])
out_2 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[8])
out_3 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[12])
ib.emit(outs[0].vstore([0, 0], out_0))
if M > 1:
ib.emit(outs[0].vstore([1, 0], out_1))
if M > 2:
ib.emit(outs[0].vstore([2, 0], out_2))
if M > 3:
ib.emit(outs[0].vstore([3, 0], out_3))
return ib.get()
# body, reset, update
return _instr()
buffer_params = {"offset_factor": 1}
return te.decl_tensor_intrin(
C.op,
_intrin_func,
binds={A: a_buffer, B: b_buffer, C: c_buffer},
default_buffer_params=buffer_params,
)
def dot_int8_int8_int32_neon_82(int32_lanes, dtype="uint"):
"""
Int8 dot product by every 4 elements using ARM v8.2 udot.
This function takes two arrays of int8 datatype -- data[4] and
kernel[int32_lanes][4] -- and computes a dot product of data[4] with every
4 elements of kernels, resulting in output[int32_lanes] of uint32 datatype.
The pseudo code is as follows.
.. code-block:: c
void dot_int8_int8_int32(int8 data[4], int8 kernel[16][4], int32 output[16]){
for (int i = 0; i < int32_lanes; i++){
out[i] = 0;
for (int k = 0; k < 4; k++){
out[i] += data[k] * kernel[i][k]
}
}
}
Physically, the kernel array sits in a vector register and
the data[4] is broadcasted to another vector register. This
function returns a TensorIntrin that can be used to tensorize
a schedule.
Parameters
----------
int32_lanes : int
How many int32/uint32 to produce
dtype : str, optional, {"uint", "int"}
Whether it works on unsigned int or signed int
Returns
-------
intrin : TensorIntrin
The ARM uint8 TensorIntrin that can be used in tensorizing schedule
"""
num_int8_elements = 4 # 4 int8 elements in int32
data = te.placeholder((num_int8_elements,), dtype="%s8" % dtype, name="data")
kernel = te.placeholder((int32_lanes, num_int8_elements), dtype="%s8" % dtype, name="kernel")
k = te.reduce_axis((0, num_int8_elements), name="k")
C = te.compute(
(int32_lanes,),
lambda i: te.sum(
data[k].astype("%s32" % dtype) * kernel[i, k].astype("%s32" % dtype), axis=k
),
name="C",
)
a_buffer = tvm.tir.decl_buffer(
data.shape, dtype="%s8" % dtype, name="a_buffer", offset_factor=1, strides=[1]
)
b_buffer = tvm.tir.decl_buffer(
kernel.shape,
dtype="%s8" % dtype,
name="b_buffer",
offset_factor=1,
strides=[te.var("s"), 1],
)
def _intrin_func(ins, outs):
def _instr(index):
ib = tvm.tir.ir_builder.create()
if index == 1:
ib.emit(outs[0].vstore(0, tvm.tir.const(0, "%s32x%d" % (dtype, int32_lanes))))
return ib.get()
dtype_a = "%s8x%d" % (dtype, num_int8_elements)
dtype_b = "%s8x%d" % (dtype, int32_lanes * num_int8_elements)
dtype_c = "%s32x%d" % (dtype, int32_lanes)
a_int8 = ins[0].vload([0], dtype_a)
re_int32 = tvm.tir.call_intrin("%s32" % dtype, "tir.reinterpret", a_int8)
# broadcast a
vec_ai32 = re_int32.astype(dtype_c)
vec_a = tvm.tir.call_intrin(dtype_b, "tir.reinterpret", vec_ai32)
vec_b = ins[1].vload([0, 0], dtype_b)
vec_c = outs[0].vload([0], dtype_c)
inst = "udot" if dtype == "uint" else "sdot"
inst = "llvm.aarch64.neon.%s.v%di32.v%di8" % (
inst,
int32_lanes,
int32_lanes * num_int8_elements,
)
vdot = tvm.tir.call_llvm_pure_intrin(
dtype_c, inst, tvm.tir.const(3, "uint32"), vec_c, vec_a, vec_b
)
ib.emit(outs[0].vstore(0, vdot))
return ib.get()
# body, reset, update
return _instr(0), _instr(1), _instr(2)
buffer_params = {"offset_factor": 1}
return te.decl_tensor_intrin(
C.op,
_intrin_func,
binds={data: a_buffer, kernel: b_buffer},
default_buffer_params=buffer_params,
)
def dot_int8_int8_int32_neon():
"""
Int8 dot product using vmlal instructions
.. code-block:: c
void dot_int8_int8_int32(int8 data[4], int8 kernel[4][4], int32 output[4]){
for (int i = 0; i < 4; i++){
out[i] = 0;
for (int k = 0; k < 4; k++){
out[i] += data[k] * kernel[i][k]
}
}
}
We use the smull and saddlp instructions to compute the dot product.
smull : int8x16 -> int8x16 -> int16x8 elementwise multiplication
saddlp: int16x8 -> int32x4 pairwise addition of elements
Data is broadcast across the register
int8 elements
| data | data |
| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
smull
int8 elements
| kernel[i] | kernel[i+1] |
| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
=
int16 elements
| data * kernel[i] | data * kernel[i+1] |
| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
saddlp =
int32 elements
| partial sum(data * kernel[i]) | partial sum(data * kernel[i+1]) |
| 0 | 1 | 2 | 3 |
We apply the above kernel twice and use addp to compute the second set of pairwise additions
int32 elements (narrowed for so they fit on a line)
| psum d*k[i] | psum d*k[i+1] | | psum d*k[i+2] | psum d*k[i+3] |
| 0 | 1 | 2 | 3 | addp | 4 | 5 | 6 | 7 |
=
|sum d*ki |sum d*ki1|sum d*ki2|sum d*ki3|
| 0 | 1 | 2 | 3 |
"""
int32_lanes = 4 # 4 int32 lanes = 128
num_int8_elements = 4 # 4 int8 elements in int32
data = te.placeholder((num_int8_elements,), dtype="int8", name="data")
kernel = te.placeholder((int32_lanes, num_int8_elements), dtype="int8", name="kernel")
k = te.reduce_axis((0, num_int8_elements), name="k")
C = te.compute(
(int32_lanes,),
lambda i: te.sum(data[k].astype("int32") * kernel[i, k].astype("int32"), axis=k),
name="C",
)
a_buffer = tvm.tir.decl_buffer(
data.shape, dtype="int8", name="a_buffer", offset_factor=1, strides=[1]
)
b_buffer = tvm.tir.decl_buffer(
kernel.shape, dtype="int8", name="b_buffer", offset_factor=1, strides=[te.var("ldw"), 1]
)
def _intrin_func(ins, outs):
def _instr(index):
int_8xl = "int8x8"
int_32xl = "int32x4"
ib = tvm.tir.ir_builder.create()
if index == 1:
ib.emit(outs[0].vstore(0, tvm.tir.const(0, int_32xl)))
return ib.get()
# this broadcasts data to the vector size
a_int8 = ins[0].vload([0], "int8x4")
re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8)
vec_ai32 = re_int32.astype("int32x2")
vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32)
vec_b = ins[1].vload([0, 0], "int8x16")
def pairwise_add_mul(extract_half):
vec_b_half = tvm.tir.call_intrin("int8x8", extract_half, vec_b)
multiply = tvm.tir.call_llvm_pure_intrin(
"int16x8",
"llvm.aarch64.neon.smull.v8i16", # saturating pairwise multiplication
tvm.tir.const(2, "uint32"),
vec_a,
vec_b_half,
)
pairwise_reduction = tvm.tir.call_llvm_pure_intrin(
"int32x4",
"llvm.aarch64.neon.saddlp.v4i32.v8i16",
tvm.tir.const(1, "uint32"),
multiply,
)
return pairwise_reduction
pair_1 = pairwise_add_mul("tir.vectorlow")
pair_2 = pairwise_add_mul("tir.vectorhigh")
quad_reduction = tvm.tir.call_llvm_pure_intrin(
"int32x4",
"llvm.aarch64.neon.addp.v4i32",
tvm.tir.const(2, "uint32"),
pair_1,
pair_2,
)
if index == 0:
ib.emit(outs[0].vstore(0, quad_reduction))
else:
ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], int_32xl)))
return ib.get()
# body, reset, update
return _instr(0), _instr(1), _instr(2)
buffer_params = {"offset_factor": 1}
return te.decl_tensor_intrin(
C.op,
_intrin_func,
binds={data: a_buffer, kernel: b_buffer},
default_buffer_params=buffer_params,
)
def select_word(vec, lane, dtype_vec):
"""
Utility function used to select a int8x4 word within a int8x16 vector
and replicate 4 times.
The pseudo-code for this operation is:
v = [x0, ..., x15]
vsub(lane) = v[4*lane:4*lane+3]
replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
Note that 0<=lane<4
Parameters
----------
vec : tvm.tir.Expr
int8x16 vector expression
lane : int
vector lane we want to replicate
dtype_vec : str
vector data type (e.g., int8x16)
Returns
----------
output : tvm.tir.Expr
replicated vector
"""
# Reinterpret vec_a as 4 int32 words
vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
# Broadcast the lane-th word
vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
# Convert back to uint8x16
vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
return vec_int8_broadcast
def gemm_acc_4x4_int8_int8_int32(dtype):
"""
Int8 4x4 matrix multiplication and accumulation using sdot/udot
instructions. This function takes two arrays of int8 datatype
-- A[4][4] and B[4][4] and produces a 4x4 matrix
which is equal to A*B'.
The pseudo code is as follows.
.. code-block:: c
void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
for (int i = 0; i < 4; i++){
for (int j = 0; j < 4; j++){
for (int k = 0; k < 4; k++){
C[i][j] += A[i][k] * B[j][k]
}
}
}
Notes:
* The tiling strategy is picked to maximize register usage.
Parameters
----------
dtype : str, {"uint8", "int8"}
Whether it works on unsigned int or signed int
Returns
-------
intrin : TensorIntrin
The Arm TensorIntrin that can be used in tensorizing schedule
"""
assert dtype in ["uint8", "int8"]
# This needs to be a variable number of "rows" since TVM
# "thinks" I only need to compute one row because of
# padding
A = te.placeholder((te.var("rows"), 4), dtype, name="A")
B = te.placeholder((4, 4), dtype, name="B")
dtype_vec = dtype + "x16"
k = te.reduce_axis((0, 4), name="k")
C = te.compute(
(te.var("rows"), 4),
lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
name="C",
)
aa_buffer = tvm.tir.decl_buffer(
A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
)
bb_buffer = tvm.tir.decl_buffer(
B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
)
cc_buffer = tvm.tir.decl_buffer(
C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
)
llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
def _intrin_func(ins, outs):
def _instr(index):
ib = tvm.tir.ir_builder.create()
if index == 1:
for i in range(0, 4):
ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
return ib.get()
# Load all the elements of tile A.
# vec_a = [a, b, c, d,
# e, f, g, h,
# l, m, n, o,
# p, q, r, s];
vec_a = ins[0].vload([0, 0], dtype_vec)
# Replicate 4 times the i-th row of A. For instance,
# vec_a[0] = [a, b, c, d,
# a, b, c, d,
# a, b, c, d,
# a, b, c, d,];
vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
# Load all the elements of B. Remember that B
# is transposed:
# vec_b = [0, 4, 8, 12,
# 1, 5, 9, 13,
# 2, 6, 10, 14,
# 3, 7, 11, 15,];
vec_b = ins[1].vload([0, 0], dtype_vec)
# Execute the dot product
for i in range(0, 4):
vec_c = outs[0].vload([i, 0], "int32x4")
# Compute the product between the i-th row of A
# and all the rows of B. Remember that sdot/udot
# subdive the input vectors in 16 elements
# and then take the dot product among each group.
# The result is stored in a int32x4 register
#
# For instance, for i=0, we have:
# sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
# a*1+b*5+c*9+d*13,
# a*2+b*6+c*10+d*14,
# a*3+b*7+c*11+d*15]
vdot = tvm.tir.call_llvm_intrin(
"int32x4",
llvm_intrin,
tvm.tir.const(3, "uint32"),
vec_c,
vec_b,
vec_aa[i],
)
# Store the result
ib.emit(outs[0].vstore([i, 0], vdot))
return ib.get()
# body, reset, update
return _instr(0), _instr(1), _instr(2)
buffer_params = {"offset_factor": 1}
return te.decl_tensor_intrin(
C.op,
_intrin_func,
binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
default_buffer_params=buffer_params,
)
def gemm_acc_nx16_int8_int8_int32(dtype, rows):
"""
Int8 nx16 matrix multiplication and accumulation using sdot/udot instructions
This function takes two arrays of int8 datatype -- A[n][4] and
B[4][16] and produces a rowsx16 matrix which is equal to A*B'
The pseudo code is as follows.
.. code-block:: c
void mmla_nx16_int8_int8_int32(int8 A[n][16], int8 B[4][16][4], int32 output[n][16]){
for (int i = 0; i < n; i++){
for (int j = 0; j < 16; j++){
for (int k = 0; k < 16; k++){
out[i][j] += A[i][k] * B[k//4][j][k%4]
}
}
}
}
Notes:
* The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16
we need 4 tiles of B to compute a single row of the output. The first 4 values of
k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on
* The tiling strategy is picked to maximize register usage.
Parameters
----------
dtype : str, {"uint8", "int8"}
Whether it works on unsigned int or signed int
rows : int
Number of the output rows "n"
Returns
-------
intrin : TensorIntrin
The Arm TensorIntrin that can be used in tensorizing schedule
"""
assert dtype in ["uint8", "int8"]
A = te.placeholder((rows, 16), dtype, name="A")
B = te.placeholder((4, 16, 4), dtype, name="B")
dtype_vec = dtype + "x16"
idxm = tvm.tir.indexmod
k = te.reduce_axis((0, 16), name="k")
C = te.compute(
(rows, 16),
lambda i, j: te.sum(
A[i, k].astype("int32") * B[k // 4, j, idxm(k, 4)].astype("int32"), axis=k
),
name="C",
)
aa_buffer = tvm.tir.decl_buffer(
A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
)
bb_buffer = tvm.tir.decl_buffer(
B.shape,
dtype,
name="bb_buffer",
offset_factor=1,
strides=[te.var("sb0"), te.var("sb1"), 1],
)
cc_buffer = tvm.tir.decl_buffer(
C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
)
llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
def _intrin_func(ins, outs):
def _instr(index):
ib = tvm.tir.ir_builder.create()
if index == 1:
for i in range(0, rows):
ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x16")))
return ib.get()
# Iterate on the number of rows of the output
for k in range(0, rows):
# Load 16 elements of A
# vec_a = [a, b, c, d, e, f, g, h, l, m, n, o, p, q, r, s];
vec_a = ins[0].vload([k, 0], dtype_vec)
# Iterate over each of the 4 rowsx4 tiles of the output
for j in range(0, 4):
# Accumulate over each of the 4 (16x4) tiles contained in B
for i in range(0, 4):
# Replicate a single 4-element group of A (A[k, i:i+4])
vec_aa = select_word(vec_a, i, dtype_vec)
# Load 4 rows (each rows with 4 elements) from B (B[i:i+4, j:j+4])
# vec_b = [0, 16, 32, 48,
# 1, 17, 33, 49,
# 2, 18, 34, 50,
# 3, 19, 35, 51,];
vec_b = ins[1].vload([i, 4 * j, 0], dtype_vec)
# Accumulate in the correct part of the output
vec_c = outs[0].vload([k, 4 * j], "int32x4")
# Compute the dot product between the rowsx4 tile
# from A and the 4x4 tile from B
#
# For instance, for i=0, we have:
# sdot(vec_aa[0], vec_b) = [a*0+b*16+c*32+d*48,
# a*1+b*17+c*33+d*49,
# a*2+b*18+c*34+d*50,
# a*3+b*19+c*35+d*51]
vdot = tvm.tir.call_llvm_intrin(
"int32x4",
llvm_intrin,
tvm.tir.const(3, "uint32"),
vec_c,
vec_b,
vec_aa,
)
ib.emit(outs[0].vstore([k, 4 * j], vdot))
return ib.get()
# body, reset, update
return _instr(0), _instr(1), _instr(2)
buffer_params = {"offset_factor": 1}
return te.decl_tensor_intrin(
C.op,
_intrin_func,
binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
default_buffer_params=buffer_params,
)
def smlal_int16_int32():
"""
Intrinsic to be used in order to load two int16x8 vectors and multiply
them together through a pair of smlal/smlal2 instructions. The pseudo-code
for the algorithm is as follows:
vec_a = vload(A, "int16x8")
vec_b = vload(B, "int16x8")
vec_c[0:4] += vec_a[0:4]*vec_b[0:4] // -> smlal instruction
vec_c[4:8] += vec_a[4:8]*vec_b[4:8] // -> smlal2 instruction
So we load a single int16x8 vector and we accumulate its lower (0:4) and
higher part separately.
"""
int16_lanes = 8
A = te.placeholder((int16_lanes,), dtype="int16", name="A")
B = te.placeholder((int16_lanes, 1), dtype="int16", name="B")
C = te.compute(
(int16_lanes,),
lambda i: A[i].astype("int32") * B[i, 0].astype("int32"),
name="C",
)
a_buffer = tvm.tir.decl_buffer(
A.shape, dtype="int16", name="a_buffer", offset_factor=1, strides=[1]
)
b_buffer = tvm.tir.decl_buffer(
B.shape,
dtype="int16",
name="b_buffer",
offset_factor=1,
strides=[te.var("sb"), 1],
)
c_buffer = tvm.tir.decl_buffer(
C.shape,
dtype="int32",
name="c_buffer",
offset_factor=1,
strides=[1],
)
def _intrin_func(ins, outs):
def _instr(index):
ib = tvm.tir.ir_builder.create()
if index == 1:
ib.emit(outs[0].vstore(0, tvm.tir.const(0, "int32x8")))
return ib.get()
vec_a = ins[0].vload([0], "int16x8")
vec_b = ins[1].vload([0, 0], "int16x8")
inst = "llvm.aarch64.neon.smull"
# Higher part of the vector
vec_c_h = outs[0].vload([4], "int32x4")
vec_a_h = tvm.tir.call_intrin("int16x4", "tir.vectorhigh", vec_a)
vec_b_h = tvm.tir.call_intrin("int16x4", "tir.vectorhigh", vec_b)
vmull_h = tvm.tir.call_llvm_pure_intrin(
"int32x4", inst, tvm.tir.const(2, "uint32"), vec_a_h, vec_b_h
)
vec_out_h = vec_c_h + vmull_h
# Lower part of the vector
vec_c_l = outs[0].vload([0], "int32x4")
vec_a_l = tvm.tir.call_intrin("int16x4", "tir.vectorlow", vec_a)
vec_b_l = tvm.tir.call_intrin("int16x4", "tir.vectorlow", vec_b)
vmull_l = tvm.tir.call_llvm_pure_intrin(
"int32x4", inst, tvm.tir.const(2, "uint32"), vec_a_l, vec_b_l
)
vec_out_l = vec_c_l + vmull_l
# Combine higher and lower part in a single int32x8 vector to store
# (this will require two different store instructions, since the
# length of a NEON vector is fixed at 128
vec_out = tvm.tir.call_intrin("int32x8", "tir.vectorcombine", vec_out_l, vec_out_h)
ib.emit(outs[0].vstore(0, vec_out))
return ib.get()
# body, reset, update
return _instr(0), _instr(1), _instr(2)
buffer_params = {"offset_factor": 1}
return te.decl_tensor_intrin(
C.op,
_intrin_func,
binds={A: a_buffer, B: b_buffer, C: c_buffer},
default_buffer_params=buffer_params,
)
def gemm_acc_2x2_int8_int8_int32(dtype):
"""
Int8 2x2 matrix multiplication using smmla/ummla instructions
This function takes two arrays of int8 datatype -- A[2][8] and
B[2][8] and produces a 2x2 matrix which is equal to A*B'
The pseudo code is as follows.
.. code-block:: c
void mmla_2x2_int8_int8_int32(int8 A[2][8], int8 B[2][8], int32 C[2][2]){
for (int i = 0; i < 2; i++){
for (int j = 0; j < 2; j++){
for (int k = 0; k < 8; k++){
C[i][j] += A[i][k] * B[j][k]
}
}
}
Parameters
----------
dtype : str, {"uint8", "int8"}
Whether it works on unsigned int or signed int
Returns
-------
intrin : TensorIntrin
The Arm TensorIntrin that can be used in tensorizing schedule
"""
assert dtype in ["uint8", "int8"]
A = te.placeholder((2, 8), dtype, name="A")
B = te.placeholder((2, 8), dtype, name="B")
dtype_vec = dtype + "x16"
k = te.reduce_axis((0, 8), name="k")
C = te.compute(
(2, 2),
lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
name="C",
)
aa_buffer = tvm.tir.decl_buffer(
A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
)
bb_buffer = tvm.tir.decl_buffer(
B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
)
cc_buffer = tvm.tir.decl_buffer(
C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
)
llvm_intrin = "llvm.aarch64.neon.smmla" if dtype == "int8" else "llvm.aarch64.neon.ummla"
def _intrin_func(ins, outs):
def _instr(index):
ib = tvm.tir.ir_builder.create()
if index == 1:
ib.emit(outs[0].vstore([0, 0], tvm.tir.const(0, "int32x4")))
return ib.get()
# Load in vec_a the two rows of A
# vec_a = [a, b, c, d, e, f, g, h;
# i, j, k, l, m, n, o, p,]
vec_a = ins[0].vload([0, 0], dtype_vec)
# Load in vec_b the two rows of B
# vec_b = [0, 2, 4, 6, 8, 10, 12, 14;
# 1, 3, 5, 7, 9, 11, 13, 14,]
vec_b = ins[1].vload([0, 0], dtype_vec)
# Execute the matrix multiplication via (s/u)mmla:
# vec_c = [a*0 + b*2 + c*4 + d*6 +e*8 + f*10 + g*12 + h*14;
# a*1 + b*3 + c*5 + d*7 +e*9 + f*11 + g*13 + h*15;
# i*0 + j*2 + k*4 + l*6 +m*8 + n*10 + o*12 + p*14;
# i*1 + j*3 + k*5 + l*7 +m*9 + n*11 + o*13 + p*15]
vec_c = outs[0].vload([0, 0], "int32x4")
vmmla = tvm.tir.call_llvm_intrin(
"int32x4",
llvm_intrin,
tvm.tir.const(3, "uint32"),
vec_c,
vec_a,
vec_b,
)
# Store the result
ib.emit(outs[0].vstore([0, 0], vmmla))
return ib.get()
# body, reset, update
return _instr(0), _instr(1), _instr(2)
buffer_params = {"offset_factor": 1}
return te.decl_tensor_intrin(
C.op,
_intrin_func,
binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
default_buffer_params=buffer_params,
)
def _q_multiply_shift_arm(op):
"""
Implementation of q_multiply_shift_arm through arm intrinsics
sqrdmulh and srshl when q == 31.
Please note that this is introducing a small round-up error for
some corner cases. This is because we are rounding twice instead
than only once. I.e.:
* original q_multiply_shift: round(x*y*2^-s)
* arm q_multiply_shift: round(round(x*y)*2^-s)
"""
x = op.args[0]
y = op.args[1]
q = op.args[2]
s = op.args[3]
# Don't use this intrinsic if we don't have a int32x4 vector
# or if we are not multiplying q31 numbers
if x.dtype != "int32x4" or q.value != 31:
return op
# Case 1, shift is negative
sqrdmulh = tvm.tir.call_llvm_intrin(
op.dtype, "llvm.aarch64.neon.sqrdmulh", tvm.tir.const(2, "uint32"), x, y
)
fixup = (sqrdmulh & (-s)) >> 31
fixed_up_x = sqrdmulh + fixup
out_1 = tvm.tir.call_llvm_intrin(
op.dtype, "llvm.aarch64.neon.srshl", tvm.tir.const(2, "uint32"), sqrdmulh, s
)
# Case 2, shift is positive
x = x * (1 << (s))
out_2 = tvm.tir.call_llvm_intrin(
op.dtype, "llvm.aarch64.neon.sqrdmulh", tvm.tir.const(2, "uint32"), x, y
)
# Select depending on the shift
return tvm.tir.Select(s < 0, out_1, out_2)
register_intrin_lowering(
"tir.q_multiply_shift", target="llvm.aarch64", f=_q_multiply_shift_arm, level=99
)