blob: d38f038ffd269d02511fcc91e38326a05c0f6635 [file] [log] [blame]
# pylint: skip-file
import mxnet as mx
import numpy as np
from numpy.testing import assert_allclose
if __name__ == '__main__':
x = mx.nd.zeros((10,), ctx=mx.gpu(0))
x[:] = 1
y = mx.nd.zeros((10,), ctx=mx.gpu(0))
y[:] = 2
rtc = mx.rtc('abc', [('x', x)], [('y', y)], """
__shared__ float s_rec[10];
s_rec[threadIdx.x] = x[threadIdx.x];
y[threadIdx.x] = expf(s_rec[threadIdx.x]*5.0);""")
rtc.push([x], [y], (1, 1, 1), (10,1,1))
assert_allclose(y.asnumpy(), np.exp(x.asnumpy()*5.0))