| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file multi_lans.cu |
| * \brief multi-tensor LANS optimizer |
| * \author Shuai Zheng |
| */ |
| |
| #include "./multi_lans-inl.h" |
| |
| namespace mxnet { |
| namespace op { |
| |
| #define BLOCK_SIZE_LAMB 512 |
| #define ILP_LAMB 4 |
| |
| template <bool has_mixed_precision, typename MPDType, typename DType> |
| __global__ void KernelStep1(const MultiLANSKernelParam<DType, MPDType> kernel_params, |
| const float beta1, |
| const float beta2, |
| const MPDType beta3, |
| const MPDType beta4, |
| const float epsilon, |
| const float clip_gradient, |
| const float rescale_grad, |
| float* g_sq_norm, |
| float* temp_m, |
| float* temp_g, |
| int* block_to_tensor, |
| int* block_to_chunk) { |
| const int tensor_id = block_to_tensor[blockIdx.x]; |
| const int chunck_id = block_to_chunk[blockIdx.x]; |
| const int start_pos = chunck_id * kernel_params.chunk_size + threadIdx.x; |
| const int stop_pos = chunck_id * kernel_params.chunk_size + kernel_params.chunk_size; |
| |
| MPDType g_norm = sqrtf(g_sq_norm[tensor_id]); |
| |
| MPDType biascorrection1, biascorrection2; |
| |
| biascorrection1 = 1.0 - static_cast<MPDType>( |
| pow(beta1, static_cast<float>(kernel_params.step_count[tensor_id]))); |
| biascorrection2 = 1.0 - static_cast<MPDType>( |
| pow(beta2, static_cast<float>(kernel_params.step_count[tensor_id]))); |
| |
| MPDType r_weight[ILP_LAMB]; |
| MPDType r_grad[ILP_LAMB]; |
| MPDType r_mean[ILP_LAMB]; |
| MPDType r_var[ILP_LAMB]; |
| MPDType r_m[ILP_LAMB]; |
| MPDType r_g[ILP_LAMB]; |
| |
| for (size_t i = start_pos; i < stop_pos && i < kernel_params.sizes[tensor_id]; |
| i += blockDim.x * ILP_LAMB) { |
| #pragma unroll |
| for (int ii = 0; ii < ILP_LAMB; ii++) { |
| int load_pos = i + ii * blockDim.x; |
| if (load_pos < stop_pos && load_pos < kernel_params.sizes[tensor_id]) { |
| r_weight[ii] = has_mixed_precision ? |
| kernel_params.weights32[tensor_id][load_pos] : |
| static_cast<MPDType>(kernel_params.weights[tensor_id][load_pos]); |
| r_grad[ii] = static_cast<MPDType>(kernel_params.grads[tensor_id][load_pos]); |
| r_mean[ii] = kernel_params.mean[tensor_id][load_pos]; |
| r_var[ii] = kernel_params.var[tensor_id][load_pos]; |
| } else { |
| r_weight[ii] = static_cast<MPDType>(0); |
| r_grad[ii] = static_cast<MPDType>(0); |
| r_mean[ii] = static_cast<MPDType>(0); |
| r_var[ii] = static_cast<MPDType>(0); |
| } |
| } |
| #pragma unroll |
| for (int ii = 0; ii < ILP_LAMB; ii++) { |
| r_grad[ii] = (r_grad[ii] * rescale_grad) / g_norm; |
| if (clip_gradient >= 0.0f) |
| r_grad[ii] = max(min(r_grad[ii], clip_gradient), -clip_gradient); |
| r_mean[ii] = static_cast<MPDType>(beta1) * r_mean[ii] + beta3 * r_grad[ii]; |
| r_var[ii] = static_cast<MPDType>(beta2) * r_var[ii] + beta4 * r_grad[ii] * r_grad[ii]; |
| MPDType r_var_hat = sqrt(r_var[ii] / biascorrection2) + static_cast<MPDType>(epsilon); |
| r_m[ii] = (r_mean[ii] / biascorrection1) / r_var_hat; |
| r_g[ii] = r_grad[ii] / r_var_hat; |
| r_m[ii] = __fmaf_rn(kernel_params.wds[tensor_id], r_weight[ii], r_m[ii]); |
| r_g[ii] = __fmaf_rn(kernel_params.wds[tensor_id], r_weight[ii], r_g[ii]); |
| } |
| #pragma unroll |
| for (int ii = 0; ii < ILP_LAMB; ii++) { |
| int store_pos = i + ii * blockDim.x; |
| if (store_pos < stop_pos && store_pos < kernel_params.sizes[tensor_id]) { |
| kernel_params.mean[tensor_id][store_pos] = r_mean[ii]; |
| kernel_params.var[tensor_id][store_pos] = r_var[ii]; |
| temp_m[kernel_params.tensor2temp_g[tensor_id] + store_pos] = r_m[ii]; |
| temp_g[kernel_params.tensor2temp_g[tensor_id] + store_pos] = r_g[ii]; |
| } |
| } |
| } |
| } |
| |
| template <bool has_mixed_precision, typename MPDType, typename DType> |
| __global__ void KernelStep2(const MultiLANSKernelParam<DType, MPDType> kernel_params, |
| const float beta1, |
| const MPDType beta3, |
| const float* sum_sq_weigths, |
| const float* sum_sq_temp_m, |
| const float* sum_sq_temp_g, |
| const float* temp_m, |
| const float* temp_g, |
| const float lower_bound, |
| const float upper_bound, |
| int* block_to_tensor, |
| int* block_to_chunk, |
| const OpReqType req) { |
| const int tensor_id = block_to_tensor[blockIdx.x]; |
| const int chunck_id = block_to_chunk[blockIdx.x]; |
| const int start_pos = chunck_id * kernel_params.chunk_size + threadIdx.x; |
| const int stop_pos = chunck_id * kernel_params.chunk_size + kernel_params.chunk_size; |
| |
| MPDType r1 = sqrtf(sum_sq_weigths[tensor_id]); |
| MPDType r2_m = sqrtf(sum_sq_temp_m[tensor_id]); |
| MPDType r2_g = sqrtf(sum_sq_temp_g[tensor_id]); |
| if (lower_bound >= 0) |
| r1 = max(r1, lower_bound); |
| if (upper_bound >= 0) |
| r1 = min(r1, upper_bound); |
| |
| MPDType lr_adjusted_m, lr_adjusted_g; |
| if (r1 == 0.0f || r2_m == 0.0f) |
| lr_adjusted_m = kernel_params.learning_rates[tensor_id]; |
| else |
| lr_adjusted_m = kernel_params.learning_rates[tensor_id] * r1 / r2_m; |
| if (r1 == 0.0f || r2_g == 0.0f) |
| lr_adjusted_g = kernel_params.learning_rates[tensor_id]; |
| else |
| lr_adjusted_g = kernel_params.learning_rates[tensor_id] * r1 / r2_g; |
| lr_adjusted_m *= static_cast<MPDType>(beta1); |
| lr_adjusted_g *= beta3; |
| |
| MPDType r_weight[ILP_LAMB]; |
| MPDType r_m[ILP_LAMB]; |
| MPDType r_g[ILP_LAMB]; |
| |
| for (size_t i = start_pos; i < stop_pos && i < kernel_params.sizes[tensor_id]; |
| i += blockDim.x * ILP_LAMB) { |
| #pragma unroll |
| for (int ii = 0; ii < ILP_LAMB; ii++) { |
| int load_pos = i + ii * blockDim.x; |
| if (load_pos < stop_pos && load_pos < kernel_params.sizes[tensor_id]) { |
| r_weight[ii] = has_mixed_precision ? |
| kernel_params.weights32[tensor_id][load_pos] : |
| static_cast<MPDType>(kernel_params.weights[tensor_id][load_pos]); |
| r_m[ii] = temp_m[kernel_params.tensor2temp_g[tensor_id] + load_pos]; |
| r_g[ii] = temp_g[kernel_params.tensor2temp_g[tensor_id] + load_pos]; |
| } |
| } |
| #pragma unroll |
| for (int ii = 0; ii < ILP_LAMB; ii++) { |
| r_weight[ii] -= lr_adjusted_m * r_m[ii] + lr_adjusted_g * r_g[ii]; |
| } |
| #pragma unroll |
| for (int ii = 0; ii < ILP_LAMB; ii++) { |
| int store_pos = i + ii * blockDim.x; |
| if (store_pos < stop_pos && store_pos < kernel_params.sizes[tensor_id]) { |
| if (has_mixed_precision) |
| kernel_params.weights32[tensor_id][store_pos] = r_weight[ii]; |
| KERNEL_ASSIGN(kernel_params.out_data[tensor_id][store_pos], req, r_weight[ii]); |
| } |
| } |
| } |
| } |
| |
| template <typename MPDType, typename DType> |
| void CallKernel1(Stream<gpu>* s, |
| const MultiLANSKernelParam<DType, MPDType>& kernel_params, |
| const MultiLANSParam& param, |
| float* g_sq_norm, |
| float* temp_m, |
| float* temp_g, |
| int* block_to_tensor, |
| int* block_to_chunk) { |
| int nblocks = kernel_params.nchunks; |
| int* host_block2tensor = reinterpret_cast<int*>(malloc(kernel_params.nchunks * sizeof(int))); |
| int* host_block2chunk = reinterpret_cast<int*>(malloc(kernel_params.nchunks * sizeof(int))); |
| int chunk_id = 0; |
| for (size_t index = 0; index < kernel_params.ntensors; ++index) { |
| int current_chunk = 0; |
| for (size_t j = 0; j < kernel_params.sizes[index]; j += kernel_params.chunk_size) { |
| host_block2tensor[chunk_id] = index; |
| host_block2chunk[chunk_id] = current_chunk; |
| current_chunk++; |
| chunk_id++; |
| } |
| } |
| cudaMemcpyAsync(block_to_tensor, |
| host_block2tensor, |
| kernel_params.nchunks * sizeof(int), |
| cudaMemcpyHostToDevice, |
| Stream<gpu>::GetStream(s)); |
| cudaMemcpyAsync(block_to_chunk, |
| host_block2chunk, |
| kernel_params.nchunks * sizeof(int), |
| cudaMemcpyHostToDevice, |
| Stream<gpu>::GetStream(s)); |
| |
| bool has_mixed_precision = !std::is_same<DType, MPDType>::value; |
| MPDType beta3 = 1.0 - param.beta1; |
| MPDType beta4 = 1.0 - param.beta2; |
| |
| if (has_mixed_precision) |
| KernelStep1<true> |
| <<<nblocks, BLOCK_SIZE_LAMB, 0, Stream<gpu>::GetStream(s)>>>(kernel_params, |
| param.beta1, |
| param.beta2, |
| beta3, |
| beta4, |
| param.epsilon, |
| param.clip_gradient, |
| param.rescale_grad, |
| g_sq_norm, |
| temp_m, |
| temp_g, |
| block_to_tensor, |
| block_to_chunk); |
| else |
| KernelStep1<false> |
| <<<nblocks, BLOCK_SIZE_LAMB, 0, Stream<gpu>::GetStream(s)>>>(kernel_params, |
| param.beta1, |
| param.beta2, |
| beta3, |
| beta4, |
| param.epsilon, |
| param.clip_gradient, |
| param.rescale_grad, |
| g_sq_norm, |
| temp_m, |
| temp_g, |
| block_to_tensor, |
| block_to_chunk); |
| } |
| |
| template <typename MPDType, typename DType> |
| void CallKernel2(Stream<gpu>* s, |
| const MultiLANSKernelParam<DType, MPDType>& kernel_params, |
| const MultiLANSParam& param, |
| float* r1, |
| float* r2_m, |
| float* r2_g, |
| float* temp_m, |
| float* temp_g, |
| int* block_to_tensor, |
| int* block_to_chunk, |
| const OpReqType req) { |
| size_t nblocks = kernel_params.nchunks; |
| bool has_mixed_precision = !std::is_same<DType, MPDType>::value; |
| MPDType beta3 = 1.0 - param.beta1; |
| |
| if (has_mixed_precision) |
| KernelStep2<true><<<nblocks, BLOCK_SIZE_LAMB, 0, Stream<gpu>::GetStream(s)>>>(kernel_params, |
| param.beta1, |
| beta3, |
| r1, |
| r2_m, |
| r2_g, |
| temp_m, |
| temp_g, |
| param.lower_bound, |
| param.upper_bound, |
| block_to_tensor, |
| block_to_chunk, |
| req); |
| else |
| KernelStep2<false> |
| <<<nblocks, BLOCK_SIZE_LAMB, 0, Stream<gpu>::GetStream(s)>>>(kernel_params, |
| param.beta1, |
| beta3, |
| r1, |
| r2_m, |
| r2_g, |
| temp_m, |
| temp_g, |
| param.lower_bound, |
| param.upper_bound, |
| block_to_tensor, |
| block_to_chunk, |
| req); |
| } |
| |
| NNVM_REGISTER_OP(_multi_lans_update) |
| .set_attr<FCompute>("FCompute<gpu>", MultiLANSUpdate<gpu, false>); |
| |
| NNVM_REGISTER_OP(_multi_mp_lans_update) |
| .set_attr<FCompute>("FCompute<gpu>", MultiLANSUpdate<gpu, true>); |
| |
| } // namespace op |
| } // namespace mxnet |