blob: 46a8ed9523529d4416ab16c72f095b090d65afd7 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
"""Dynamic loss scaler for AMP."""
import logging
from .. import autograd as ag
from .. import ndarray
from ..util import is_np_array
class LossScaler(object):
"""Dynamic loss scaler for AMP.
Properties
----------
loss_scale : float
The current loss scale
"""
def __init__(self):
self._loss_scale = 2.**16
self._next_loss_scale = self._loss_scale
self._max_loss_scale = 2.**24
self._scale_seq_len = 2000
self._unskipped = 0
@property
def loss_scale(self):
return self._loss_scale
def has_overflow(self, params):
"""Check gradients for overflow."""
if is_np_array():
all_finite_f = ndarray.numpy._internal.multi_all_finite
ones_f = lambda ctx: ndarray.numpy.ones((1,), device=ctx)
else:
all_finite_f = ndarray.multi_all_finite
ones_f = lambda ctx: ndarray.ones((1,), ctx=ctx)
with ag.pause():
chunk_size = 200
valid_params = [p._grad[0] for p in params if p._grad is not None]
gpu_output = ones_f(valid_params[0].context)
nb_params = len(valid_params)
for idx in range(0, nb_params, chunk_size):
all_finite_f(*valid_params[idx:idx+chunk_size],
num_arrays=len(valid_params[idx:idx+chunk_size]),
init_output=False, out=gpu_output)
has_overflow = not bool(gpu_output.asnumpy())
self._loss_scale = self._next_loss_scale
if has_overflow:
self._next_loss_scale = self._loss_scale / 2.
self._unskipped = 0
logging.info("AMP: decreasing loss scale to %f", self._next_loss_scale)
else:
self._unskipped += 1
if self._unskipped == self._scale_seq_len:
self._unskipped = 0
self._next_loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
logging.info("AMP: increasing loss scale to %f", self._next_loss_scale)
return has_overflow