blob: b21e9b9b89fcb76302db76928c3f037600c70308 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from mxnet.ndarray import NDArray, topk, abs as NDabs
from mxnet.optimizer import SGD, register
import logging
log = 'Sparsity Update:\t'
@register
class SparseSGD(SGD):
"""The SGD optimizer with weight pruning.
This class implements the optimizer described in the paper *DSD: Dense-Sparse-Dense Training for
Deep Neural Networks*, available at https://arxiv.org/pdf/1607.04381.pdf
The optimizer updates the weights the same way as done in SGD, but does the following
preprocessing::
if threshold given, all weights below the threshold in absolute value are pruned,
mask = abs(weight) >= threshold
if sparsity level given, the smallest (sparsity)% weights in absolute value are pruned
(or the largest (100-sparsity)% weights in absolute value are used)
mask = topk(abs(weight), ret_typ='mask', k=weight.size*(100-sparsity)/100)
=> mask[i,j] = {0 if weight[i,j] is pruned, 1 otherwise} (for a matrix representation)
weight = weight * mask
grad = grad * mask
state = state * mask
This optimizer accepts the following parameters in addition to those accepted
by :class:`.SGD`.
Parameters
----------
pruning_switch_epoch : list of ints, optional
The epochs at which there is a change in sparsity level (should be in ascending order).
weight_sparsity : list of floats, optional
The sparsity on the weights required on each iteration of sparse training.
bias_sparsity : list of floats, optional
The sparsity on the biases required on each iteration of sparse training.
weight_threshold : list of floats, optional
The absolute value threshold on the weights required on each iteration of sparse training.
bias_threshold : list of floats, optional
The absolute value threshold on the biases required on each iteration of sparse training.
batches_per_epoch : int, optional
The number of batches in each epoch.
(The ceiling integer value of number_of_examples / batch_size)
"""
def __init__(self, pruning_switch_epoch, batches_per_epoch,
weight_sparsity=None, bias_sparsity=None,
weight_threshold=None, bias_threshold=None, **kwargs):
super(SparseSGD, self).__init__(**kwargs)
self.masks = []
self.masks_updated = False
self.epoch = 0
self.pruning_switch_epoch = pruning_switch_epoch
self.batches_per_epoch = batches_per_epoch
# get weight and bias sparsity percentages
self.weight_sparsity = weight_sparsity
self.bias_sparsity = bias_sparsity
if weight_sparsity is not None:
assert len(weight_sparsity) == len(bias_sparsity), \
'weight_sparsity and bias_sparsity should have same length'
assert len(weight_sparsity) == len(pruning_switch_epoch), \
'pruning_switch_epoch and weight_sparsity should have same length'
# get weight and bias sparsity thresholds
self.weight_threshold = weight_threshold
self.bias_threshold = bias_threshold
if weight_threshold is not None:
assert len(weight_threshold) == len(bias_threshold), \
'weight_threshold and bias_threshold should have same length'
assert len(weight_threshold) == len(pruning_switch_epoch), \
'pruning_switch_epoch and weight_sparsity_threshold should have same length'
# either percentages or thresholds must be given
assert weight_sparsity is not None or weight_threshold is not None,\
'weight_sparsity or weight_sparsity_threshold should be given'
def update_masks(self, index, weight):
"""Updates the masks for sparse training.
Parameters
----------
index : int
The index for weight.
weight : NDArray
The weight matrix.
Returns
-------
boolean
If the masks were changed
"""
# determine number of updates without actually updating the count
if index not in self._index_update_count:
num_update = self.begin_num_update
else:
num_update = self._index_update_count[index]
num_update += 1
num_update = max(num_update, self.num_update)
# calculate epoch
epoch = int((num_update - 1) / self.batches_per_epoch) + 1
# determine if masks need to be updated, and get corresponding parameters
if index == 0:
self.masks_updated = True
if self.epoch != epoch:
self.epoch = epoch
if epoch == 1:
self.masks_updated = False
if self.weight_sparsity is not None:
logging.info(log + 'bias-sparsity={}, weight-sparsity={}'.format(self.bias_sparsity[0], self.weight_sparsity[0]))
else:
logging.info(log + 'bias-threshold={}, weight-threshold={}'.format(self.bias_threshold[0], self.weight_threshold[0]))
if self.pruning_switch_epoch[0] + 1 == epoch:
self.masks_updated = False
self.pruning_switch_epoch.pop(0)
if self.weight_sparsity is not None:
self.weight_sparsity.pop(0)
self.bias_sparsity.pop(0)
logging.info(log + 'bias-sparsity={}, weight-sparsity={}'.format(self.bias_sparsity[0], self.weight_sparsity[0]))
else:
self.weight_threshold.pop(0)
self.bias_threshold.pop(0)
logging.info(log + 'bias-threshold={}, weight-threshold={}'.format(self.bias_threshold[0], self.weight_threshold[0]))
# update masks if needed
if not self.masks_updated:
# initialize masks
if epoch == 1:
self.masks.append(None)
# if percentages are given
if self.weight_sparsity is not None:
if len(weight.shape) == 1:
sparsity = self.bias_sparsity[0]
else:
sparsity = self.weight_sparsity[0]
number_unpruned = int((100.0 - sparsity) * weight.size / 100.0)
self.masks[index] = topk(NDabs(weight), axis=None, ret_typ='mask',
k=number_unpruned)
# if thresholds are given
else:
if len(weight.shape) == 1:
threshold = self.bias_threshold[0]
else:
threshold = self.weight_threshold[0]
self.masks[index] = NDabs(weight) >= threshold
return not self.masks_updated
def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
# preprocessing for pruning
if self.update_masks(index, weight):
weight[:] = weight * self.masks[index]
grad[:] = grad * self.masks[index]
if state is not None:
state[:] = state * self.masks[index]
super(SparseSGD, self).update(index, weight, grad, state)