blob: 112fc320973b4d385d8d1a5dd88a28bb19b82ea5 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""External function interface to MIOpen library."""
# pylint: disable-msg=C0103
import ctypes
import numpy as np
import tvm
import tvm._ffi
from tvm import te
def _get_np_int32_array_handle(arr):
"""Return a void_p handle for a numpy array
Parameters
----------
arr: numpy.NDArray
source numpy array
Returns
-------
ptr: ctypes.c_void_p
pointer to the data
"""
assert arr.dtype == np.int32
ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
return ctypes.cast(ptr, ctypes.c_void_p)
def conv2d_forward(
x,
w,
stride_h=1,
stride_w=1,
pad_h=0,
pad_w=0,
dilation_h=1,
dilation_w=1,
conv_mode=0,
data_type=1,
group_count=1,
):
"""Create an extern op that compute 2D convolution with MIOpen
Parameters
----------
x: Tensor
input feature map
w: Tensor
convolution weight
stride_h: int
height stride
stride_w: int
width stride
pad_h: int
height pad
pad_w: int
weight pad
dilation_h: int
height dilation
dilation_w: int
width dilation
conv_mode: int
0: miopenConvolution
1: miopenTranspose
data_type: int
0: miopenHalf (fp16)
1: miopenFloat (fp32)
group_count: int
number of groups
Returns
-------
y: Tensor
The result tensor
"""
assert 0 <= conv_mode <= 2, "0: miopenConvolution / 1: miopenTranspose / 2: miopenGroupConv"
if group_count > 1:
conv_mode = 2
oshape = np.zeros((len(x.shape)), dtype=np.int32)
xshape = x.shape
wshape = w.shape
setup_func = tvm._ffi.get_global_func("tvm.contrib.miopen.conv2d.setup")
algo = setup_func(
conv_mode,
data_type,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
xshape[0].value,
xshape[1].value,
xshape[2].value,
xshape[3].value,
wshape[0].value,
wshape[1].value,
wshape[2].value,
wshape[3].value,
group_count,
_get_np_int32_array_handle(oshape),
)
return te.extern(
list(oshape),
[x, w],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.miopen.conv2d.forward",
conv_mode,
data_type,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
algo,
ins[0],
ins[1],
outs[0],
),
name="y",
)