blob: a9f59478cca19c3619bd078793b143387fef462e [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file multi_lans.cu
* \brief multi-tensor LANS optimizer
* \author Shuai Zheng
*/
#include "./multi_lans-inl.h"
namespace mxnet {
namespace op {
#define BLOCK_SIZE_LAMB 512
#define ILP_LAMB 4
template <bool has_mixed_precision, typename MPDType, typename DType>
__global__ void KernelStep1(const MultiLANSKernelParam<DType, MPDType> kernel_params,
const float beta1,
const float beta2,
const MPDType beta3,
const MPDType beta4,
const float epsilon,
const float clip_gradient,
const float rescale_grad,
float* g_sq_norm,
float* temp_m,
float* temp_g,
int* block_to_tensor,
int* block_to_chunk) {
const int tensor_id = block_to_tensor[blockIdx.x];
const int chunck_id = block_to_chunk[blockIdx.x];
const int start_pos = chunck_id * kernel_params.chunk_size + threadIdx.x;
const int stop_pos = chunck_id * kernel_params.chunk_size + kernel_params.chunk_size;
MPDType g_norm = sqrtf(g_sq_norm[tensor_id]);
MPDType biascorrection1, biascorrection2;
biascorrection1 = 1.0 - static_cast<MPDType>(
pow(beta1, static_cast<float>(kernel_params.step_count[tensor_id])));
biascorrection2 = 1.0 - static_cast<MPDType>(
pow(beta2, static_cast<float>(kernel_params.step_count[tensor_id])));
MPDType r_weight[ILP_LAMB];
MPDType r_grad[ILP_LAMB];
MPDType r_mean[ILP_LAMB];
MPDType r_var[ILP_LAMB];
MPDType r_m[ILP_LAMB];
MPDType r_g[ILP_LAMB];
for (size_t i = start_pos; i < stop_pos && i < kernel_params.sizes[tensor_id];
i += blockDim.x * ILP_LAMB) {
#pragma unroll
for (int ii = 0; ii < ILP_LAMB; ii++) {
int load_pos = i + ii * blockDim.x;
if (load_pos < stop_pos && load_pos < kernel_params.sizes[tensor_id]) {
r_weight[ii] = has_mixed_precision ?
kernel_params.weights32[tensor_id][load_pos] :
static_cast<MPDType>(kernel_params.weights[tensor_id][load_pos]);
r_grad[ii] = static_cast<MPDType>(kernel_params.grads[tensor_id][load_pos]);
r_mean[ii] = kernel_params.mean[tensor_id][load_pos];
r_var[ii] = kernel_params.var[tensor_id][load_pos];
} else {
r_weight[ii] = static_cast<MPDType>(0);
r_grad[ii] = static_cast<MPDType>(0);
r_mean[ii] = static_cast<MPDType>(0);
r_var[ii] = static_cast<MPDType>(0);
}
}
#pragma unroll
for (int ii = 0; ii < ILP_LAMB; ii++) {
r_grad[ii] = (r_grad[ii] * rescale_grad) / g_norm;
if (clip_gradient >= 0.0f)
r_grad[ii] = max(min(r_grad[ii], clip_gradient), -clip_gradient);
r_mean[ii] = static_cast<MPDType>(beta1) * r_mean[ii] + beta3 * r_grad[ii];
r_var[ii] = static_cast<MPDType>(beta2) * r_var[ii] + beta4 * r_grad[ii] * r_grad[ii];
MPDType r_var_hat = sqrt(r_var[ii] / biascorrection2) + static_cast<MPDType>(epsilon);
r_m[ii] = (r_mean[ii] / biascorrection1) / r_var_hat;
r_g[ii] = r_grad[ii] / r_var_hat;
r_m[ii] = __fmaf_rn(kernel_params.wds[tensor_id], r_weight[ii], r_m[ii]);
r_g[ii] = __fmaf_rn(kernel_params.wds[tensor_id], r_weight[ii], r_g[ii]);
}
#pragma unroll
for (int ii = 0; ii < ILP_LAMB; ii++) {
int store_pos = i + ii * blockDim.x;
if (store_pos < stop_pos && store_pos < kernel_params.sizes[tensor_id]) {
kernel_params.mean[tensor_id][store_pos] = r_mean[ii];
kernel_params.var[tensor_id][store_pos] = r_var[ii];
temp_m[kernel_params.tensor2temp_g[tensor_id] + store_pos] = r_m[ii];
temp_g[kernel_params.tensor2temp_g[tensor_id] + store_pos] = r_g[ii];
}
}
}
}
template <bool has_mixed_precision, typename MPDType, typename DType>
__global__ void KernelStep2(const MultiLANSKernelParam<DType, MPDType> kernel_params,
const float beta1,
const MPDType beta3,
const float* sum_sq_weigths,
const float* sum_sq_temp_m,
const float* sum_sq_temp_g,
const float* temp_m,
const float* temp_g,
const float lower_bound,
const float upper_bound,
int* block_to_tensor,
int* block_to_chunk,
const OpReqType req) {
const int tensor_id = block_to_tensor[blockIdx.x];
const int chunck_id = block_to_chunk[blockIdx.x];
const int start_pos = chunck_id * kernel_params.chunk_size + threadIdx.x;
const int stop_pos = chunck_id * kernel_params.chunk_size + kernel_params.chunk_size;
MPDType r1 = sqrtf(sum_sq_weigths[tensor_id]);
MPDType r2_m = sqrtf(sum_sq_temp_m[tensor_id]);
MPDType r2_g = sqrtf(sum_sq_temp_g[tensor_id]);
if (lower_bound >= 0)
r1 = max(r1, lower_bound);
if (upper_bound >= 0)
r1 = min(r1, upper_bound);
MPDType lr_adjusted_m, lr_adjusted_g;
if (r1 == 0.0f || r2_m == 0.0f)
lr_adjusted_m = kernel_params.learning_rates[tensor_id];
else
lr_adjusted_m = kernel_params.learning_rates[tensor_id] * r1 / r2_m;
if (r1 == 0.0f || r2_g == 0.0f)
lr_adjusted_g = kernel_params.learning_rates[tensor_id];
else
lr_adjusted_g = kernel_params.learning_rates[tensor_id] * r1 / r2_g;
lr_adjusted_m *= static_cast<MPDType>(beta1);
lr_adjusted_g *= beta3;
MPDType r_weight[ILP_LAMB];
MPDType r_m[ILP_LAMB];
MPDType r_g[ILP_LAMB];
for (size_t i = start_pos; i < stop_pos && i < kernel_params.sizes[tensor_id];
i += blockDim.x * ILP_LAMB) {
#pragma unroll
for (int ii = 0; ii < ILP_LAMB; ii++) {
int load_pos = i + ii * blockDim.x;
if (load_pos < stop_pos && load_pos < kernel_params.sizes[tensor_id]) {
r_weight[ii] = has_mixed_precision ?
kernel_params.weights32[tensor_id][load_pos] :
static_cast<MPDType>(kernel_params.weights[tensor_id][load_pos]);
r_m[ii] = temp_m[kernel_params.tensor2temp_g[tensor_id] + load_pos];
r_g[ii] = temp_g[kernel_params.tensor2temp_g[tensor_id] + load_pos];
}
}
#pragma unroll
for (int ii = 0; ii < ILP_LAMB; ii++) {
r_weight[ii] -= lr_adjusted_m * r_m[ii] + lr_adjusted_g * r_g[ii];
}
#pragma unroll
for (int ii = 0; ii < ILP_LAMB; ii++) {
int store_pos = i + ii * blockDim.x;
if (store_pos < stop_pos && store_pos < kernel_params.sizes[tensor_id]) {
if (has_mixed_precision)
kernel_params.weights32[tensor_id][store_pos] = r_weight[ii];
KERNEL_ASSIGN(kernel_params.out_data[tensor_id][store_pos], req, r_weight[ii]);
}
}
}
}
template <typename MPDType, typename DType>
void CallKernel1(Stream<gpu>* s,
const MultiLANSKernelParam<DType, MPDType>& kernel_params,
const MultiLANSParam& param,
float* g_sq_norm,
float* temp_m,
float* temp_g,
int* block_to_tensor,
int* block_to_chunk) {
int nblocks = kernel_params.nchunks;
int* host_block2tensor = reinterpret_cast<int*>(malloc(kernel_params.nchunks * sizeof(int)));
int* host_block2chunk = reinterpret_cast<int*>(malloc(kernel_params.nchunks * sizeof(int)));
int chunk_id = 0;
for (size_t index = 0; index < kernel_params.ntensors; ++index) {
int current_chunk = 0;
for (size_t j = 0; j < kernel_params.sizes[index]; j += kernel_params.chunk_size) {
host_block2tensor[chunk_id] = index;
host_block2chunk[chunk_id] = current_chunk;
current_chunk++;
chunk_id++;
}
}
cudaMemcpyAsync(block_to_tensor,
host_block2tensor,
kernel_params.nchunks * sizeof(int),
cudaMemcpyHostToDevice,
Stream<gpu>::GetStream(s));
cudaMemcpyAsync(block_to_chunk,
host_block2chunk,
kernel_params.nchunks * sizeof(int),
cudaMemcpyHostToDevice,
Stream<gpu>::GetStream(s));
bool has_mixed_precision = !std::is_same<DType, MPDType>::value;
MPDType beta3 = 1.0 - param.beta1;
MPDType beta4 = 1.0 - param.beta2;
if (has_mixed_precision)
KernelStep1<true>
<<<nblocks, BLOCK_SIZE_LAMB, 0, Stream<gpu>::GetStream(s)>>>(kernel_params,
param.beta1,
param.beta2,
beta3,
beta4,
param.epsilon,
param.clip_gradient,
param.rescale_grad,
g_sq_norm,
temp_m,
temp_g,
block_to_tensor,
block_to_chunk);
else
KernelStep1<false>
<<<nblocks, BLOCK_SIZE_LAMB, 0, Stream<gpu>::GetStream(s)>>>(kernel_params,
param.beta1,
param.beta2,
beta3,
beta4,
param.epsilon,
param.clip_gradient,
param.rescale_grad,
g_sq_norm,
temp_m,
temp_g,
block_to_tensor,
block_to_chunk);
}
template <typename MPDType, typename DType>
void CallKernel2(Stream<gpu>* s,
const MultiLANSKernelParam<DType, MPDType>& kernel_params,
const MultiLANSParam& param,
float* r1,
float* r2_m,
float* r2_g,
float* temp_m,
float* temp_g,
int* block_to_tensor,
int* block_to_chunk,
const OpReqType req) {
size_t nblocks = kernel_params.nchunks;
bool has_mixed_precision = !std::is_same<DType, MPDType>::value;
MPDType beta3 = 1.0 - param.beta1;
if (has_mixed_precision)
KernelStep2<true><<<nblocks, BLOCK_SIZE_LAMB, 0, Stream<gpu>::GetStream(s)>>>(kernel_params,
param.beta1,
beta3,
r1,
r2_m,
r2_g,
temp_m,
temp_g,
param.lower_bound,
param.upper_bound,
block_to_tensor,
block_to_chunk,
req);
else
KernelStep2<false>
<<<nblocks, BLOCK_SIZE_LAMB, 0, Stream<gpu>::GetStream(s)>>>(kernel_params,
param.beta1,
beta3,
r1,
r2_m,
r2_g,
temp_m,
temp_g,
param.lower_bound,
param.upper_bound,
block_to_tensor,
block_to_chunk,
req);
}
NNVM_REGISTER_OP(_multi_lans_update)
.set_attr<FCompute>("FCompute<gpu>", MultiLANSUpdate<gpu, false>);
NNVM_REGISTER_OP(_multi_mp_lans_update)
.set_attr<FCompute>("FCompute<gpu>", MultiLANSUpdate<gpu, true>);
} // namespace op
} // namespace mxnet