blob: 759fc3d300424ac447330eb7bd2d631018807b39 [file] [log] [blame]
"""Interface to runtime cuda kernel compile module."""
from __future__ import absolute_import
import ctypes
from .base import _LIB, NDArrayHandle, RtcHandle, mx_uint, c_array, check_call
class Rtc(object):
"""MXRtc object in mxnet.
This class allow you to write CUDA kernels in Python
and call them with NDArray.
Parameters
----------
name : str
Name of the kernel.
inputs : tuple of (str, mxnet.ndarray)
List of input names and ndarray.
outputs : tuple of (str, mxnet.ndarray)
List of output names and ndarray.
kernel : str
The actual kernel code.
Note that this is only the body of the kernel, i.e.
after { and before }. Rtc will decorate the kernel.
For example, if ``name = "mykernel"`` and
inputs = [('x', mx.nd.zeros((10,)))]
outputs = [('y', mx.nd.zeros((10,)))]
kernel = "y[threadIdx.x] = x[threadIdx.x];",
then the compiled kernel will be:
extern "C" __global__ mykernel(float *x, float *y) {
const int x_ndim = 1;
const int x_dims = { 10 };
const int y_ndim = 1;
const int y_dims = { 10 };
y[threadIdx.x] = x[threadIdx.x];
}
"""
def __init__(self, name, inputs, outputs, kernel):
self.handle = RtcHandle()
input_names = ctypes.cast(c_array(ctypes.c_char_p, [i[0] for i in inputs]),
ctypes.POINTER(ctypes.c_char_p))
output_names = ctypes.cast(c_array(ctypes.c_char_p, [i[0] for i in outputs]),
ctypes.POINTER(ctypes.c_char_p))
input_nds = ctypes.cast(c_array(NDArrayHandle, [i[1].handle for i in inputs]),
ctypes.POINTER(NDArrayHandle))
output_nds = ctypes.cast(c_array(NDArrayHandle, [i[1].handle for i in outputs]),
ctypes.POINTER(NDArrayHandle))
check_call(_LIB.MXRtcCreate(ctypes.c_char_p(name),
mx_uint(len(inputs)),
mx_uint(len(outputs)),
input_names,
output_names,
input_nds,
output_nds,
ctypes.c_char_p(kernel),
ctypes.byref(self.handle)))
def __del__(self):
check_call(_LIB.MXRtcFree(self.handle))
def push(self, inputs, outputs, grid_dims, block_dims):
"""Run the kernel.
Parameters
----------
inputs : list of NDArray
List of inputs. Can contain different NDArrays than those used for the constructor,
but its elements must have the same shapes and appear in the same order.
outputs : list of NDArray
List of outputs. Can contain different ndarrays than used for the constructor,
but must have the same shapes and appear in the same order.
grid_dims : tuple of 3 uint
Grid dimension for kernel launch.
block_dims : tuple of 3 uint
Block dimension for kernel launch.
"""
input_nds = ctypes.cast(c_array(NDArrayHandle, [i.handle for i in inputs]),
ctypes.POINTER(NDArrayHandle))
output_nds = ctypes.cast(c_array(NDArrayHandle, [i.handle for i in outputs]),
ctypes.POINTER(NDArrayHandle))
check_call(_LIB.MXRtcPush(self.handle,
mx_uint(len(inputs)),
mx_uint(len(outputs)),
input_nds,
output_nds,
mx_uint(grid_dims[0]),
mx_uint(grid_dims[1]),
mx_uint(grid_dims[2]),
mx_uint(block_dims[0]),
mx_uint(block_dims[1]),
mx_uint(block_dims[2])))