blob: 36ac0e3a439e411446aefc4c6826c4ea064e6cb2 [file] [log] [blame]
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <dlpack/dlpack.h>
#include <stdint.h>
#include <tvm/runtime/logging.h>
#include "custom_allreduce_kernels.h"
namespace tensorrt_llm {
static inline __device__ void st_flag_release(uint32_t& flag, uint32_t* flag_addr) {
#if __CUDA_ARCH__ >= 700
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
#else
__threadfence_system();
asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void ld_flag_acquire(uint32_t& flag, uint32_t* flag_addr) {
#if __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
#else
asm volatile("ld.global.volatile.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Type Converter that packs data format to 128 bits data type
//
using PackedFloat = union {
int4 packed;
float unpacked[4];
};
using PackedHalf = union {
int4 packed;
half2 unpacked[4];
};
template <typename T>
struct PackedOn16Bytes {};
template <>
struct PackedOn16Bytes<float> {
using Type = PackedFloat;
};
template <>
struct PackedOn16Bytes<half> {
using Type = PackedHalf;
};
#ifdef ENABLE_BF16
using PackedBFloat16 = union {
int4 packed;
__nv_bfloat162 unpacked[4];
};
template <>
struct PackedOn16Bytes<__nv_bfloat16> {
using Type = PackedBFloat16;
};
#endif
// add two 128b data
template <typename T>
inline __device__ int4 add128b(T& a, T& b) {
T c;
c.unpacked[0] = a.unpacked[0] + b.unpacked[0];
c.unpacked[1] = a.unpacked[1] + b.unpacked[1];
c.unpacked[2] = a.unpacked[2] + b.unpacked[2];
c.unpacked[3] = a.unpacked[3] + b.unpacked[3];
return c.packed;
}
__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, const uint32_t flag,
const size_t rank, const size_t world_size,
int const tidx, int const bidx) {
// At the end of the function, we now that has least block 0 from all others GPUs have reached
// that point.
uint32_t volatile* my_signals = signals[rank];
if (tidx < world_size) {
// The 1st block notifies the other ranks.
if (bidx == 0) {
signals[tidx][rank] = flag;
}
// Busy-wait until all ranks are ready.
while (my_signals[tidx] != flag) {
}
}
// Make sure we can move on...
__syncthreads();
}
__global__ void multiGpuBarrierKernel(AllReduceParams params) {
multi_gpu_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank,
params.ranks_per_node, threadIdx.x, blockIdx.x);
}
template <typename T, int RANKS_PER_NODE>
static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
int const bidx = blockIdx.x;
int const tidx = threadIdx.x;
// The number of elements packed into one for comms
static constexpr int NUM_ELTS = 16 / sizeof(T);
// Packed data type for comms
using PackedStruct = typename PackedOn16Bytes<T>::Type;
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank,
RANKS_PER_NODE, tidx, bidx);
// The source pointers. Distributed round-robin for the different warps.
T const* src_d[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
src_d[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
}
// The location in the destination array (load 8 fp16 or load 4 fp32 using LDG.128).
size_t offset = bidx * params.elts_per_block + tidx * NUM_ELTS;
// The end of the segment computed by that block.
size_t max_offset = min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
// Each block accumulates the values from the different GPUs on the same node.
for (size_t iter_offset = offset; iter_offset < max_offset;
iter_offset += blockDim.x * NUM_ELTS) {
// Iterate over the different ranks/devices on the node to load the values.
PackedStruct vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
vals[ii].packed = *reinterpret_cast<int4 const*>(&src_d[ii][iter_offset]);
}
// Sum the values from the different ranks.
PackedStruct sums;
sums.packed = {0, 0, 0, 0};
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
sums.packed = add128b(sums, vals[ii]);
}
// Store to the destination buffer.
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) =
sums.packed;
}
}
template <typename T, int RANKS_PER_NODE>
static __global__ void twoShotAllReduceKernel(AllReduceParams params) {
// The block index.
int const bidx = blockIdx.x;
// The thread index with the block.
int const tidx = threadIdx.x;
// The number of elements packed into one for comms
static constexpr int NUM_ELTS = 16 / sizeof(T);
// Packed data type for comms
using PackedType = typename PackedOn16Bytes<T>::Type;
// The location in the destination array (load 8 fp16 or load 4 fp32 using LDG.128).
const size_t block_offset = bidx * params.elts_per_block + tidx * NUM_ELTS;
const size_t block_start = params.rank_offset + block_offset;
// The end of the segment computed by that block.
size_t max_offset =
min(block_start + params.elts_per_block, params.rank_offset + params.elts_per_rank);
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank,
RANKS_PER_NODE, tidx, bidx);
// The source pointers. Distributed round-robin for the different warps.
T* src_d[RANKS_PER_NODE];
// The destination ranks for round-robin gathering
size_t dst_rank[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
src_d[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
dst_rank[ii] = rank;
}
// Each block accumulates the values from the different GPUs on the same node.
for (size_t local_offset = block_start; local_offset < max_offset;
local_offset += blockDim.x * NUM_ELTS) {
// Iterate over the different ranks/devices on the node to load the values.
PackedType vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
vals[ii].packed = *reinterpret_cast<int4 const*>(&src_d[ii][local_offset]);
}
// Sum the values from the different ranks.
PackedType sums;
sums.packed = {0, 0, 0, 0};
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
sums.packed = add128b(sums, vals[ii]);
}
// Store to the local buffer.
*reinterpret_cast<int4*>(&src_d[0][local_offset]) = sums.packed;
}
// sync threads to make sure all block threads have the sums
__syncthreads();
// barriers among the blocks with the same idx (release-acquire semantics)
if (tidx < RANKS_PER_NODE) {
// The all blocks notifies the other ranks.
uint32_t flag_block_offset = RANKS_PER_NODE + bidx * RANKS_PER_NODE;
st_flag_release(params.barrier_flag,
params.peer_barrier_ptrs_in[tidx] + flag_block_offset + params.local_rank);
// Busy-wait until all ranks are ready.
uint32_t rank_barrier = 0;
uint32_t* peer_barrier_d =
params.peer_barrier_ptrs_in[params.local_rank] + flag_block_offset + tidx;
do {
ld_flag_acquire(rank_barrier, peer_barrier_d);
} while (rank_barrier != params.barrier_flag);
}
// sync threads to make sure all other ranks has the final partial results
__syncthreads();
size_t max_block_offset = min(block_offset + params.elts_per_block, params.elts_per_rank);
// Gather all needed elts from other intra-node ranks
for (size_t local_offset = block_offset; local_offset < max_block_offset;
local_offset += blockDim.x * NUM_ELTS) {
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
// use round-robin gathering from other ranks
size_t offset_rank = dst_rank[ii] * params.elts_per_rank + local_offset;
if (offset_rank >= params.elts_total) {
continue;
}
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[offset_rank]) =
*reinterpret_cast<int4*>(&src_d[ii][offset_rank]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int divUp(int a, int b) { return (a + b - 1) / b; }
std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReduceParams& param,
size_t elts_per_thread) {
ICHECK(param.elts_total % elts_per_thread == 0);
int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
const size_t total_threads = param.elts_total / elts_per_thread;
switch (algo) {
case AllReduceStrategyType::ONESHOT: { // one stage all reduce algo
if (total_threads <= DEFAULT_BLOCK_SIZE) { // local reduce
threads_per_block = WARP_SIZE * divUp(total_threads, WARP_SIZE);
blocks_per_grid = 1;
} else { // local reduce
threads_per_block = DEFAULT_BLOCK_SIZE;
blocks_per_grid = divUp(total_threads, DEFAULT_BLOCK_SIZE);
blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS), blocks_per_grid);
}
param.elts_per_rank = param.elts_total;
param.elts_per_block =
elts_per_thread * divUp(param.elts_per_rank, elts_per_thread * blocks_per_grid);
break;
}
case AllReduceStrategyType::TWOSHOT: { // two stage all reduce algo
const size_t elts_per_rank = param.elts_total / param.ranks_per_node;
ICHECK(elts_per_rank % elts_per_thread == 0);
size_t total_threads = elts_per_rank / elts_per_thread;
total_threads = WARP_SIZE * ((total_threads + WARP_SIZE - 1) / WARP_SIZE);
ICHECK(total_threads % WARP_SIZE == 0);
while (total_threads % blocks_per_grid != 0 ||
total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) {
blocks_per_grid += 1;
}
threads_per_block = total_threads / blocks_per_grid;
// NOTE: need to adjust here
if (static_cast<size_t>(blocks_per_grid) > MAX_ALL_REDUCE_BLOCKS) {
size_t iter_factor = 1;
while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS ||
blocks_per_grid % iter_factor) {
iter_factor += 1;
}
blocks_per_grid /= iter_factor;
}
param.elts_per_rank = param.elts_total / param.ranks_per_node;
param.elts_per_block = param.elts_per_rank / blocks_per_grid;
param.elts_per_block = elts_per_thread * divUp(param.elts_per_block, elts_per_thread);
param.rank_offset = param.rank * param.elts_per_rank;
break;
}
default:
LOG(FATAL) << ("Algorithm not supported here.");
}
return std::make_tuple(blocks_per_grid, threads_per_block);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, int RANKS_PER_NODE>
void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid,
int threads_per_block, cudaStream_t stream) {
if (algo == AllReduceStrategyType::ONESHOT) {
oneShotAllReduceKernel<T, RANKS_PER_NODE>
<<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
} else {
twoShotAllReduceKernel<T, RANKS_PER_NODE>
<<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
}
}
template <typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat,
cudaStream_t stream) {
ICHECK(strat == AllReduceStrategyType::ONESHOT || strat == AllReduceStrategyType::TWOSHOT);
auto last_error = cudaGetLastError();
if (last_error != cudaSuccess) {
LOG(INFO) << "cuda error:" << cudaGetErrorString(last_error);
}
size_t elts_per_thread = 16 / sizeof(T);
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread);
switch (param.ranks_per_node) {
case 2:
dispatchARKernels<T, 2>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 4:
dispatchARKernels<T, 4>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 6:
dispatchARKernels<T, 6>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 8:
dispatchARKernels<T, 8>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
default:
break;
}
last_error = cudaGetLastError();
if (last_error != cudaSuccess) {
LOG(INFO) << "cuda error:" << cudaGetErrorString(last_error);
}
}
void invokeMultiGpuBarrier(AllReduceParams& param, cudaStream_t stream) {
multiGpuBarrierKernel<<<1, param.ranks_per_node, 0, stream>>>(param);
}
void customAllReduce(AllReduceParams& params, void* data, size_t elts, DLDataType dataType,
AllReduceStrategyType strat, cudaStream_t stream) {
params.local_output_buffer_ptr = data;
params.elts_total = elts;
if (dataType.code == kDLFloat && dataType.bits == 32) {
invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
} else if (dataType.code == kDLFloat && dataType.bits == 16) {
invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
}
#ifdef ENABLE_BF16
else if (dataType.code == kDLBfloat && dataType.bits == 16) {
invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
}
#endif
else {
LOG(FATAL) << ("Unsupported dataType for customAllReduce");
}
}
} // namespace tensorrt_llm