blob: 87cd8108f9ee27349483a4426afce9385d14b0f8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <fstream>
#include <iostream>
#include <sstream>
#include <type_traits>
#include <variant>
#include <vector>
#include "../../cuda/cuda_common.h"
// clang-format off
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
// clang-format on
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
CHECK(error == cutlass::Status::kSuccess) \
<< "Got cutlass error: " << cutlassGetStatusString(error); \
}
using namespace cute;
using tvm::runtime::Tensor;
template <typename TileShape, typename ClusterShape, typename ElementD>
struct CutlassFP8ScaledGroupwiseGemmRunnerSM100 {
using ElementA = cutlass::float_e4m3_t;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
using ElementB = cutlass::float_e4m3_t;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementC = void;
using LayoutC = cutlass::layout::RowMajor;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
using LayoutD = LayoutC;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// MMA type
using ElementAccumulator = float; // Element Accumulator will also be our scale factor type
using ElementCompute = float;
using ElementBlockScale = float;
static constexpr int ScaleGranularityM = 1;
static constexpr int ScaleGranularityN = 128;
static constexpr int ScaleGranularityK = 128;
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, UMMA::Major::MN, UMMA::Major::K>;
using LayoutSFA =
decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB =
decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute, ElementC,
LayoutC, AlignmentC, ElementD, LayoutC, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementA,
cute::tuple<LayoutA, LayoutSFA>, AlignmentA, ElementB, cute::tuple<LayoutB, LayoutSFB>,
AlignmentB, ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelScheduleSm100Blockwise>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideD = typename Gemm::GemmKernel::StrideD;
void run_gemm(const ElementA* a_ptr, const ElementB* b_ptr, const ElementBlockScale* scales_a_ptr,
const ElementBlockScale* scales_b_ptr, ElementD* o_ptr, int m, int n, int k, int l,
uint8_t* workspace, int64_t workspace_size, cudaStream_t stream) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
StrideA stride_a =
cute::make_stride(static_cast<int64_t>(k), Int<1>{}, static_cast<int64_t>(m * k));
StrideB stride_b =
cute::make_stride(static_cast<int64_t>(k), Int<1>{}, static_cast<int64_t>(n * k));
StrideD stride_d =
cute::make_stride(static_cast<int64_t>(n), Int<1>{}, static_cast<int64_t>(m * n));
auto layout_scales_a = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, l));
auto layout_scales_b = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, l));
typename Gemm::Arguments arguments = {cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, l},
{a_ptr, stride_a, b_ptr, stride_b, scales_a_ptr,
layout_scales_a, scales_b_ptr, layout_scales_b},
{{}, o_ptr, stride_d, o_ptr, stride_d},
hw_info};
Gemm gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(arguments));
CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
CUTLASS_CHECK(gemm_op.run(stream));
}
};
template <typename TileShape, typename ClusterShape, typename ElementA, typename ElementB,
typename ElementD, typename ElementBlockScale>
void cutlass_fp8_groupwise_scaled_mm_sm100(ElementA* a, ElementB* b, ElementBlockScale* scales_a,
ElementBlockScale* scales_b, ElementD* out,
uint8_t* workspace, int64_t workspace_size, int64_t m,
int64_t n, int64_t k, int64_t l, cudaStream_t stream) {
using Runner = CutlassFP8ScaledGroupwiseGemmRunnerSM100<TileShape, ClusterShape, ElementD>;
Runner runner;
runner.run_gemm(a, b, scales_a, scales_b, out, m, n, k, l, workspace, workspace_size, stream);
}