blob: 88881c0b1d5fd7f5865511449cb72548545c10ef [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# =============================================================================
'''
This script includes Module class for python users
to use Computational Graph in their model.
'''
from functools import wraps
from singa import autograd
from . import singa_wrap as singa
from .device import get_default_device
import gc
class Graph(type):
def buffer_operation(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.graph_mode and self.training:
name = func.__name__
if name not in self._called:
# tag this function
self._called.add(name)
# buffer operations
self._device.EnableGraph(True)
ret = func(self, *args, **kwargs)
self._device.Sync()
self._device.EnableGraph(False)
# deconstruct Operations before running the entire graph
if name == 'optim':
for fname in self._results:
if isinstance(self._results[fname], list):
for _matrix in self._results[fname]:
_matrix.creator = None
else:
self._results[fname].creator = None
# make sure all Operations are deallocated
gc.collect()
# add result tensor
self._results[name] = ret
# run graph
self._device.RunGraph(self.sequential)
self.initialized = True
return ret
return self._results[name]
else:
return func(self, *args, **kwargs)
return wrapper
def __new__(cls, name, bases, attr):
attr["forward"] = Graph.buffer_operation(attr["forward"])
attr["loss"] = Graph.buffer_operation(attr["loss"])
attr["optim"] = Graph.buffer_operation(attr["optim"])
return super(Graph, cls).__new__(cls, name, bases, attr)
class Module(object, metaclass=Graph):
""" Base class for your neural network modules.
Example usage::
import numpy as np
from singa import opt
from singa import tensor
from singa import device
from singa import autograd
from singa.module import Module
class Model(Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = autograd.Conv2d(1, 20, 5, padding=0)
self.conv2 = autograd.Conv2d(20, 50, 5, padding=0)
self.sgd = opt.SGD(lr=0.01)
def forward(self, x):
y = self.conv1(x)
y = self.conv2(y)
return y
def loss(self, out, y):
return autograd.softmax_cross_entropy(out, y)
def optim(self, loss):
self.sgd.backward_and_update(loss)
"""
def __init__(self):
"""
Initializes internal Module state
"""
self.training = True
self.graph_mode = True
self.sequential = False
self.initialized = False
self._device = get_default_device()
self._results = {}
self._called = set()
def forward(self, *input):
"""Defines the computation performed at every call.
Should be overridden by all subclasses.
Args:
*input: the input training data for the module
Returns:
out: the outputs of the forward propagation.
"""
raise NotImplementedError
def loss(self, *args, **kwargs):
"""Defines the loss function performed when training the module.
"""
pass
def optim(self, *args, **kwargs):
"""Defines the optim function for backward pass.
"""
pass
def train(self, mode=True):
"""Set the module in evaluation mode.
Args:
mode(bool): when mode is True, this module will enter training mode
"""
self.training = mode
autograd.training = True
def eval(self):
"""Sets the module in evaluation mode.
"""
self.train(mode=False)
autograd.training = False
def graph(self, mode=True, sequential=False):
""" Turn on the computational graph. Specify execution mode.
Args:
mode(bool): when mode is True, module will use computational graph
sequential(bool): when sequential is True, module will execute ops
in the graph follow the order of joining the graph
"""
self.graph_mode = mode
self.sequential = sequential
def on_device(self, device):
"""Sets the target device.
The following training will be performed on that device.
Args:
device(Device): the target device
"""
self._device = device
def __get_name__(self):
return self.__class__.__name__
def __call__(self, *input, **kwargs):
if self.graph_mode and self.training:
if self.initialized == True:
self._device.RunGraph(self.sequential)
return self.forward(*input, **kwargs)