blob: d39e37eae66a2c019fd7e9dac1cc404e4f060c30 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
import mxnet as mx
import numpy as np
class Sqr(mx.operator.CustomOp):
'''Example of how to use custom op with sparse ndarrays
'''
def forward(self, is_train, req, in_data, out_data, aux):
inp = in_data[0]
if inp.stype == 'csr':
csr_m = inp.data
out = mx.nd.sparse.csr_matrix((csr_m * csr_m, inp.indices, inp.indptr), shape=inp.shape)
else:
out = inp * inp
self.assign(out_data[0], req[0], out)
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
self.assign(in_grad[0], req[0], 2 * mx.nd.sparse.elemwise_mul(in_data[0], out_grad[0]))
@mx.operator.register("sqr")
class SqrProp(mx.operator.CustomOpProp):
def __init__(self):
super(SqrProp, self).__init__(need_top_grad=True)
def list_arguments(self):
return ['data']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
return in_shape, [in_shape[0]], []
def infer_type(self, in_type):
return in_type, [in_type[0]], []
def infer_storage_type(self, in_stype):
'''Infer storage type logic for the forward pass
Takes a list of storage types for inputs
Returns three lists lists, one for input storage types inferred,
second for output storage types inferred and third for aux storage
types inferred
The in_stype is the list containing storage type for inputs
If the input is a dense ndarray then we infer the input
and output to be dense. If input is csr then input and output
are inferred as csr.
'''
if in_stype[0] == 'default':
return ['default'], ['default'], []
return ['csr'], ['csr'], []
def infer_storage_type_backward(self, ograd_stype, in_stype, out_stype, igrad_stype, aux_stype):
'''Infer storage type logic for the backward pass
Takes storage type of output gradients(ograd_stype), inputs(in_stype),
outputs(out_stype) and aux(aux_stype).
Returns inferred storage types in the following order:
ograd_stype, in_stype, out_stype, igrad_stype (Storage type for input gradients)
and aux_stype.
'''
if in_stype[0] == 'default':
return ['default'], ['default'], ['default'], ['default'], []
return ['csr'], ['csr'], ['csr'], ['default'], []
def create_operator(self, ctx, shapes, dtypes):
return Sqr()
#Default
x = mx.nd.array(np.random.uniform(1, 10, size=(4,10)))
x.attach_grad(stype='default')
z = mx.nd.ones((4,10))
with mx.autograd.record():
y = mx.nd.Custom(x, op_type='sqr')
y.backward(out_grad=z)
print("Original ndarray")
print("--------------")
print(x.asnumpy())
print("Squared ndarray")
print("--------------")
print(y.asnumpy())
print("stype of input is {}".format(x.stype))
print("stype of output is {}".format(y.stype))
print("Grad ndarray")
print(x.grad.asnumpy())
#Sparse
x = mx.nd.array(np.random.uniform(1, 10, size=(4,10)))
x = x.tostype('csr')
x.attach_grad(stype='default')
z = mx.nd.ones((4,10))
z = z.tostype('csr')
with mx.autograd.record():
y = mx.nd.Custom(x, op_type='sqr')
y.backward(out_grad=z)
print("Original ndarray")
print("--------------")
print(x.asnumpy())
print("Squared ndarray")
print("--------------")
print(y.asnumpy())
print("stype of input is {}".format(x.stype))
print("stype of output is {}".format(y.stype))
print("Grad ndarray")
print(x.grad.asnumpy())