blob: 2052dc4a580645d06f2497c3e2acdab712b4a652 [file] [log] [blame]
// Copyright 2016 bluss
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
use std::vec::*;
use std::cmp::min;
use std::mem::size_of;
use util::range_chunk;
use util::round_up_to;
use kernel::GemmKernel;
use kernel::Element;
use sgemm_kernel;
use dgemm_kernel;
use rawpointer::PointerExt;
/// General matrix multiplication (f32)
///
/// C ← α A B + β C
///
/// + m, k, n: dimensions
/// + a, b, c: pointer to the first element in the matrix
/// + A: m by k matrix
/// + B: k by n matrix
/// + C: m by n matrix
/// + rs<em>x</em>: row stride of *x*
/// + cs<em>x</em>: col stride of *x*
///
/// Strides for A and B may be arbitrary. Strides for C must not result in
/// elements that alias each other, for example they can not be zero.
///
/// If β is zero, then C does not need to be initialized.
pub unsafe fn sgemm(
m: usize, k: usize, n: usize,
alpha: f32,
a: *const f32, rsa: isize, csa: isize,
b: *const f32, rsb: isize, csb: isize,
beta: f32,
c: *mut f32, rsc: isize, csc: isize)
{
gemm_loop::<sgemm_kernel::Gemm>(
m, k, n,
alpha,
a, rsa, csa,
b, rsb, csb,
beta,
c, rsc, csc)
}
/// General matrix multiplication (f64)
///
/// C ← α A B + β C
///
/// + m, k, n: dimensions
/// + a, b, c: pointer to the first element in the matrix
/// + A: m by k matrix
/// + B: k by n matrix
/// + C: m by n matrix
/// + rs<em>x</em>: row stride of *x*
/// + cs<em>x</em>: col stride of *x*
///
/// Strides for A and B may be arbitrary. Strides for C must not result in
/// elements that alias each other, for example they can not be zero.
///
/// If β is zero, then C does not need to be initialized.
pub unsafe fn dgemm(
m: usize, k: usize, n: usize,
alpha: f64,
a: *const f64, rsa: isize, csa: isize,
b: *const f64, rsb: isize, csb: isize,
beta: f64,
c: *mut f64, rsc: isize, csc: isize)
{
gemm_loop::<dgemm_kernel::Gemm>(
m, k, n,
alpha,
a, rsa, csa,
b, rsb, csb,
beta,
c, rsc, csc)
}
/// Ensure that GemmKernel parameters are supported
/// (alignment, microkernel size).
///
/// This function is optimized out for a supported configuration.
#[inline(always)]
fn ensure_kernel_params<K>()
where K: GemmKernel
{
let mr = K::mr();
let nr = K::nr();
assert!(mr > 0 && mr <= 8);
assert!(nr > 0 && nr <= 8);
assert!(mr * nr * size_of::<K::Elem>() <= 8 * 4 * 8);
assert!(K::align_to() <= 32);
// one row/col of the kernel is limiting the max align we can provide
let max_align = size_of::<K::Elem>() * min(mr, nr);
assert!(K::align_to() <= max_align);
}
/// Implement matrix multiply using packed buffers and a microkernel
/// strategy, the type parameter `K` is the gemm microkernel.
unsafe fn gemm_loop<K>(
m: usize, k: usize, n: usize,
alpha: K::Elem,
a: *const K::Elem, rsa: isize, csa: isize,
b: *const K::Elem, rsb: isize, csb: isize,
beta: K::Elem,
c: *mut K::Elem, rsc: isize, csc: isize)
where K: GemmKernel
{
debug_assert!(m * n == 0 || (rsc != 0 && csc != 0));
let knc = K::nc();
let kkc = K::kc();
let kmc = K::mc();
ensure_kernel_params::<K>();
let (mut packv, bp_offset) = packing_vec::<K>(m, k, n);
let app = make_aligned_vec_ptr(K::align_to(), &mut packv);
let bpp = app.offset(bp_offset);
// LOOP 5: split n into nc parts
for (l5, nc) in range_chunk(n, knc) {
dprint!("LOOP 5, {}, nc={}", l5, nc);
let b = b.stride_offset(csb, knc * l5);
let c = c.stride_offset(csc, knc * l5);
// LOOP 4: split k in kc parts
for (l4, kc) in range_chunk(k, kkc) {
dprint!("LOOP 4, {}, kc={}", l4, kc);
let b = b.stride_offset(rsb, kkc * l4);
let a = a.stride_offset(csa, kkc * l4);
debug!(for elt in &mut packv { *elt = <_>::one(); });
// Pack B -> B~
pack(kc, nc, K::nr(), bpp, b, csb, rsb);
// LOOP 3: split m into mc parts
for (l3, mc) in range_chunk(m, kmc) {
dprint!("LOOP 3, {}, mc={}", l3, mc);
let a = a.stride_offset(rsa, kmc * l3);
let c = c.stride_offset(rsc, kmc * l3);
// Pack A -> A~
pack(kc, mc, K::mr(), app, a, rsa, csa);
// First time writing to C, use user's `beta`, else accumulate
let betap = if l4 == 0 { beta } else { <_>::one() };
// LOOP 2 and 1
gemm_packed::<K>(nc, kc, mc,
alpha,
app, bpp,
betap,
c, rsc, csc);
}
}
}
}
/// Loops 1 and 2 around the µ-kernel
///
/// + app: packed A (A~)
/// + bpp: packed B (B~)
/// + nc: columns of packed B
/// + kc: columns of packed A / rows of packed B
/// + mc: rows of packed A
unsafe fn gemm_packed<K>(nc: usize, kc: usize, mc: usize,
alpha: K::Elem,
app: *const K::Elem, bpp: *const K::Elem,
beta: K::Elem,
c: *mut K::Elem, rsc: isize, csc: isize)
where K: GemmKernel,
{
let mr = K::mr();
let nr = K::nr();
// make a mask buffer that fits 8 x 8 f32 and 8 x 4 f64 kernels and alignment
assert!(mr * nr * size_of::<K::Elem>() <= 256 && K::align_to() <= 32);
let mut mask_buf = [0u8; 256 + 31];
let mask_ptr = align_ptr(32, mask_buf.as_mut_ptr()) as *mut K::Elem;
// LOOP 2: through micropanels in packed `b`
for (l2, nr_) in range_chunk(nc, nr) {
let bpp = bpp.stride_offset(1, kc * nr * l2);
let c = c.stride_offset(csc, nr * l2);
// LOOP 1: through micropanels in packed `a` while `b` is constant
for (l1, mr_) in range_chunk(mc, mr) {
let app = app.stride_offset(1, kc * mr * l1);
let c = c.stride_offset(rsc, mr * l1);
// GEMM KERNEL
// NOTE: For the rust kernels, it performs better to simply
// always use the masked kernel function!
if K::always_masked() || nr_ < nr || mr_ < mr {
masked_kernel::<_, K>(kc, alpha, &*app, &*bpp,
beta, &mut *c, rsc, csc,
mr_, nr_, mask_ptr);
continue;
} else {
K::kernel(kc, alpha, app, bpp, beta, c, rsc, csc);
}
}
}
}
/// Allocate a vector of uninitialized data to be used for both packing buffers.
///
/// + A~ needs be KC x MC
/// + B~ needs be KC x NC
/// but we can make them smaller if the matrix is smaller than this (just ensure
/// we have rounded up to a multiple of the kernel size).
///
/// Return packing vector and offset to start of b
unsafe fn packing_vec<K>(m: usize, k: usize, n: usize) -> (Vec<K::Elem>, isize)
where K: GemmKernel,
{
let m = min(m, K::mc());
let k = min(k, K::kc());
let n = min(n, K::nc());
// round up k, n to multiples of mr, nr
// round up to multiple of kc
let apack_size = k * round_up_to(m, K::mr());
let bpack_size = k * round_up_to(n, K::nr());
let nelem = apack_size + bpack_size;
let mut v = Vec::with_capacity(nelem);
v.set_len(nelem);
dprint!("packed nelem={}, apack={}, bpack={},
m={} k={} n={}",
nelem, apack_size, bpack_size,
m,k,n);
// max alignment requirement is a multiple of min(MR, NR) * sizeof<Elem>
// because apack_size is a multiple of MR, start of b aligns fine
(v, apack_size as isize)
}
/// Align a pointer into the vec. Will reallocate to fit & shift the pointer
/// forwards if needed. This invalidates any previous pointers into the v.
unsafe fn make_aligned_vec_ptr<U>(align_to: usize, v: &mut Vec<U>) -> *mut U {
let mut ptr = v.as_mut_ptr();
if align_to != 0 {
if v.as_ptr() as usize % align_to != 0 {
let cap = v.capacity();
v.reserve_exact(cap + align_to / size_of::<U>() - 1);
ptr = align_ptr(align_to, v.as_mut_ptr());
}
}
ptr
}
/// offset the ptr forwards to align to a specific byte count
unsafe fn align_ptr<U>(align_to: usize, mut ptr: *mut U) -> *mut U {
if align_to != 0 {
let cur_align = ptr as usize % align_to;
if cur_align != 0 {
ptr = ptr.offset(((align_to - cur_align) / size_of::<U>()) as isize);
}
}
ptr
}
/// Pack matrix into `pack`
///
/// + kc: length of the micropanel
/// + mc: number of rows/columns in the matrix to be packed
/// + mr: kernel rows/columns that we round up to
/// + pack: packing buffer
/// + a: matrix,
/// + rsa: row stride
/// + csa: column stride
unsafe fn pack<T>(kc: usize, mc: usize, mr: usize, pack: *mut T,
a: *const T, rsa: isize, csa: isize)
where T: Element
{
let mut pack = pack;
for ir in 0..mc/mr {
let row_offset = ir * mr;
for j in 0..kc {
for i in 0..mr {
*pack = *a.stride_offset(rsa, i + row_offset)
.stride_offset(csa, j);
pack.inc();
}
}
}
let zero = <_>::zero();
// Pad with zeros to multiple of kernel size (uneven mc)
let rest = mc % mr;
if rest > 0 {
let row_offset = (mc/mr) * mr;
for j in 0..kc {
for i in 0..mr {
if i < rest {
*pack = *a.stride_offset(rsa, i + row_offset)
.stride_offset(csa, j);
} else {
*pack = zero;
}
pack.inc();
}
}
}
}
/// Call the GEMM kernel with a "masked" output C.
///
/// Simply redirect the MR by NR kernel output to the passed
/// in `mask_buf`, and copy the non masked region to the real
/// C.
///
/// + rows: rows of kernel unmasked
/// + cols: cols of kernel unmasked
#[inline(never)]
unsafe fn masked_kernel<T, K>(k: usize, alpha: T,
a: *const T,
b: *const T,
beta: T,
c: *mut T, rsc: isize, csc: isize,
rows: usize, cols: usize,
mask_buf: *mut T)
where K: GemmKernel<Elem=T>, T: Element,
{
let mr = K::mr();
let nr = K::nr();
// use column major order for `mask_buf`
K::kernel(k, T::one(), a, b, T::zero(), mask_buf, 1, mr as isize);
let mut ab = mask_buf;
for j in 0..nr {
for i in 0..mr {
if i < rows && j < cols {
let cptr = c.offset(rsc * i as isize + csc * j as isize);
if beta.is_zero() {
*cptr = T::zero(); // initialize C
} else {
(*cptr).scale_by(beta);
}
(*cptr).scaled_add(alpha, *ab);
}
ab.inc();
}
}
}