blob: b78d8f7d6e516b309b38c83077e4adb6ea956acf [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
#include "mps_utils.h"
namespace tvm {
namespace contrib {
using namespace runtime;
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("tvm.contrib.mps.matmul", [](ffi::PackedArgs args, ffi::Any* ret) {
auto A = args[0].cast<DLTensor*>();
auto B = args[1].cast<DLTensor*>();
auto C = args[2].cast<DLTensor*>();
bool transa = args[3].cast<bool>();
bool transb = args[4].cast<bool>();
// call gemm for simple compact code.
ICHECK_EQ(A->ndim, 2);
ICHECK_EQ(B->ndim, 2);
ICHECK_EQ(C->ndim, 2);
ICHECK(ffi::IsContiguous(*C));
ICHECK(ffi::IsContiguous(*B));
ICHECK(ffi::IsContiguous(*A));
ICHECK(TypeMatch(A->dtype, kDLFloat, 32));
ICHECK(TypeMatch(B->dtype, kDLFloat, 32));
ICHECK(TypeMatch(C->dtype, kDLFloat, 32));
// Get Metal device API
MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal();
// ICHECK_EQ(A->device, B->device);
// ICHECK_EQ(A->device, C->device);
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->device);
id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(A->device);
id<MTLCommandBuffer> cb = [queue commandBuffer];
NSUInteger M = A->shape[0 + (transa ? 1 : 0)];
NSUInteger N = B->shape[1 - (transb ? 1 : 0)];
NSUInteger K = B->shape[0 + (transb ? 1 : 0)];
ICHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K);
// mps a
MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype);
MPSMatrixDescriptor* descA =
[MPSMatrixDescriptor matrixDescriptorWithDimensions:M
columns:K
rowBytes:K * sizeof(MPSDataTypeFloat32)
dataType:MPSDataTypeFloat32];
id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data);
MPSMatrix* matrixA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA];
// mps b
MPSMatrixDescriptor* descB =
[MPSMatrixDescriptor matrixDescriptorWithDimensions:K
columns:N
rowBytes:N * sizeof(dtype)
dataType:dtype];
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data);
MPSMatrix* matrixB = [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB];
// mps c
MPSMatrixDescriptor* descC =
[MPSMatrixDescriptor matrixDescriptorWithDimensions:M
columns:N
rowBytes:N * sizeof(dtype)
dataType:dtype];
id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data);
MPSMatrix* matrixC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC];
// kernel
MPSMatrixMultiplication* mul_obj = [[MPSMatrixMultiplication alloc] init];
MPSMatrixMultiplication* sgemm = [mul_obj initWithDevice:dev
transposeLeft:transa
transposeRight:transb
resultRows:M
resultColumns:N
interiorColumns:K
alpha:1.0f
beta:0.0f];
ICHECK(sgemm != nil);
[sgemm encodeToCommandBuffer:cb leftMatrix:matrixA rightMatrix:matrixB resultMatrix:matrixC];
[cb commit];
});
}
} // namespace contrib
} // namespace tvm