| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| |
| from __future__ import division |
| |
| from collections import Counter, deque |
| import numpy as np |
| |
| from singa import tensor |
| from singa import utils |
| from .tensor import Tensor |
| from . import singa_wrap as singa |
| |
| CTensor = singa.Tensor |
| training = False |
| |
| |
| def axis_helper(y_shape, x_shape): |
| """ |
| check which axes the x has been broadcasted |
| Args: |
| y_shape: the shape of result |
| x_shape: the shape of x |
| Return: |
| a tuple refering the axes |
| """ |
| res = [] |
| j = len(x_shape) - 1 |
| for i in range(len(y_shape) - 1, -1, -1): |
| if j < 0 or x_shape[j] != y_shape[i]: |
| res.append(i) |
| j -= 1 |
| return tuple(res[::-1]) |
| |
| |
| def back_broadcast(y_shape, x_shape, x): |
| """ |
| for a brodcasted tensor, restore its shape of x from y_shape to x_shape |
| Args: |
| y_shape: the shape of result |
| x_shape: the shape of x |
| x: the input |
| Return: |
| a tensor |
| """ |
| if y_shape != x_shape: |
| x = tensor.from_raw_tensor(x) |
| axis = axis_helper(y_shape, x_shape) |
| x = tensor.sum(x, axis) |
| x = tensor.reshape(x, x_shape) |
| x = x.data |
| return x |
| |
| |
| def infer_dependency(op): |
| """ |
| Infer the dependency of all operations with the |
| given op as the last operation. |
| Operator A is depending on B if A uses the output(s) of B. |
| |
| Args: |
| op: an Operator instance, e.g. the loss operation. |
| |
| Return: |
| a Counter instance with the operation as the key, |
| and the number of operations that are depending on it as the value; |
| and a Counter instance with the id of the output tensor as the key, and |
| the number of operations that are depending on it as the value. |
| """ |
| |
| # current op is not inserted into the dependency_count |
| # if the current op is not a terminal op, then this function may just |
| # count dependency of a branch. |
| op_count = Counter() |
| tensor_count = Counter() |
| queue = deque([op]) |
| while len(queue) > 0: |
| cur_op = queue.pop() |
| for src_op, xid, _, _ in cur_op.src: |
| if src_op not in op_count: |
| op_count[src_op] = 1 |
| queue.append(src_op) |
| else: |
| op_count[src_op] += 1 |
| tensor_count[xid] += 1 |
| return op_count, tensor_count |
| |
| |
| def gradients(y, dy=None): |
| """ |
| Compute the gradients of the output w.r.t the parameters |
| |
| Args: |
| y: the output tensor, e.g., the loss |
| dy: gradient of the target w.r.t y; None indicates the gradient is 1.0; |
| it can be used to rescale the loss. |
| |
| Return: |
| a dictionary storing the gradient tensors of all tensors |
| whose stores_grad is true (e.g. parameter tensors) |
| """ |
| grads = {} # mapping: x->dx if x.stores_grad |
| for p, dp in backward(y, dy): |
| # TODO: this fn is only helper for test case for now. |
| # 1. could implement __hash__ or |
| # 2. make grad as a attribute of tensor class |
| # p.grad = dp |
| grads[id(p)] = dp |
| return grads |
| |
| |
| def backward(y, dy=None): |
| """ |
| Run the backward propagation starting at y. |
| Args: |
| y: a Tensor instance, usually the loss |
| dy: a number or a Tensor instance, for the gradient of the |
| objective/loss w.r.t y, usually None, i.e., 1.0 |
| Return: |
| yeild the parameter (tensor with stores_grad true) and the |
| gradient tensors. |
| """ |
| assert isinstance(y, Tensor), "wrong input type." |
| op_dep, tensor_dep = infer_dependency(y.creator) |
| assert y.size() == 1, ("y must be a Tensor with a single value;" |
| "size of y is % d" % y.size()) |
| |
| # by default the dy is a tensor with 1.0 for each sample; |
| if dy is None: |
| dy = float(1.0) |
| elif isinstance(dy, Tensor): |
| dy = dy.data |
| else: |
| dy = float(dy) |
| |
| # ready is a queue of (operation, dy list) |
| ready = deque([(y.creator, (dy,))]) |
| not_ready = {} # mapping: op->[dy] |
| |
| if y.stores_grad: |
| # gradients[y] = dy |
| if isinstance(dy, float): |
| g = np.array(dy) |
| else: |
| g = dy |
| tg = Tensor(device=g.device(), data=g) |
| yield (y, tg) |
| |
| while len(ready) > 0: |
| op, dys = ready.pop() |
| if not op.requires_grad or isinstance(op, Dummy): |
| continue |
| # if not isinstance(op, tensor.Dummy): |
| dxs = op._do_backward(*dys) |
| # TODO src and dx must match |
| |
| assert len(op.src) == len(dxs), ( |
| "the number of src ops (=%d) and dx (=%d) not match" % |
| (len(op.src), len(dxs))) |
| for (src_op, x_id, y, y_stores_grad), dx in zip(op.src, dxs): |
| # prefix x is w.r.t op; prefix y is w.r.t src_op. |
| # x_id is the python id of one input arg of src_op, denoted as x. |
| # y_idx (below) is the index of x among the outputs of src_op. |
| # not_ready[src_op][y_idx] records the intermediate gradient |
| # of the y_idx'th output of src_op. 'intermediate gradient' |
| # indicates that if this output is used in multiple children |
| # operations, then we have to add the graident (dx) from all these |
| # children operations. When src_op is ready, it means that |
| # the gradient of all its outputs are available, i.e. all children |
| # operations have been backwarded. |
| # y is None if y.stores_grad is false; otherwise it is a Tensor |
| |
| if isinstance(src_op, Dummy) and (not src_op.stores_grad): |
| continue |
| |
| y_idx = src_op.y_id2idx[x_id] |
| if src_op not in not_ready: |
| # src_op may have mulitple outputs |
| not_ready[src_op] = [None for _ in src_op.y_id2idx] |
| not_ready[src_op][y_idx] = dx |
| else: |
| dxs_ = not_ready[src_op] |
| if dxs_[y_idx] is None: |
| dxs_[y_idx] = dx |
| else: |
| # add the gradient from another children operation that |
| # uses y_idx'th output of src_op as input arg |
| dxs_[y_idx] += dx |
| |
| op_dep[src_op] -= 1 |
| tensor_dep[x_id] -= 1 |
| if y_stores_grad and tensor_dep[x_id] == 0: |
| # store the gradient for final return, e.g. for parameters. |
| # it may cause a delay to yield. Only after src_op's all |
| # output tensors have recieved the gradients, then output |
| g = not_ready[src_op][y_idx] |
| tg = Tensor(device=g.device(), |
| data=g, |
| name=src_op.grad_name(y_idx)) |
| yield (y, tg) |
| |
| if op_dep[src_op] == 0: |
| if src_op.requires_grad is True: |
| assert not isinstance( |
| src_op, Dummy), "Dummy op does not do backward()" |
| ready.append((src_op, not_ready[src_op])) |
| del not_ready[src_op] |
| del op # delete the operation to free all tensors from this op |
| |
| |
| class Operator(object): |
| """ |
| An operation includes the forward and backward function of |
| tensor calculation. |
| Steps to add a specific operation Xxxx: |
| 1. create a subclass of Operator, name it as Xxxx |
| 2. override the forward() and backward(); The arguments of forward() |
| and backward() should only include CTensor; |
| """ |
| |
| op_count = 0 |
| |
| def __init__(self, name=None): |
| if name is None: |
| self.name = "{}#{}".format(self.__class__.__name__, |
| Operator.op_count) |
| Operator.op_count += 1 |
| else: |
| self.name = name |
| |
| def __call__(self, *xs): |
| return self._do_forward(*xs) |
| |
| def output_name(self, idx): |
| """ |
| Args: |
| idx: index of the output among all outputs |
| |
| Return: |
| the name of the output tensor |
| """ |
| return "{}:{}".format(self.name, idx) |
| |
| def grad_name(self, idx): |
| """ |
| Args: |
| idx: index of the output among all outputs |
| |
| Return: |
| the name of the gradient of the output tensor |
| """ |
| return "{}_g".format(self.output_name(idx)) |
| |
| def _do_forward(self, *xs): |
| """ |
| Do not call this function from user code. It is called by __call__(). |
| Args: |
| xs, Tensor instance(s) |
| Returns: |
| Tensor instance(s) |
| """ |
| # TODO add the pre hook |
| assert all([isinstance(x, Tensor) for x in xs |
| ]), "xs should include only Tensor instances" |
| |
| # need to do backward if any of its input arg needs gradient |
| self.requires_grad = any([x.requires_grad for x in xs]) |
| |
| self.src = [] |
| for x in xs: |
| if x.stores_grad: |
| # store the tensor whose gradient needs be returned in |
| # backward(), e.g. if x is parameter |
| self.src.append((x.creator, id(x), x, x.stores_grad)) |
| else: |
| # for intermediate tensors, they will be released soon; |
| # no need to store them --> use None |
| self.src.append((x.creator, id(x), None, x.stores_grad)) |
| |
| # get the CTensor (data) if the input arg is Tensor |
| xs = tuple(x.data for x in xs) |
| ys = self.forward(*xs) |
| if not isinstance(ys, tuple): |
| ys = (ys,) |
| # create Tensor based on CTensor(data); |
| # assume outputs are all Tensor instances |
| ys = tuple( |
| Tensor( |
| device=y.device(), |
| data=y, |
| requires_grad=self.requires_grad, |
| creator=self, |
| name=self.output_name(idx), |
| ) for idx, y in enumerate(ys)) |
| # map from python id to output index |
| self.y_id2idx = {id(y): i for i, y in enumerate(ys)} |
| # TODO add the post hook |
| return ys |
| |
| def _do_backward(self, *dys): |
| dxs = self.backward(*dys) |
| if not isinstance(dxs, tuple): |
| dxs = (dxs,) |
| return dxs |
| |
| def forward(self, *xs): |
| """Forward propagation. |
| Args: |
| xs: input args consisting of only CTensors. |
| Returns: |
| CTensor instance(s) |
| """ |
| raise NotImplementedError |
| |
| def backward(self, *dys): |
| """ Backward propagation. |
| Args: |
| dys: input args consisting of only CTensors. |
| Returns: |
| CTensor instance(s) |
| """ |
| raise NotImplementedError |
| |
| def get_params(self): |
| return [] |
| |
| |
| class Dummy(Operator): |
| """Dummy operation whice serves as a placehoder for autograd |
| Args: |
| name(string): set it for debug |
| """ |
| |
| def __init__(self, tensor, name=None): |
| super(Dummy, self).__init__(name) |
| self.src = [] |
| self.y_id2idx = {id(tensor): 0} |
| self.tensor = tensor |
| self.requires_grad = False |
| |
| def output_name(self, idx): |
| return self.name |
| |
| def grad_name(self, idx): |
| return "{}_g".format(self.name) |
| |
| def __getattr__(self, name): |
| return self.tensor.__getattribute__(name) |
| |
| |
| class Mean(Operator): |
| """ |
| Element-wise mean of each of the input CTensors. |
| """ |
| |
| def __init__(self): |
| super(Mean, self).__init__() |
| |
| def forward(self, *l): |
| """ |
| Args: |
| l (a list of CTensor): a list of CTensor for element-wise mean. |
| Returns: |
| a new CTensor. |
| """ |
| if training: |
| self.l = len(l) |
| assert (len(l) > 0) |
| x = singa.Tensor(list(l[0].shape()), l[0].device()) |
| x.SetFloatValue(0.0) |
| for i in range(len(l)): |
| x += l[i] |
| return singa.MultFloat(x, 1 / len(l)) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): dL / dy. |
| Returns: |
| a list of dx (CTensor). |
| """ |
| return [singa.MultFloat(dy, 1 / self.l)] * self.l |
| |
| |
| def mean(*l): |
| """ |
| Element-wise mean of each of the input tensors. |
| Args: |
| l (a list of Tensor): element-wise mean operator. |
| Returns: |
| a new Tensor. |
| """ |
| return Mean()(*l)[0] |
| |
| |
| class ReLU(Operator): |
| """ |
| Relu means rectified linear function, i.e, y = max(0, x) is applied to the |
| CTensor elementwise. |
| """ |
| |
| def __init__(self): |
| super(ReLU, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): input tensor. |
| Returns: |
| a new CTensor whose element y = x if x >= 0; otherwise 0. |
| """ |
| if training: |
| self.input = x |
| return singa.ReLU(x) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): dL / dy. |
| Returns: |
| dx (CTensor): dL / dx = dy if x >= 0; otherwise 0. |
| """ |
| return singa.ReLUBackward(dy, self.input) |
| |
| |
| def relu(x): |
| """ |
| Relu means rectified linear function, i.e, y = max(0, x) is applied to the |
| CTensors elementwise. |
| Args: |
| x (Tensor): input tensor. |
| Returns: |
| a new Tensor whose element y = x if x >= 0; otherwise 0. |
| """ |
| return ReLU()(x)[0] |
| |
| |
| class Less(Operator): |
| """ |
| Returns the tensor resulted from performing the less logical operation |
| elementwise on the input CTensors x and y. |
| """ |
| |
| def __init__(self): |
| super(Less, self).__init__() |
| |
| def forward(self, x, y): |
| """ |
| Return a<b, where a and b are CTensor. |
| """ |
| cur = singa.LTFloat(singa.__sub__(x, y), 0) |
| if training: |
| self.cache = cur |
| return cur |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): data for the dL / dy, L is the loss. |
| Raises: |
| AssertionError: no backward function for this operator. |
| """ |
| assert False, ('no backward function for less') |
| |
| |
| def less(x, y): |
| """ |
| Return a<b, where a and b are CTensor. |
| """ |
| return Less()(x, y)[0] |
| |
| |
| class Clip(Operator): |
| """ |
| Clip operator limits the given input within an interval. The interval |
| is specified by the inputs 'min' and 'max'. |
| """ |
| |
| def __init__(self, min, max): |
| """ |
| Args: |
| min (float): min value, under which element is replaced by min. |
| max (float): max value, above which element is replaced by max. |
| """ |
| super(Clip, self).__init__() |
| self.max = max |
| self.min = min |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): input tensor |
| Returns: |
| a new CTensor with np.clip(x,min,max) |
| """ |
| self.mask = singa.Tensor(list(x.shape()), x.device()) |
| self.mask.SetFloatValue(1.0) |
| |
| if self.min is not None: |
| self.min = float(self.min) |
| mask0 = singa.LTFloat(x, self.min) |
| mask1 = singa.GEFloat(x, self.min) |
| self.mask = singa.__mul__(mask1, self.mask) |
| x = singa.__add__(singa.MultFloat(mask0, self.min), |
| singa.__mul__(mask1, x)) |
| |
| if self.max is not None: |
| self.max = float(self.max) |
| mask0 = singa.GTFloat(x, self.max) |
| mask1 = singa.LEFloat(x, self.max) |
| self.mask = singa.__mul__(mask1, self.mask) |
| x = singa.__add__(singa.MultFloat(mask0, self.max), |
| singa.__mul__(mask1, x)) |
| |
| return x |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): dL / dy |
| Returns: |
| dx (CTensor): dL / dx |
| """ |
| return singa.__mul__(dy, self.mask) |
| |
| |
| def clip(x, min=None, max=None): |
| """ |
| Clip operator limits the given input within an interval. The interval |
| is specified by the inputs 'min' and 'max'. |
| Args: |
| x (Tensor): input tensor |
| min (float): Minimum value, under which element is replaced by min. |
| max (float): Maximum value, above which element is replaced by max. |
| Returns: |
| a new Tensor with np.clip(x,min,max). |
| """ |
| return Clip(min, max)(x)[0] |
| |
| |
| class Identity(Operator): |
| """ |
| Init a identity operator |
| """ |
| |
| def __init__(self): |
| super(Identity, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): input tensor. |
| Returns: |
| the same CTensor x. |
| """ |
| return x |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): dL / dy. |
| Returns: |
| dx (CTensor): dL / dx. |
| """ |
| return dy |
| |
| |
| def identity(x): |
| """ |
| Init a identity operator. |
| Args: |
| x (Tensor): input tensor. |
| Returns: |
| the same Tensor with x. |
| """ |
| return Identity()(x)[0] |
| |
| |
| class Matmul(Operator): |
| """ |
| Init matrix multiplication operator. |
| """ |
| |
| def __init__(self): |
| super(Matmul, self).__init__() |
| |
| def forward(self, x, w): |
| """ |
| Return `np.matmul(x,w)`, where x and w are CTensor. |
| """ |
| # todo, cannot do Mult for dims more than 2 |
| if training: |
| self.input = (x, w) |
| res = singa.Mult(x, w) |
| return res |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): data for the dL / dy, L is the loss. |
| Returns: |
| a tuple for (dx, dw). |
| """ |
| return ( |
| singa.Mult(dy, singa.DefaultTranspose(self.input[1])), |
| singa.Mult(singa.DefaultTranspose(self.input[0]), dy), |
| ) |
| |
| |
| def matmul(x, w): |
| """ |
| Return `np.matmul(x,w)`, where x and w are Tensor. |
| """ |
| return Matmul()(x, w)[0] |
| |
| |
| class Greater(Operator): |
| """ |
| Returns the tensor resulted from performing the greater logical |
| operation elementwise on the input tensors A and B. |
| """ |
| |
| def __init__(self): |
| super(Greater, self).__init__() |
| |
| def forward(self, x, y): |
| """ |
| Return a>b, where a and b are CTensor. |
| """ |
| cur = singa.GTFloat(singa.__sub__(x, y), 0) |
| if training: |
| self.cache = cur |
| return cur |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): data for the dL / dy, L is the loss. |
| Raises: |
| AssertionError: no backward function for this operator. |
| """ |
| assert False, ('no backward function for greater') |
| |
| |
| def greater(x, y): |
| """ |
| Return a>b, where a and b are Tensor. |
| """ |
| return Greater()(x, y)[0] |
| |
| |
| class AddBias(Operator): |
| """ |
| Add Bias to each row / column of the Tensor, depending on the axis arg. |
| """ |
| |
| def __init__(self, axis=0): |
| """ |
| To indicate the calculation axis, 0 for row, 1 for column. |
| Args: |
| axis (int): 0 or 1, default is 0. |
| """ |
| super(AddBias, self).__init__() |
| self.axis = axis |
| |
| def forward(self, x, b): |
| """ |
| Args: |
| x (CTensor): matrix. |
| b (CTensor): bias to be added. |
| Return: |
| the result Tensor |
| """ |
| if self.axis == 0: |
| singa.AddRow(b, x) |
| elif self.axis == 1: |
| singa.AddColumn(b, x) |
| return x |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): data for the dL / dy, L is the loss. |
| Return: |
| a tuple for (db, dx), db is data for dL / db, dx is data |
| for dL / dx. |
| """ |
| dtype = dy.data_type() |
| _dy = dy.AsType(tensor.float32) |
| if self.axis == 0: |
| return dy, singa.Sum(_dy, 0).AsType(dtype) |
| elif self.axis == 1: |
| return dy, singa.Sum(_dy, 0).AsType(dtype) |
| |
| |
| def add_bias(x, b, axis=0): |
| """ |
| Add Bias to each row / column of the Tensor, depending on the axis arg. |
| Args: |
| x (Tensor): matrix. |
| b (Tensor): bias to be added. |
| axis (int): 0 or 1, default is 0. |
| Return: |
| the result Tensor |
| """ |
| assert x.ndim() == 2, "1st arg required 2d tensor. got shape: %s" % ( |
| x.shape) |
| assert b.ndim() == 1, "2nd arg required 1d tensor. got shape: %s" % ( |
| b.shape) |
| assert axis in [0, 1], "allowed axis: 0 or 1" |
| return AddBias(axis)(x, b)[0] |
| |
| |
| class Reshape(Operator): |
| """ |
| Reshape the input tensor similar to np.reshape. |
| """ |
| |
| def __init__(self, shape): |
| """ |
| Args: |
| shape (list of int): Specified shape for output. At most one |
| dimension of the new shape can be -1. In this case, the |
| value is inferred from the size of the tensor and the |
| remaining dimensions. A dimension could also be 0, |
| in which case the actual dimension value is unchanged |
| (i.e. taken from the input tensor). |
| """ |
| super(Reshape, self).__init__() |
| self.shape = shape |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): matrix. |
| Return: |
| the result CTensor |
| """ |
| self._shape = x.shape() |
| shape = list(self.shape) |
| # handle the shape with 0 |
| shape = [ |
| self._shape[i] |
| if i < len(self._shape) and shape[i] == 0 else shape[i] |
| for i in range(len(shape)) |
| ] |
| # handle the shape with -1 |
| hidden_shape = int(np.prod(self._shape) // np.abs(np.prod(shape))) |
| self.cache = [int(s) if s != -1 else hidden_shape for s in shape] |
| return singa.Reshape(x, self.cache) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): dL / dy |
| Returns: |
| dx (CTensor): dL / dx |
| """ |
| return singa.Reshape(dy, self._shape) |
| |
| |
| def reshape(x, shape): |
| """ |
| Reshape the input tensor similar to mp.reshape. |
| Args: |
| x (Tensor): matrix. |
| shape (list of int): Specified shape for output. At most one |
| dimension of the new shape can be -1. In this case, the |
| value is inferred from the size of the tensor and the |
| remaining dimensions. A dimension could also be 0, |
| in which case the actual dimension value is unchanged |
| (i.e. taken from the input tensor). |
| Return: |
| the result Tensor |
| """ |
| return Reshape(shape)(x)[0] |
| |
| |
| class PRelu(Operator): |
| """ |
| PRelu applies the function `f(x) = slope * x` for x < 0, |
| `f(x) = x` for x >= 0 to the data tensor elementwise. |
| """ |
| |
| def __init__(self): |
| super(PRelu, self).__init__() |
| |
| def forward(self, x, slope): |
| """ |
| Args: |
| x (CTensor): matrix. |
| Return: |
| the result CTensor |
| """ |
| mask0 = singa.LTFloat(x, 0.0) |
| res = singa.__mul__(x, mask0) |
| res = singa.__mul__(res, slope) |
| res += singa.ReLU(x) |
| if training: |
| self.input = x |
| self.slope = slope |
| self.mask0 = mask0 |
| self.shape0 = list(x.shape()) |
| self.shape1 = list(slope.shape()) |
| self.shape3 = list(res.shape()) |
| return res |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): dL / dy |
| Returns: |
| dx (CTensor): dL / dx |
| """ |
| dx1mask = singa.GEFloat(self.input, 0.0) |
| dx2 = singa.__mul__(self.mask0, self.slope) |
| dx = singa.__add__(dx1mask, dx2) |
| dx = singa.__mul__(dy, dx) |
| dslope = singa.__mul__(dy, singa.__mul__(self.mask0, self.input)) |
| if (type(dy) == float) or self.shape0 == self.shape1: |
| assert self.shape0 == self.shape1, ('should have same shape') |
| return dx, dslope |
| # handle broadcast |
| dx = back_broadcast(self.shape3, self.shape0, dx) |
| dslope = back_broadcast(self.shape3, self.shape1, dslope) |
| return dx, dslope |
| |
| |
| def prelu(x, slope): |
| """ |
| PRelu applies the function `f(x) = slope * x` for x < 0, |
| `f(x) = x` for x >= 0 to the data tensor elementwise. |
| Args: |
| x (Tensor): matrix. |
| Return: |
| the result Tensor |
| """ |
| return PRelu()(x, slope)[0] |
| |
| |
| class Add(Operator): |
| """ |
| Performs element-wise binary addition. |
| """ |
| |
| def __init__(self): |
| super(Add, self).__init__() |
| |
| def forward(self, a, b): |
| """ |
| Return `a+b`, where a and b are CTensor. |
| """ |
| res = singa.__add__(a, b) |
| if training: |
| self.shape0 = list(a.shape()) |
| self.shape1 = list(b.shape()) |
| self.shape3 = list(res.shape()) |
| return res |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy(CTensor): dL / dy |
| Return: |
| a tuple for (dx0, dx1), dx0 is data for dL / da, dx1 is data |
| for dL / db. |
| """ |
| dx0, dx1 = dy, dy |
| if (type(dy) == float) or self.shape0 == self.shape1: |
| assert self.shape0 == self.shape1, ('should have same shape') |
| return dx0, dx1 |
| # handle broadcast |
| dx0 = back_broadcast(self.shape3, self.shape0, dx0) |
| dx1 = back_broadcast(self.shape3, self.shape1, dx1) |
| return dx0, dx1 |
| |
| |
| def add(a, b): |
| """ |
| Return `a+b`, where a and b are Tensor. |
| """ |
| return Add()(a, b)[0] |
| |
| |
| class Elu(Operator): |
| """ |
| `f(x) = alpha * (exp(x) - 1.)` for x < 0, `f(x) = x` for x >= 0., is applied to |
| the tensor elementwise. |
| """ |
| |
| def __init__(self, alpha=1.): |
| """ |
| Args: |
| alpha (float): Coefficient of ELU, default is 1.0 |
| """ |
| super(Elu, self).__init__() |
| self.alpha = alpha |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): matrix |
| Returns: |
| a CTensor for the result |
| """ |
| #f(x) = alpha * (exp(x) - 1.) for x < 0, f(x) = x for x >= 0 |
| if training: |
| self.input = x |
| x1 = singa.LTFloat(x, 0.0) |
| x1 *= x |
| x1 = singa.MultFloat(singa.SubFloat(singa.Exp(x1), 1.0), self.alpha) |
| x2 = singa.ReLU(x) |
| x1 += x2 |
| return x1 |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): dL / dy |
| Returns: |
| dx (CTensor): dL / dx |
| """ |
| dx1mask = singa.LTFloat(self.input, 0.0) |
| dx = singa.MultFloat(singa.Exp(self.input), self.alpha) |
| dx *= dx1mask |
| |
| dx2mask = singa.GEFloat(self.input, 0.0) |
| |
| dx += dx2mask |
| dx *= dy |
| return dx |
| |
| |
| def elu(x, alpha=1): |
| """ |
| `f(x) = alpha * (exp(x) - 1.)` for x < 0, `f(x) = x` for x >= 0., is applied to |
| the tensor elementwise. |
| Args: |
| x (Tensor): matrix |
| alpha (float): Coefficient of ELU, default is 1.0 |
| Returns: |
| a Tensor for the result |
| """ |
| return Elu(alpha)(x)[0] |
| |
| |
| class Equal(Operator): |
| """ |
| Returns the tensor resulted from performing the equal logical operation |
| elementwise on the input tensors x and y. |
| """ |
| |
| def __init__(self): |
| super(Equal, self).__init__() |
| |
| def forward(self, x, y): |
| """ |
| Return `a=b`, where a and b are CTensor. |
| """ |
| return singa.__eq__(x, y) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): data for the dL / dy, L is the loss |
| Raises: |
| AssertionError: no backward function for this operator |
| """ |
| assert False, ('no backward function for equal') |
| |
| |
| def equal(x, y): |
| """ |
| Return `a=b`, where a and b are Tensor. |
| """ |
| return Equal()(x, y)[0] |
| |
| |
| class SeLU(Operator): |
| """ |
| `y = gamma * (alpha * e^x - alpha)` for x <= 0, `y = gamma * x` for x > 0 |
| is applied to the tensor elementwise. |
| """ |
| |
| def __init__(self, alpha=1.67326, gamma=1.0507): |
| """ |
| Args: |
| alpha (float): Coefficient of SELU default to 1.67326 |
| gamma (float): Coefficient of SELU default to 1.0507 |
| """ |
| super(SeLU, self).__init__() |
| self.alpha = alpha |
| self.gamma = gamma |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): matrix |
| Returns: |
| a CTensor for the result |
| """ |
| #y = gamma * (alpha * e^x - alpha) for x <= 0, y = gamma * x for x > 0 |
| if training: |
| self.input = x |
| x1 = singa.LEFloat(x, 0.0) |
| x1 *= x |
| x1 = singa.MultFloat(singa.SubFloat(singa.Exp(x1), 1.0), |
| self.alpha * self.gamma) |
| x2 = singa.ReLU(x) |
| x2 = singa.MultFloat(x2, self.gamma) |
| x1 += x2 |
| return x1 |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): dL / dy |
| Returns: |
| dx (CTensor): dL / dx |
| """ |
| dx1mask = singa.LEFloat(self.input, 0.0) |
| dx1 = singa.MultFloat(singa.Exp(self.input), self.gamma * self.alpha) |
| dx1 = singa.__mul__(dx1mask, dx1) |
| |
| dx2mask = singa.GTFloat(self.input, 0.0) |
| dx2 = singa.MultFloat(dx2mask, self.gamma) |
| |
| dx = singa.__add__(dx1, dx2) |
| dx *= dy |
| return dx |
| |
| |
| def selu(x, alpha=1.67326, gamma=1.0507): |
| """ |
| `y = gamma * (alpha * e^x - alpha)` for x <= 0, `y = gamma * x` for x > 0 |
| is applied to the tensor elementwise. |
| Args: |
| x (Tensor): matrix |
| alpha (float): Coefficient of SELU default to 1.67326 |
| gamma (float): Coefficient of SELU default to 1.0507 |
| Returns: |
| a Tensor for the result |
| """ |
| return SeLU(alpha, gamma)(x)[0] |
| |
| |
| class SoftMax(Operator): |
| """ |
| Apply SoftMax for each row of the Tensor or each column of the Tensor |
| according to the parameter axis. |
| """ |
| |
| def __init__(self, axis=1): |
| """ |
| Args: |
| axis (int): axis of softmax, default to 1 |
| """ |
| super(SoftMax, self).__init__() |
| self.axis = axis |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): the input 1d or 2d tensor |
| Returns: |
| the result CTensor |
| """ |
| self.output = singa.SoftMax(x, self.axis) |
| return self.output |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): dL / dy |
| Returns: |
| dx (CTensor): dL / dx |
| """ |
| return singa.SoftMaxBackward(dy, self.axis, self.output) |
| |
| |
| def softmax(x, axis=1): |
| """ |
| Apply SoftMax for each row of the Tensor or each column of the Tensor |
| according to the parameter axis. |
| Args: |
| x (Tensor): the input 1d or 2d tensor |
| axis (int): axis of softmax, default to 1 |
| Returns: |
| the result Tensor |
| """ |
| return SoftMax(axis)(x)[0] |
| |
| |
| class Sum(Operator): |
| """ |
| Element-wise sum of each of the input tensors |
| """ |
| |
| def __init__(self): |
| super(Sum, self).__init__() |
| |
| def forward(self, *l): |
| """ |
| Args: |
| l (a list of CTensor): element-wise sum operator |
| Returns: |
| a CTensor for the result |
| """ |
| if training: |
| self.l = len(l) |
| assert (len(l) > 0) |
| x = singa.Tensor(list(l[0].shape()), l[0].device()) |
| x.SetFloatValue(0.0) |
| for i in range(len(l)): |
| x += l[i] |
| return x |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): dL / dy |
| Returns: |
| dx (CTensor): dL / dx |
| """ |
| return [dy] * self.l |
| |
| |
| def sum(*l): |
| """ |
| Element-wise sum of each of the input tensors |
| Args: |
| l (a list of Tensor): element-wise sum operator |
| Returns: |
| a Tensor for the result |
| """ |
| return Sum()(*l)[0] |
| |
| |
| class BinaryCrossEntropy(Operator): |
| |
| def __init__(self, t): |
| super(BinaryCrossEntropy, self).__init__() |
| self.t = t.data |
| |
| """ |
| Calculte negative log likelihood loss for a batch of training data. |
| """ |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): 1d or 2d tensor, the prediction data(output) |
| of current network. |
| t (CTensor): 1d or 2d tensor, the target data for training. |
| Returns: |
| loss (CTensor): scalar. |
| """ |
| posx = singa.AddFloat(x, 0.0001) |
| loss = singa.SumAll(singa.__mul__(self.t, singa.Log(posx))) |
| negt = singa.AddFloat(singa.MultFloat(self.t, -1.0), 1.0) |
| negx = singa.AddFloat(singa.MultFloat(x, -1.0), 1.0001) |
| negLoss = singa.SumAll(singa.__mul__(negt, singa.Log(negx))) |
| loss += negLoss |
| loss /= -x.shape()[0] |
| self.x = singa.AddFloat(x, 0.0001) |
| return loss |
| |
| def backward(self, dy=1.0): |
| """ |
| Args: |
| dy (float or CTensor): scalar, accumulate gradient from outside |
| of current network, usually equal to 1.0 |
| Returns: |
| dx (CTensor): data for the dL /dx, L is the loss, x is the output |
| of current network. note that this is true for |
| dy = 1.0 |
| """ |
| |
| dx = singa.__div__(self.t, self.x) |
| negt = singa.AddFloat(self.t, -1.0) |
| negx = singa.AddFloat(self.x, -0.9999) |
| dx -= singa.__div__(negt, negx) |
| dx *= float(-1.0 / self.x.shape()[0]) |
| if isinstance(dy, float): |
| # dtype of dy: float |
| dx *= dy |
| return dx |
| elif isinstance(dy, CTensor): |
| pass # TODO, broadcast elementwise multiply seems not support |
| |
| |
| def binary_cross_entropy(x, t): |
| return BinaryCrossEntropy(t)(x)[0] |
| |
| |
| class CrossEntropy(Operator): |
| |
| def __init__(self, t): |
| super(CrossEntropy, self).__init__() |
| self.t = t.data |
| |
| """ |
| Calculte negative log likelihood loss for a batch of training data. |
| """ |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): 1d or 2d tensor, the prediction data(output) |
| of current network. |
| t (CTensor): 1d or 2d tensor, the target data for training. |
| Returns: |
| loss (CTensor): scalar. |
| """ |
| loss = singa.SumAll(singa.__mul__(self.t, singa.Log(x))) |
| loss /= -x.shape()[0] |
| self.x = x |
| return loss |
| |
| def backward(self, dy=1.0): |
| """ |
| Args: |
| dy (float or CTensor): scalar, accumulate gradient from outside |
| of current network, usually equal to 1.0 |
| Returns: |
| dx (CTensor): data for the dL /dx, L is the loss, x is the output |
| of current network. note that this is true for |
| dy = 1.0 |
| """ |
| |
| dx = singa.__div__(self.t, self.x) |
| dx *= float(-1.0 / self.x.shape()[0]) |
| if isinstance(dy, float): |
| # dtype of dy: float |
| dx *= dy |
| return dx |
| elif isinstance(dy, CTensor): |
| pass # TODO, broadcast elementwise multiply seems not support |
| |
| |
| def cross_entropy(x, t): |
| assert x.ndim() == 2, "1st arg required 2d tensor. got shape: " + str( |
| x.shape) |
| assert t.ndim() <= 2, "2nd arg required <=2d tensor. got shape: " + str( |
| t.shape) |
| # x is the logits and t is the ground truth. |
| return CrossEntropy(t)(x)[0] |
| |
| |
| class RankingLoss(Operator): |
| |
| def __init__(self, M=0.2): |
| super().__init__() |
| # margin |
| self.M = M |
| |
| def forward(self, pos, neg): |
| # L = max{0, M - fn(pos) + fn(neg)} |
| zero = singa.Tensor(list(pos.shape()), pos.device()) |
| zero.SetFloatValue(0.0) |
| val = singa.AddFloat(singa.__sub__(neg, pos), self.M) |
| gt_zero = singa.__gt__(val, zero) |
| if training: |
| self.inputs = (gt_zero,) # (BS,) |
| all_loss = singa.__mul__(gt_zero, val) |
| loss = singa.SumAll(all_loss) |
| loss /= (pos.shape()[0]) |
| return loss |
| |
| def backward(self, dy=1.0): |
| assert training, "enable training mode to do backward" |
| # dpos = -1 if M-pos+neg > 0 else 0 |
| # dneg = 1 if M-pos+neg > 0 else 0 |
| gt_zero = self.inputs[0] |
| dpos_factor = singa.Tensor(list(gt_zero.shape()), gt_zero.device()) |
| dpos_factor.SetFloatValue(-1.0 / gt_zero.Size()) |
| dneg_factor = singa.Tensor(list(gt_zero.shape()), gt_zero.device()) |
| dneg_factor.SetFloatValue(1.0 / gt_zero.Size()) |
| dpos = singa.__mul__(gt_zero, dpos_factor) |
| dneg = singa.__mul__(gt_zero, dneg_factor) |
| return dpos, dneg |
| |
| |
| def ranking_loss(pos, neg, M=0.2): |
| assert pos.shape == neg.shape, "input and target shape different: %s, %s" % ( |
| pos.shape, neg.shape) |
| return RankingLoss(M)(pos, neg)[0] |
| |
| |
| class SoftMaxCrossEntropy(Operator): |
| |
| def __init__(self, t): |
| super(SoftMaxCrossEntropy, self).__init__() |
| self.t = t.data |
| |
| def forward(self, x): |
| self.p = singa.SoftMax(x) |
| ret = singa.CrossEntropyFwd(self.p, self.t) |
| loss = singa.SumAll(ret) |
| loss /= x.shape()[0] |
| return loss |
| |
| def backward(self, dy=1.0): |
| dx = singa.SoftmaxCrossEntropyBwd(self.p, self.t) |
| dx /= float(self.p.shape()[0]) |
| return dx |
| |
| |
| def softmax_cross_entropy(x, t): |
| assert x.ndim() == 2, "1st arg required 2d tensor. got shape: " + str( |
| x.shape) |
| assert t.ndim() <= 2, "2nd arg required <=2d tensor. got shape: " + str( |
| t.shape) |
| # x is the logits and t is the ground truth. |
| return SoftMaxCrossEntropy(t)(x)[0] |
| |
| |
| class MeanSquareError(Operator): |
| |
| def __init__(self, t): |
| super(MeanSquareError, self).__init__() |
| self.t = t.data |
| |
| def forward(self, x): |
| self.err = singa.__sub__(x, self.t) |
| sqr = singa.Square(self.err) |
| loss = singa.SumAll(sqr) |
| self.n = 1 |
| for s in x.shape(): |
| self.n *= s |
| loss /= self.n |
| return loss |
| |
| def backward(self, dy=1.0): |
| dx = self.err |
| dx *= float(2 / self.n) |
| dx *= dy |
| return dx |
| |
| |
| def mse_loss(x, t): |
| assert x.shape == t.shape, "input and target shape different: %s, %s" % ( |
| x.shape, t.shape) |
| return MeanSquareError(t)(x)[0] |
| |
| |
| def ctensor2numpy(x): |
| """ |
| To be used in SoftMax Operator. |
| Convert a singa_tensor to numpy_tensor. |
| """ |
| np_array = x.GetFloatValue(int(x.Size())) |
| return np_array.reshape(x.shape()) |
| |
| |
| class Flatten(Operator): |
| """ |
| Flattens the input tensor into a 2D matrix. If input tensor has shape |
| `(d_0, d_1, ... d_n)` then the output will have shape `(d_0 X d_1 ... |
| d_(axis-1), d_axis X d_(axis+1) ... X dn)`. |
| """ |
| |
| def __init__(self, axis=1): |
| """ |
| Args: |
| axis (int): Indicate up to which input dimensions (exclusive) |
| should be flattened to the outer dimension of the output. The |
| value for axis must be in the range [-r, r], where r is the |
| rank of the input tensor. Negative value means counting |
| dimensions from the back. When axis = 0, the shape of the |
| output tensor is `(1, (d_0 X d_1 ... d_n)`, where the shape |
| of the input tensor is `(d_0, d_1, ... d_n)`. |
| Returns: |
| the result CTensor |
| """ |
| super(Flatten, self).__init__() |
| self.axis = axis |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): the input tensor |
| Returns: |
| the result CTensor |
| """ |
| self.shape = list(x.shape()) |
| shape, axis = self.shape, self.axis |
| # the axis must be within this range (0, r-1) |
| assert axis <= len( |
| shape) - 1 or axis >= 0, "the axis must be within (0, %d-1)" % len( |
| shape) |
| # calculate the new shape |
| new_shape = (1, int(np.prod(shape))) if axis == 0 else ( |
| int(np.prod(shape[0:axis]).astype(int)), |
| int(np.prod(shape[axis:]).astype(int))) |
| y = singa.Reshape(x, new_shape) |
| return y |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): data for the dL / dy, L is the loss |
| Returns: |
| dx (CTensor): data for the dL / dx, L is the loss, |
| """ |
| dx = singa.Reshape(dy, self.shape) |
| return dx |
| |
| |
| def flatten(x, axis=1): |
| """ |
| Flattens the input tensor into a 2D matrix. If input tensor has shape |
| `(d_0, d_1, ... d_n)` then the output will have shape `(d_0 X d_1 ... |
| d_(axis-1), d_axis X d_(axis+1) ... X dn)`. |
| Args: |
| x (Tensor): the input tensor |
| axis (int): Indicate up to which input dimensions (exclusive) |
| should be flattened to the outer dimension of the output. The |
| value for axis must be in the range [-r, r], where r is the |
| rank of the input tensor. Negative value means counting |
| dimensions from the back. When axis = 0, the shape of the |
| output tensor is `(1, (d_0 X d_1 ... d_n)`, where the shape |
| of the input tensor is `(d_0, d_1, ... d_n)`. |
| Returns: |
| the result Tensor |
| """ |
| return Flatten(axis)(x)[0] |
| |
| |
| class ScatterElements(Operator): |
| """ |
| ScatterElements operator following ONNX Operator Schemas |
| https://github.com/onnx/onnx/blob/master/docs/Changelog.md#ScatterElements-11 |
| |
| Example usage: |
| data = [ |
| [0.0, 0.0, 0.0], |
| [0.0, 0.0, 0.0], |
| [0.0, 0.0, 0.0], |
| ] |
| axis = 0 |
| indices = [ |
| [1, 0, 2], |
| [0, 2, 1], |
| ] |
| updates = [ |
| [1.0, 1.1, 1.2], |
| [2.0, 2.1, 2.2], |
| ] |
| output = [ |
| [2.0, 1.1, 0.0] |
| [1.0, 0.0, 2.2] |
| [0.0, 2.1, 1.2] |
| ] |
| |
| """ |
| |
| def __init__(self, indices, updates, axis=0): |
| """ |
| Args: |
| indices (Tensor): index tensor |
| updates (Tensor): source tensor |
| axis (int): Which axis to scatter on. A negative value means |
| counting dimension from the back. Accepted range is [-r,r-1] |
| where r=rank(destination_tensor) |
| """ |
| super(ScatterElements, self).__init__() |
| self.indices = indices |
| self.updates = updates |
| self.axis = axis |
| |
| def forward(self, x): |
| x_shape = x.shape() |
| x_rank = len(x_shape) |
| if isinstance(self.indices, Tensor): |
| self.indices = tensor.to_numpy(self.indices) |
| elif isinstance(self.indices, (list, tuple)): |
| self.indices = np.array(self.indices) |
| if isinstance(self.updates, Tensor): |
| self.updates = tensor.to_numpy(self.updates) |
| elif isinstance(self.updates, (list, tuple)): |
| self.updates = np.array(self.updates) |
| self.updates.astype(np.int32) |
| _x = tensor.to_numpy(tensor.from_raw_tensor(x)) |
| _x = _x.astype(np.float32) |
| |
| assert x_rank == 2, "Only support 2D input." |
| assert x_rank == len( |
| self.indices.shape |
| ), "Index should have the same number of dimensions as output" |
| assert -x_rank < self.axis <= x_rank, "Axis is out of range" |
| assert np.logical_and( |
| -_x.shape[self.axis] < self.indices, |
| self.indices <= _x.shape[self.axis]).all( |
| ), "The values of the indexes should be between %d and %d" % ( |
| -_x.shape[self.axis], _x.shape[self.axis] - 1) |
| |
| self.axis = self.axis % x_rank |
| u_shape = self.updates.shape |
| y = _x.copy() |
| for i in range(u_shape[0]): |
| for j in range(u_shape[1]): |
| idx = int(self.indices[i][j]) |
| if self.axis == 0: |
| y[idx][j] = self.updates[i][j] |
| else: |
| y[i][idx] = self.updates[i][j] |
| y = tensor.from_numpy(y) |
| y.to_device(x.device()) |
| return y.data |
| |
| def backward(self, dy): |
| mask = np.ones(dy.shape(), dtype=np.float32) |
| u_shape = self.updates.shape |
| for i in range(u_shape[0]): |
| for j in range(u_shape[1]): |
| idx = int(self.indices[i][j]) |
| if self.axis == 0: |
| mask[idx][j] = 0. |
| else: |
| mask[i][idx] = 0. |
| mask = tensor.from_numpy(mask) |
| mask.to_device(dy.device()) |
| return singa.__mul__(dy, mask.data) |
| |
| |
| def scatter_elements(x, indices, updates, axis=0): |
| """ |
| Produces a ScatterElements operator |
| Args: |
| x (Tensor): input tensor. |
| indices (Tensor): index tensor |
| updates (Tensor): source tensor |
| axis (int): Which axis to scatter on. A negative value means |
| counting dimension from the back. Accepted range is [-r,r-1] |
| where r=rank(destination_tensor) |
| Returns: |
| the output Tensor. |
| """ |
| return ScatterElements(indices, updates, axis)(x)[0] |
| |
| |
| class Concat(Operator): |
| """ |
| Concatenate a list of tensors into a single tensor. All input tensors must |
| have the same shape, except for the dimension size of the axis to |
| concatenate on. |
| """ |
| |
| def __init__(self, axis=0): |
| """ |
| Args: |
| axis (int): Which axis to concat on. A negative value means |
| counting dimensions from the back. Accepted range is [-r, r-1] |
| where r = rank(inputs). |
| Returns: |
| the result CTensor |
| """ |
| super(Concat, self).__init__() |
| self.axis = axis |
| |
| def forward(self, *xs): |
| """ |
| Args: |
| xs (a list of CTensor): List of tensors for concatenation |
| Returns: |
| a CTensor for the result |
| """ |
| if self.axis < 0: |
| self.axis = self.axis % len(xs[0].shape()) |
| if training: |
| offset = 0 |
| self.slice_point = [] |
| for t in xs: |
| offset += t.shape()[self.axis] |
| self.slice_point.append(offset) |
| x = singa.VecTensor(list(xs)) |
| return singa.ConcatOn(x, self.axis) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): data for the dL / dy, L is the loss |
| Returns: |
| dxs (a tuple of CTensor): data for the dL / dxs, L is the loss, |
| """ |
| assert hasattr( |
| self, "slice_point"), "Please set training as True before do BP. " |
| assert self.slice_point[-1] == dy.shape()[self.axis], "Shape mismatch." |
| dxs = [] |
| last_offset = 0 |
| for p in self.slice_point: |
| dxs.append(singa.SliceOn(dy, last_offset, p, self.axis)) |
| last_offset = p |
| return tuple(dxs) |
| |
| |
| def cat(xs, axis=0): |
| """ |
| Concatenate a list of tensors into a single tensor. All input tensors must |
| have the same shape, except for the dimension size of the axis to |
| concatenate on. |
| Args: |
| xs (a list of Tensor): List of tensors for concatenation |
| axis (int): Which axis to concat on. A negative value means |
| counting dimensions from the back. Accepted range is [-r, r-1] |
| where r = rank(inputs). |
| Returns: |
| a Tensor for the result |
| """ |
| return Concat(axis)(*xs)[0] |
| |
| |
| """ |
| def make_slice(arr, axis, i): # type: ignore |
| slc = [slice(None)] * arr.ndim |
| slc[axis] = i |
| return slc |
| """ |
| |
| |
| class _Conv2d(Operator): |
| """ |
| Init a conv 2d operator |
| """ |
| |
| def __init__(self, handle, odd_padding=(0, 0, 0, 0)): |
| """ |
| Args: |
| handle (object): ConvHandle for cpu or CudnnConvHandle for gpu |
| odd_padding (tuple of four ints):, the odd paddding is the value |
| that cannot be handled by the tuple padding (w, h) mode so |
| we need to firstly handle the input, then use the nomal padding |
| method. |
| """ |
| super(_Conv2d, self).__init__() |
| self.handle = handle |
| self.odd_padding = odd_padding |
| |
| def forward(self, x, W, b=None): |
| """ |
| Args: |
| x (CTensor): input |
| W (CTensor): weight |
| b (CTensor): bias |
| Returns: |
| CTensor |
| """ |
| assert x.nDim() == 4, "The dimensions of input should be 4D." |
| if self.odd_padding != (0, 0, 0, 0): |
| x = utils.handle_odd_pad_fwd(x, self.odd_padding) |
| |
| if training: |
| if self.handle.bias_term: |
| self.inputs = (x, W, b) |
| else: |
| self.inputs = (x, W) |
| |
| if not self.handle.bias_term: |
| # create empty bias tensor for Cpp API |
| b = CTensor((self.handle.num_filters,), x.device()) |
| b.SetFloatValue(0.0) |
| |
| if (type(self.handle) != singa.ConvHandle): |
| return singa.GpuConvForward(x, W, b, self.handle) |
| else: |
| return singa.CpuConvForward(x, W, b, self.handle) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): dL / dy |
| Returns: |
| dx (CTensor): dL / dx |
| """ |
| assert training is True and hasattr( |
| self, "inputs"), "Please set training as True before do BP. " |
| |
| if (type(self.handle) != singa.ConvHandle): |
| dx = singa.GpuConvBackwardx(dy, self.inputs[1], self.inputs[0], |
| self.handle) |
| dW = singa.GpuConvBackwardW(dy, self.inputs[0], self.inputs[1], |
| self.handle) |
| db = singa.GpuConvBackwardb( |
| dy, self.inputs[2], |
| self.handle) if self.handle.bias_term else None |
| else: |
| dx = singa.CpuConvBackwardx(dy, self.inputs[1], self.inputs[0], |
| self.handle) |
| dW = singa.CpuConvBackwardW(dy, self.inputs[0], self.inputs[1], |
| self.handle) |
| db = singa.CpuConvBackwardb( |
| dy, self.inputs[2], |
| self.handle) if self.handle.bias_term else None |
| if self.odd_padding != (0, 0, 0, 0): |
| dx = utils.handle_odd_pad_bwd(dx, self.odd_padding) |
| |
| if db: |
| return dx, dW, db |
| |
| else: |
| return dx, dW |
| |
| |
| def conv2d(handle, x, W, b=None, odd_padding=(0, 0, 0, 0)): |
| """ |
| Conv 2d operator |
| Args: |
| handle (object): ConvHandle for cpu or CudnnConvHandle for gpu |
| x (Tensor): input |
| W (Tensor): weight |
| b (Tensor): bias |
| odd_padding (tuple of four ints):, the odd paddding is the value |
| that cannot be handled by the tuple padding (w, h) mode so |
| we need to firstly handle the input, then use the nomal padding |
| method. |
| """ |
| if b is None: |
| return _Conv2d(handle, odd_padding)(x, W)[0] |
| else: |
| return _Conv2d(handle, odd_padding)(x, W, b)[0] |
| |
| |
| class _BatchNorm2d(Operator): |
| """ |
| Carries out batch normalization as described in the paper |
| https://arxiv.org/abs/1502.03167. |
| """ |
| |
| def __init__(self, handle, running_mean, running_var, name=None): |
| """ |
| Args: |
| handle (object): BatchNormHandle for cpu and CudnnBatchNormHandle |
| for gpu |
| running_mean (float): the running_mean |
| running_var (float): the running_var |
| name (string): the name assigned to this operator |
| """ |
| super(_BatchNorm2d, self).__init__(name) |
| self.handle = handle |
| self.running_mean = running_mean.data |
| self.running_var = running_var.data |
| |
| def forward(self, x, scale, bias): |
| """ |
| Args: |
| x (CTensor): the input tensor |
| scale (CTensor): the bias tensor |
| bias (CTensor): the bias tensor |
| Returns: |
| the result CTensor |
| """ |
| if training: |
| if (type(self.handle) == singa.BatchNormHandle): |
| y, mean, var = singa.CpuBatchNormForwardTraining( |
| self.handle, x, scale, bias, self.running_mean, |
| self.running_var) |
| |
| self.cache = (x, scale, mean, var, y, bias) |
| else: |
| y, mean, var = singa.GpuBatchNormForwardTraining( |
| self.handle, x, scale, bias, self.running_mean, |
| self.running_var) |
| |
| self.cache = (x, scale, mean, var) |
| |
| else: |
| |
| if (type(self.handle) == singa.BatchNormHandle): |
| y = singa.CpuBatchNormForwardInference( |
| self.handle, |
| x, |
| scale, |
| bias, |
| self.running_mean, |
| self.running_var, |
| ) |
| else: |
| y = singa.GpuBatchNormForwardInference( |
| self.handle, |
| x, |
| scale, |
| bias, |
| self.running_mean, |
| self.running_var, |
| ) |
| return y |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): data for the dL / dy, L is the loss |
| Returns: |
| dx (CTensor): data for the dL / dx, L is the loss |
| ds (CTensor): data for the dL / ds, L is the loss |
| db (CTensor): data for the dL / db, L is the loss |
| """ |
| assert training is True and hasattr( |
| self, "cache"), "Please set training as True before do BP. " |
| |
| if (type(self.handle) == singa.BatchNormHandle): |
| x, scale, mean, var, y, bias = self.cache |
| dx, ds, db = singa.CpuBatchNormBackwardx(self.handle, y, dy, x, |
| scale, bias, mean, var) |
| else: |
| x, scale, mean, var = self.cache |
| dx, ds, db = singa.GpuBatchNormBackward(self.handle, dy, x, scale, |
| mean, var) |
| |
| return dx, ds, db |
| |
| |
| def batchnorm_2d(handle, x, scale, bias, running_mean, running_var): |
| """ |
| Carries out batch normalization as described in the paper |
| https://arxiv.org/abs/1502.03167. |
| Args: |
| handle (object): BatchNormHandle for cpu and CudnnBatchNormHandle |
| for gpu |
| x (Tensor): the input tensor |
| scale (Tensor): the bias tensor |
| bias (Tensor): the bias tensor |
| running_mean (float): the running_mean |
| running_var (float): the running_var |
| Returns: |
| the result Tensor |
| """ |
| return _BatchNorm2d(handle, running_mean, running_var)(x, scale, bias)[0] |
| |
| |
| class _Pooling2d(Operator): |
| """ |
| Init a pool 2d operator |
| """ |
| |
| def __init__(self, handle, odd_padding=(0, 0, 0, 0)): |
| """ |
| Args: |
| handle (object): PoolingHandle for cpu or CudnnPoolingHandle for |
| gpu |
| odd_padding (tuple of four int): the odd paddding is the value |
| that cannot be handled by the tuple padding (w, h) mode so |
| it needs to firstly handle the input, then use the normal |
| padding method. |
| """ |
| super(_Pooling2d, self).__init__() |
| self.handle = handle |
| self.odd_padding = odd_padding |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): the input tensor |
| Returns: |
| the result CTensor |
| """ |
| assert x.nDim() == 4, "The dimensions of input should be 4D." |
| if self.odd_padding != (0, 0, 0, 0): |
| x = utils.handle_odd_pad_fwd(x, self.odd_padding, True) |
| |
| if (type(self.handle) != singa.PoolingHandle): |
| y = singa.GpuPoolingForward(self.handle, x) |
| else: |
| y = singa.CpuPoolingForward(self.handle, x) |
| if training: |
| self.cache = (x, y) |
| return y |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): data for the dL / dy, L is the loss |
| Returns: |
| dx (CTensor): data for the dL / dx, L is the loss, |
| """ |
| if (type(self.handle) != singa.PoolingHandle): |
| dx = singa.GpuPoolingBackward(self.handle, dy, self.cache[0], |
| self.cache[1]) |
| else: |
| dx = singa.CpuPoolingBackward(self.handle, dy, self.cache[0], |
| self.cache[1]) |
| if self.odd_padding != (0, 0, 0, 0): |
| dx = utils.handle_odd_pad_bwd(dx, self.odd_padding) |
| |
| return dx |
| |
| |
| def pooling_2d(handle, x, odd_padding=(0, 0, 0, 0)): |
| """ |
| Pooling 2d operator |
| Args: |
| handle (object): PoolingHandle for cpu or CudnnPoolingHandle for |
| gpu |
| x (Tensor): input |
| odd_padding (tuple of four int): the odd paddding is the value |
| that cannot be handled by the tuple padding (w, h) mode so |
| it needs to firstly handle the input, then use the normal |
| padding method. |
| Returns: |
| the result Tensor |
| """ |
| return _Pooling2d(handle, odd_padding)(x)[0] |
| |
| |
| class Tanh(Operator): |
| """ |
| Calculates the hyperbolic tangent of the given input tensor element-wise. |
| """ |
| |
| def __init__(self): |
| super(Tanh, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| out = singa.Tanh(x) |
| if training: |
| self.cache = (out,) |
| return out |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.__mul__(self.cache[0], self.cache[0]) |
| dx = singa.MultFloat(dx, -1.0) |
| dx = singa.AddFloat(dx, 1.0) |
| dx *= dy |
| return dx |
| |
| |
| def tanh(x): |
| """ |
| Calculates the hyperbolic tangent of the given input tensor element-wise. |
| Args: |
| x (Tensor): Input tensor |
| Returns: |
| Tensor, the output |
| """ |
| return Tanh()(x)[0] |
| |
| |
| class Cos(Operator): |
| """ |
| Calculates the cosine of the given input tensor, element-wise. |
| """ |
| |
| def __init__(self): |
| super(Cos, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| if training: |
| self.input = x |
| return singa.Cos(x) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.Sin(self.input) |
| dx = singa.MultFloat(dx, -1.0) |
| dx *= dy |
| return dx |
| |
| |
| def cos(x): |
| """ |
| Calculates the cosine of the given input tensor, element-wise. |
| Args: |
| x (Tensor): Input tensor |
| Returns: |
| Tensor, the output |
| """ |
| |
| return Cos()(x)[0] |
| |
| |
| class Cosh(Operator): |
| """ |
| Calculates the hyperbolic cosine of the given input tensor element-wise. |
| """ |
| |
| def __init__(self): |
| super(Cosh, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| if training: |
| self.input = x |
| return singa.Cosh(x) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.Sinh(self.input) |
| dx *= dy |
| return dx |
| |
| |
| def cosh(x): |
| """ |
| Calculates the hyperbolic cosine of the given input tensor element-wise. |
| Args: |
| x (Tensor): Input tensor |
| Returns: |
| Tensor, the output |
| """ |
| return Cosh()(x)[0] |
| |
| |
| class Acos(Operator): |
| """ |
| Calculates the arccosine (inverse of cosine) of the given input tensor, |
| element-wise. |
| """ |
| |
| def __init__(self): |
| super(Acos, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| if training: |
| self.input = x |
| return singa.Acos(x) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.Square(self.input) |
| dx = singa.MultFloat(dx, -1.0) |
| dx = singa.AddFloat(dx, 1.0) |
| dx = singa.PowFloat(dx, -0.5) |
| dx = singa.MultFloat(dx, -1.0) |
| dx *= dy |
| return dx |
| |
| |
| def acos(x): |
| """ |
| Calculates the arccosine (inverse of cosine) of the given input tensor, |
| element-wise. |
| Args: |
| x (Tensor): Input tensor |
| Returns: |
| Tensor, the output |
| """ |
| return Acos()(x)[0] |
| |
| |
| class Acosh(Operator): |
| """ |
| Calculates the hyperbolic arccosine of the given input tensor element-wise. |
| """ |
| |
| def __init__(self): |
| super(Acosh, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| if training: |
| self.input = x |
| return singa.Acosh(x) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.SubFloat(self.input, 1.0) |
| dx = singa.Sqrt(dx) |
| temp = singa.AddFloat(self.input, 1.0) |
| temp = singa.Sqrt(temp) |
| dx = singa.__mul__(dx, temp) |
| dx = singa.PowFloat(dx, -1.0) |
| dx *= dy |
| return dx |
| |
| |
| def acosh(x): |
| """ |
| Calculates the hyperbolic arccosine of the given input tensor element-wise. |
| Args: |
| x (Tensor): Input tensor |
| Returns: |
| Tensor, the output |
| """ |
| return Acosh()(x)[0] |
| |
| |
| class Sin(Operator): |
| """ |
| Calculates the sine of the given input tensor, element-wise. |
| """ |
| |
| def __init__(self): |
| super(Sin, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| if training: |
| self.input = x |
| return singa.Sin(x) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.Cos(self.input) |
| dx *= dy |
| return dx |
| |
| |
| def sin(x): |
| """ |
| Calculates the sine of the given input tensor, element-wise. |
| Args: |
| x (Tensor): Input tensor |
| Returns: |
| Tensor, the output |
| """ |
| return Sin()(x)[0] |
| |
| |
| class Sinh(Operator): |
| """ |
| Calculates the hyperbolic sine of the given input tensor element-wise. |
| """ |
| |
| def __init__(self): |
| super(Sinh, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| if training: |
| self.input = x |
| return singa.Sinh(x) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.Cosh(self.input) |
| dx *= dy |
| return dx |
| |
| |
| def sinh(x): |
| """ |
| Calculates the hyperbolic sine of the given input tensor element-wise. |
| Args: |
| x (Tensor): Input tensor |
| Returns: |
| Tensor, the output |
| """ |
| return Sinh()(x)[0] |
| |
| |
| class Asin(Operator): |
| """ |
| Calculates the arcsine (inverse of sine) of the given input tensor, element-wise. |
| """ |
| |
| def __init__(self): |
| super(Asin, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| if training: |
| self.input = x |
| return singa.Asin(x) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.Square(self.input) |
| dx = singa.MultFloat(dx, -1.0) |
| dx = singa.AddFloat(dx, 1.0) |
| dx = singa.PowFloat(dx, -0.5) |
| dx *= dy |
| return dx |
| |
| |
| def asin(x): |
| """ |
| Calculates the arcsine (inverse of sine) of the given input tensor, element-wise. |
| Args: |
| x (Tensor): Input tensor |
| Returns: |
| Tensor, the output |
| """ |
| |
| return Asin()(x)[0] |
| |
| |
| class Asinh(Operator): |
| """ |
| Calculates the hyperbolic arcsine of the given input tensor element-wise. |
| """ |
| |
| def __init__(self): |
| super(Asinh, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| if training: |
| self.input = x |
| return singa.Asinh(x) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.Square(self.input) |
| dx = singa.AddFloat(dx, 1.0) |
| dx = singa.PowFloat(dx, -0.5) |
| dx *= dy |
| return dx |
| |
| |
| def asinh(x): |
| """ |
| Calculates the hyperbolic arcsine of the given input tensor element-wise. |
| Args: |
| x (Tensor): Input tensor |
| Returns: |
| Tensor, the output |
| """ |
| return Asinh()(x)[0] |
| |
| |
| class Tan(Operator): |
| """ |
| Insert single-dimensional entries to the shape of an input tensor (data). |
| """ |
| |
| def __init__(self): |
| super(Tan, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| if training: |
| self.input = x |
| return singa.Tan(x) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.Cos(self.input) |
| dx = singa.Square(dx) |
| dx = singa.PowFloat(dx, -1.0) |
| dx *= dy |
| return dx |
| |
| |
| def tan(x): |
| """ |
| Calculates the tangent of the given input tensor, element-wise. |
| Args: |
| x (Tensor): Input tensor |
| Returns: |
| Tensor, the output |
| """ |
| return Tan()(x)[0] |
| |
| |
| class Atan(Operator): |
| """ |
| Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise. |
| """ |
| |
| def __init__(self): |
| super(Atan, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| if training: |
| self.input = x |
| return singa.Atan(x) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.Square(self.input) |
| dx = singa.AddFloat(dx, 1.0) |
| dx = singa.PowFloat(dx, -1.0) |
| dx *= dy |
| return dx |
| |
| |
| def atan(x): |
| """ |
| Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise. |
| Args: |
| x (Tensor): Input tensor |
| Returns: |
| Tensor, the output |
| """ |
| return Atan()(x)[0] |
| |
| |
| class Atanh(Operator): |
| """ |
| Calculates the hyperbolic arctangent of the given input tensor element-wise. |
| """ |
| |
| def __init__(self): |
| super(Atanh, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| if training: |
| self.input = x |
| return singa.Atanh(x) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.Square(self.input) |
| dx = singa.MultFloat(dx, -1.0) |
| dx = singa.AddFloat(dx, 1.0) |
| dx = singa.PowFloat(dx, -1.0) |
| dx *= dy |
| return dx |
| |
| |
| def atanh(x): |
| """ |
| Calculates the hyperbolic arctangent of the given input tensor element-wise. |
| Args: |
| x (Tensor): Input tensor |
| Returns: |
| Tensor, the output |
| """ |
| return Atanh()(x)[0] |
| |
| |
| class Sigmoid(Operator): |
| """ |
| `y = 1 / (1 + exp(-x))`, is applied to the tensor elementwise. |
| """ |
| |
| def __init__(self): |
| super(Sigmoid, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| out = singa.Sigmoid(x) |
| if training: |
| self.cache = (out,) |
| return out |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.MultFloat(self.cache[0], -1.0) |
| dx = singa.AddFloat(dx, 1.0) |
| dx = singa.__mul__(self.cache[0], dx) |
| dx *= dy |
| return dx |
| |
| |
| def sigmoid(x): |
| """ |
| `y = 1 / (1 + exp(-x))`, is applied to the tensor elementwise. |
| Args: |
| x (Tensor): Input tensor |
| Returns: |
| Tensor, the output |
| """ |
| return Sigmoid()(x)[0] |
| |
| |
| class Mul(Operator): |
| """ |
| Performs element-wise binary multiplication (with Numpy-style broadcasting |
| support). |
| """ |
| |
| def __init__(self): |
| super(Mul, self).__init__() |
| |
| def forward(self, a, b): |
| """ |
| Return `np.multiply(a,b)`, where a and b are CTensor. |
| """ |
| # todo we cannot support mul op for int tensors |
| _a, _b = a, b |
| dtype0 = _a.data_type() |
| dtype1 = _b.data_type() |
| if dtype0 == singa.kInt or dtype1 == singa.kInt: |
| _a = a.AsType(singa.kFloat32) |
| _b = b.AsType(singa.kFloat32) |
| res = singa.__mul__(_a, _b) |
| res = res.AsType(singa.kInt) |
| else: |
| res = singa.__mul__(_a, _b) |
| if training: |
| self.input = (_a, _b) |
| self.shape0 = list(_a.shape()) |
| self.shape1 = list(_b.shape()) |
| self.shape3 = list(res.shape()) |
| return res |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| a tuple for (da, db), da is data for dL / da, db is data |
| for dL / db. |
| """ |
| dx0 = singa.__mul__(dy, self.input[1]) |
| dx1 = singa.__mul__(dy, self.input[0]) |
| if (type(dy) == float) or self.shape0 == self.shape1: |
| assert self.shape0 == self.shape1, ('should have same shape') |
| return dx0, dx1 |
| # handle broadcast |
| dx0 = back_broadcast(self.shape3, self.shape0, dx0) |
| dx1 = back_broadcast(self.shape3, self.shape1, dx1) |
| return dx0, dx1 |
| |
| |
| def mul(x, y): |
| """ |
| Return `np.multiply(x,y)`, where a and b are Tensor. |
| """ |
| return Mul()(x, y)[0] |
| |
| |
| class Unsqueeze(Operator): |
| """ |
| Insert single-dimensional entries to the shape of an input tensor (data). |
| """ |
| |
| def __init__(self, axis): |
| """ |
| Args: |
| axis (list of int): the dimensions to be inserted. |
| """ |
| super(Unsqueeze, self).__init__() |
| if (type(axis) is int): |
| self.axis = list(axis) |
| else: |
| self.axis = axis |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| self.cache = x.shape() |
| cur = list(self.cache) |
| # todo, need optimize after we have scalar tensor |
| if len(self.cache) == 1 and self.axis == [0]: |
| return x |
| for i in self.axis: |
| cur.insert(i, 1) |
| return singa.Reshape(x, cur) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| return singa.Reshape(dy, self.cache) |
| |
| |
| def unsqueeze(x, axis=-1): |
| """ |
| Insert single-dimensional entries to the shape of an input tensor (data). |
| Args: |
| x (Tensor): Input tensor |
| axis (list of int): the dimensions to be inserted. |
| Returns: |
| Tensor, the output |
| """ |
| return Unsqueeze(axis)(x)[0] |
| |
| |
| class Transpose(Operator): |
| """ |
| Transpose the input tensor similar to numpy.transpose. |
| """ |
| |
| def __init__(self, perm): |
| """ |
| Args: |
| perm (list of ints): A list of integers. By default, reverse the |
| dimensions, otherwise permute the axes according to the values given. |
| """ |
| super(Transpose, self).__init__() |
| self.perm = list(perm) |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| return singa.Transpose(x, self.perm) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| cur = [] |
| for i in range(len(self.perm)): |
| cur += [self.perm.index(i)] |
| return singa.Transpose(dy, cur) |
| |
| |
| def transpose(x, shape): |
| """ |
| Transpose the input tensor similar to numpy.transpose. |
| Args: |
| x (Tensor): Input tensor |
| perm (list of ints): A list of integers. By default, reverse the |
| dimensions, otherwise permute the axes according to the values given. |
| Returns: |
| Tensor, the output |
| """ |
| return Transpose(shape)(x)[0] |
| |
| |
| def add_all(*xs): |
| assert len(xs) > 2 |
| y = add(xs[0], xs[1]) |
| for x in xs[2:]: |
| y = add(y, x) |
| return |
| |
| |
| class Abs(Operator): |
| """ |
| `y = abs(x)`, is applied to the tensor elementwise. |
| """ |
| |
| def forward(self, a): |
| """ |
| Return `abs(a)`, where a is CTensor. |
| """ |
| if training: |
| self.input = a |
| return singa.Abs(a) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.Sign(self.input) |
| dx *= dy |
| return dx |
| |
| |
| def abs(a): |
| """ |
| Return abs(a), where a is Tensor. |
| """ |
| return Abs()(a)[0] |
| |
| |
| class Exp(Operator): |
| """ |
| `y = exp(x)`, is applied to the tensor elementwise. |
| """ |
| |
| def forward(self, a): |
| """ |
| Return `exp(a)`, where a is Tensor. |
| """ |
| if training: |
| self.input = a |
| return singa.Exp(a) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.Exp(self.input) |
| dx *= dy |
| return dx |
| |
| |
| def exp(a): |
| """ |
| Return `exp(a)`, where a is Tensor. |
| """ |
| return Exp()(a)[0] |
| |
| |
| class LeakyRelu(Operator): |
| """ |
| `f(x) = alpha * x` for x < 0, `f(x) = x` for x >= 0, is applied to the tensor elementwise. |
| """ |
| |
| def __init__(self, a): |
| """ |
| Args: |
| a (float): Coefficient of leakage. |
| """ |
| super(LeakyRelu, self).__init__() |
| self.a = a |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| if training: |
| self.input = x |
| x1 = singa.LTFloat(x, 0.0) |
| x1 = singa.__mul__(x, x1) |
| x1 = singa.MultFloat(x1, self.a) |
| x2 = singa.ReLU(x) |
| x1 = singa.__add__(x1, x2) |
| return x1 |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| # TODO(wangwei) check the correctness |
| dx1 = singa.GTFloat(self.input, 0.0) |
| dx2 = singa.LTFloat(self.input, 0.0) |
| dx2 = singa.MultFloat(dx2, self.a) |
| dx = singa.__add__(dx1, dx2) |
| dx *= dy |
| return dx |
| |
| |
| def leakyrelu(x, a=0.01): |
| """ |
| `f(x) = alpha * x` for x < 0, `f(x) = x` for x >= 0 is applied to the tensor |
| elementwise. |
| Args: |
| x (Tensor): Input tensor |
| a (float): Coefficient of leakage, default to 0.01. |
| Returns: |
| Tensor, the output |
| """ |
| return LeakyRelu(a)(x)[0] |
| |
| |
| class Sign(Operator): |
| """ |
| Calculate the sign of the given input tensor element-wise. If input > 0, |
| output 1. if input < 0, output -1. if input == 0, output 0. |
| """ |
| |
| def __init__(self): |
| super(Sign, self).__init__() |
| |
| def forward(self, a): |
| """ |
| Args: |
| a (CTensor): Input tensor |
| Returns: |
| CTensor, the output |
| """ |
| if training: |
| self.input = a |
| return singa.Sign(a) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.MultFloat(dy, 0.0) |
| return dx |
| |
| |
| def sign(a): |
| """ |
| Calculate the sign of the given input tensor element-wise. If input > 0, |
| output 1. if input < 0, output -1. if input == 0, output 0. |
| Args: |
| a (Tensor): Input tensor |
| Returns: |
| Tensor, the output |
| """ |
| return Sign()(a)[0] |
| |
| |
| class Pow(Operator): |
| """ |
| `f(x) = a^b`, is applied to the tensor elementwise. |
| """ |
| |
| def __init__(self): |
| super(Pow, self).__init__() |
| |
| def forward(self, a, b): |
| """ |
| Return `a^b`, where a and b are CTensor. |
| """ |
| res = singa.Pow(a, b) |
| if training: |
| self.input = (a, b) |
| self.shape0 = list(a.shape()) |
| self.shape1 = list(b.shape()) |
| self.shape3 = list(res.shape()) |
| return res |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| a tuple for (da, db), da is data for dL / da, db is data |
| for dL / db. |
| """ |
| da1 = singa.__mul__( |
| self.input[1], |
| singa.Pow(self.input[0], singa.SubFloat(self.input[1], 1.0))) |
| dx0 = singa.__mul__(da1, dy) |
| db1 = singa.__mul__(singa.Pow(self.input[0], self.input[1]), |
| singa.Log(self.input[0])) |
| dx1 = singa.__mul__(db1, dy) |
| if (type(dy) == float) or self.shape0 == self.shape1: |
| assert self.shape0 == self.shape1, ('should have same shape') |
| return dx0, dx1 |
| # handle broadcast |
| dx0 = back_broadcast(self.shape3, self.shape0, dx0) |
| dx1 = back_broadcast(self.shape3, self.shape1, dx1) |
| return dx0, dx1 |
| |
| |
| def pow(a, b): |
| """ |
| Return `a^b`, where a and b are Tensor. |
| """ |
| return Pow()(a, b)[0] |
| |
| |
| class SoftSign(Operator): |
| """ |
| Calculates the softsign `(x/(1+|x|))` of the given input tensor element-wise. |
| """ |
| |
| def __init__(self): |
| super(SoftSign, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Return `(x/(1+|x|))`, where x is CTensor. |
| """ |
| # y = x / (1 + np.abs(x)) |
| if training: |
| self.input = x |
| x1 = singa.AddFloat(singa.Abs(x), 1.0) |
| y = singa.__div__(x, x1) |
| |
| return y |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.AddFloat(singa.Abs(self.input), 1.0) |
| dx = singa.PowFloat(singa.Square(dx), -1.0) |
| dx = singa.__mul__(dy, dx) |
| return dx |
| |
| |
| def softsign(x): |
| """ |
| Return `(x/(1+|x|))`, where x is Tensor. |
| """ |
| return SoftSign()(x)[0] |
| |
| |
| class Sqrt(Operator): |
| """ |
| `y = x^0.5`, is applied to the tensor elementwise. |
| """ |
| |
| def __init__(self): |
| super(Sqrt, self).__init__() |
| |
| def forward(self, x): |
| """ |
| Return `x^0.5`, where x is CTensor. |
| """ |
| if training: |
| self.input = x |
| return singa.Sqrt(x) |
| |
| def backward(self, dy): |
| """ |
| Args: |
| dy (CTensor): the gradient tensor from upper operations |
| Returns: |
| CTensor, the gradient over input |
| """ |
| dx = singa.PowFloat(self.input, -0.5) |
| dx = |