blob: 0d7f08b0d80a85cd98def36990371a3039dd9d49 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file ptx.cc
*/
#include "ptx.h"
#include <tvm/ffi/reflection/registry.h>
#include <algorithm>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "../../../support/utils.h"
namespace tvm {
namespace codegen {
// PTX related data structures and functions.
namespace ptx {
/*!
* \brief PTX data type.
* \note
* PTX fundamental data types:
* https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types
* PTX matrix data types:
* https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types
*/
enum class DataType : int {
kInt4 = 0,
kUInt4 = 1,
kInt8 = 2,
kUInt8 = 3,
kInt16 = 4,
kUInt16 = 5,
kInt32 = 6,
kUInt32 = 7,
kInt64 = 8,
kUInt64 = 9,
kFloat4_e2m1fn = 10,
kFloat6_e2m3fn = 11,
kFloat6_e3m2fn = 12,
kFloat8_e4m3fn = 13,
kFloat8_e4m3fnuz = 14,
kFloat8_e5m2 = 15,
kFloat8_e8m0fnu = 16,
kFloat16 = 17,
kBFloat16 = 18,
kFloat16x2 = 19,
kFloat32 = 20,
kTensorFloat32 = 21,
kFloat64 = 22,
kBit1 = 23,
kBit8 = 24,
kBit16 = 25,
kBit32 = 26,
kBit64 = 27,
};
static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32",
".u32", ".s64", ".u64", ".e2m1", ".e2m3", ".e3m2", ".e4m3",
".ue4m3", ".e5m2", ".ue8m0", ".f16", ".bf16", ".f16x2", ".f32",
".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64"};
static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 4, 6, 6, 8,
7, 8, 8, 16, 16, 32, 32, 32, 64, 1, 8, 16, 32, 64};
/*!
* \brief Create PTX data type from string.
*/
inline DataType DTypeFromString(const std::string str) {
if (str == "int4" || str == ".s4") {
return DataType::kInt4;
} else if (str == "uint4" || str == ".u4") {
return DataType::kUInt4;
} else if (str == "int8" || str == ".s8") {
return DataType::kInt8;
} else if (str == "uint8" || str == ".u8") {
return DataType::kUInt8;
} else if (str == "int16" || str == ".s16") {
return DataType::kInt16;
} else if (str == "uint16" || str == ".u16") {
return DataType::kUInt16;
} else if (str == "int32" || str == ".s32") {
return DataType::kInt32;
} else if (str == "uint32" || str == ".u32") {
return DataType::kUInt32;
} else if (str == "int64" || str == ".s64") {
return DataType::kInt64;
} else if (str == "uint64" || str == ".u64") {
return DataType::kUInt64;
} else if (str == "e2m1" || str == ".e2m1" || str == "float4_e2m1fn") {
return DataType::kFloat4_e2m1fn;
} else if (str == "e2m3" || str == ".e2m3" || str == "float6_e2m3fn") {
return DataType::kFloat6_e2m3fn;
} else if (str == "e3m2" || str == ".e3m2" || str == "float6_e3m2fn") {
return DataType::kFloat6_e3m2fn;
} else if (str == "e4m3" || str == ".e4m3" || str == "float8_e4m3fn") {
return DataType::kFloat8_e4m3fn;
} else if (str == "float8_e4m3fnuz" || str == "float8_e4m3b11fnuz") {
return DataType::kFloat8_e4m3fnuz;
} else if (str == "e5m2" || str == ".e5m2" || str == "float8_e5m2" || str == "float8_e5m2fn" ||
str == "float8_e5m2fnuz") {
return DataType::kFloat8_e5m2;
} else if (str == "ue8m0" || str == ".ue8m0" || str == "float8_e8m0fnu") {
return DataType::kFloat8_e8m0fnu;
} else if (str == "float16" || str == "fp16" || str == ".f16") {
return DataType::kFloat16;
} else if (str == "bfloat16" || str == "bf16") {
return DataType::kBFloat16;
} else if (str == ".f16x2") {
return DataType::kFloat16x2;
} else if (str == "float32" || str == "fp32" || str == ".f32") {
return DataType::kFloat32;
} else if (str == "tf32") {
return DataType::kTensorFloat32;
} else if (str == "float64" || str == "fp64" || str == ".f64") {
return DataType::kFloat64;
} else if (str == "int1" || str == ".b1") {
return DataType::kBit1;
} else if (str == ".b8") {
return DataType::kBit8;
} else if (str == ".b16") {
return DataType::kBit16;
} else if (str == ".b32") {
return DataType::kBit32;
} else if (str == ".b64") {
return DataType::kBit64;
} else {
TVM_FFI_THROW(InternalError) << "Unrecognized PTX data type " << str;
}
}
/*!
* \brief Get the string representation of given PTX data type.
*/
inline std::string DTypeToString(DataType dtype) { return dtype_str[static_cast<int>(dtype)]; }
void RegisterCudaPTXHelpers() {
static bool registered = false;
if (registered) return;
registered = true;
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("tirx.intrinsics.cuda.PTXDTypeFromString",
[](const std::string& str) -> int { return static_cast<int>(DTypeFromString(str)); })
.def("tirx.intrinsics.cuda.PTXDTypeToString", [](const int dtype) -> std::string {
return DTypeToString(static_cast<DataType>(dtype));
});
}
TVM_FFI_STATIC_INIT_BLOCK() { RegisterCudaPTXHelpers(); }
/*!
* \brief Get the number of bits of given PTX data type.
*/
inline uint32_t DTypeBits(DataType dtype) { return num_bits[static_cast<int>(dtype)]; }
/*!
* \brief Extract the value m, n, k from string m*n*k*
*/
inline std::tuple<int, int, int> ParseMMAShape(const std::string& str) {
size_t pos_m = str.find("m"), pos_n = str.find("n"), pos_k = str.find("k");
TVM_FFI_ICHECK(pos_m != str.npos && pos_n != str.npos && pos_k != str.npos)
<< "Cannot parse MMA shape " << str;
int m = std::stoi(str.substr(pos_m + 1, pos_n - pos_m - 1)),
n = std::stoi(str.substr(pos_n + 1, pos_k - pos_n - 1)), k = std::stoi(str.substr(pos_k + 1));
return std::make_tuple(m, n, k);
}
/*!
* \brief Layout Type
*/
enum class LayoutType : int { kRowMajor = 0, kColumnMajor = 1 };
/*!
* \brief Parse layout type
*/
LayoutType LayoutTypeFromString(const std::string& str) {
if (str == "row") {
return LayoutType::kRowMajor;
} else if (str == "col") {
return LayoutType::kColumnMajor;
} else {
TVM_FFI_THROW(InternalError) << "Unrecognized layout type " << str;
}
}
static const char* layout_type_str[] = {"row", "col"};
/*!
* \brief Convert layout type to string.
*/
inline std::string LayoutTypeToString(LayoutType layout) {
return layout_type_str[static_cast<int>(layout)];
}
/*!
* \brief MMA Configurations, used to determine validity.
*/
struct MMAConfig {
explicit MMAConfig(int m, int n, int k, DataType dtype_mul, bool use_bit_op, bool sparse)
: m(m), n(n), k(k), dtype_mul(dtype_mul), use_bit_op(use_bit_op), sparse(sparse) {}
int m, n, k;
DataType dtype_mul;
bool use_bit_op;
bool sparse;
inline bool operator==(const MMAConfig& other) {
return m == other.m && n == other.n && k == other.k && dtype_mul == other.dtype_mul &&
use_bit_op == other.use_bit_op && sparse == other.sparse;
}
};
/*!
* \brief Valid MMA configurations
* \note Reference:
* https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-shape
*/
const MMAConfig valid_mma_configs[] = {
MMAConfig(8, 8, 4, DataType::kFloat64, false, false),
MMAConfig(8, 8, 4, DataType::kFloat16, false, false),
MMAConfig(16, 8, 8, DataType::kFloat16, false, false),
MMAConfig(16, 8, 16, DataType::kFloat16, false, false),
MMAConfig(16, 8, 8, DataType::kBFloat16, false, false),
MMAConfig(16, 8, 16, DataType::kBFloat16, false, false),
MMAConfig(16, 8, 4, DataType::kTensorFloat32, false, false),
MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, false),
MMAConfig(8, 8, 16, DataType::kInt8, false, false),
MMAConfig(16, 8, 16, DataType::kInt8, false, false),
MMAConfig(16, 8, 32, DataType::kInt8, false, false),
MMAConfig(8, 8, 16, DataType::kUInt8, false, false),
MMAConfig(16, 8, 16, DataType::kUInt8, false, false),
MMAConfig(16, 8, 32, DataType::kUInt8, false, false),
MMAConfig(8, 8, 32, DataType::kInt4, false, false),
MMAConfig(16, 8, 32, DataType::kInt4, false, false),
MMAConfig(16, 8, 64, DataType::kInt4, false, false),
MMAConfig(8, 8, 32, DataType::kUInt4, false, false),
MMAConfig(16, 8, 32, DataType::kUInt4, false, false),
MMAConfig(16, 8, 64, DataType::kUInt4, false, false),
MMAConfig(8, 8, 128, DataType::kBit1, true, false),
MMAConfig(16, 8, 128, DataType::kBit1, true, false),
MMAConfig(16, 8, 256, DataType::kBit1, true, false),
MMAConfig(16, 8, 16, DataType::kFloat16, false, true),
MMAConfig(16, 8, 32, DataType::kFloat16, false, true),
MMAConfig(16, 8, 16, DataType::kBFloat16, false, true),
MMAConfig(16, 8, 32, DataType::kBFloat16, false, true),
MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, true),
MMAConfig(16, 8, 16, DataType::kTensorFloat32, false, true),
MMAConfig(16, 8, 32, DataType::kInt8, false, true),
MMAConfig(16, 8, 64, DataType::kInt8, false, true),
MMAConfig(16, 8, 32, DataType::kUInt8, false, true),
MMAConfig(16, 8, 64, DataType::kUInt8, false, true),
MMAConfig(16, 8, 64, DataType::kInt4, false, true),
MMAConfig(16, 8, 128, DataType::kInt4, false, true),
MMAConfig(16, 8, 64, DataType::kUInt4, false, true),
MMAConfig(16, 8, 128, DataType::kUInt4, false, true),
MMAConfig(16, 8, 32, DataType::kFloat8_e4m3fn, false, false),
MMAConfig(16, 8, 64, DataType::kFloat8_e4m3fn, false, true),
MMAConfig(16, 8, 32, DataType::kFloat8_e5m2, false, false),
MMAConfig(16, 8, 64, DataType::kFloat8_e5m2, false, true),
};
/*!
* \brief Check whether the multiplicand data type and accumulator data type is valid for MMA
* computation.
* \param dtype_a The data type of multiplicand a.
* \param dtype_b The data type of multiplicand b.
* \param dtype_c The data type of accumulator c.
* \note Reference:
* https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types
*/
void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_c) {
std::string ab_not_match_err_str = "The multiplicands' data type " + DTypeToString(dtype_a) +
DTypeToString(dtype_b) + " do not match.";
// check a and b
switch (dtype_a) {
case DataType::kBit1:
case DataType::kFloat16:
case DataType::kBFloat16:
case DataType::kTensorFloat32:
case DataType::kFloat64:
TVM_FFI_ICHECK(dtype_a == dtype_b) << ab_not_match_err_str;
break;
case DataType::kInt4:
case DataType::kUInt4:
TVM_FFI_ICHECK(dtype_b == DataType::kInt4 || dtype_b == DataType::kUInt4)
<< ab_not_match_err_str;
break;
case DataType::kInt8:
case DataType::kUInt8:
TVM_FFI_ICHECK(dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8)
<< ab_not_match_err_str;
break;
case DataType::kFloat8_e4m3fn:
case DataType::kFloat8_e5m2:
TVM_FFI_ICHECK(dtype_b == DataType::kFloat8_e4m3fn || dtype_b == DataType::kFloat8_e5m2)
<< ab_not_match_err_str;
break;
default:
TVM_FFI_ICHECK(false) << "Invalid multiplicand data types: " << DTypeToString(dtype_a)
<< DTypeToString(dtype_b);
}
// check a,b and c
switch (dtype_a) {
case DataType::kBit1:
case DataType::kInt4:
case DataType::kUInt4:
case DataType::kInt8:
case DataType::kUInt8:
TVM_FFI_ICHECK(dtype_c == DataType::kInt32)
<< "For multiplicand data type " << DTypeToString(dtype_a) << DTypeToString(dtype_b)
<< ", accumulator data type should be s32.";
break;
case DataType::kFloat16:
TVM_FFI_ICHECK(dtype_c == DataType::kFloat16 || dtype_c == DataType::kFloat32)
<< "For multiplicand data type f16, accumulator data type should be f16/f32.";
break;
case DataType::kBFloat16:
case DataType::kTensorFloat32:
TVM_FFI_ICHECK(dtype_c == DataType::kFloat32)
<< "For multiplicand data type bf16/tf32, accumulator data type can only be f32.";
break;
case DataType::kFloat64:
TVM_FFI_ICHECK(dtype_c == DataType::kFloat64)
<< "For multiplicand data type f64, accumulator data type can only be f64.";
break;
case DataType::kFloat8_e4m3fn:
case DataType::kFloat8_e5m2:
TVM_FFI_ICHECK(dtype_c == DataType::kFloat32)
<< "For multiplicand data type e4m3/e5m2, accumulator data type can only be f32.";
break;
default:
TVM_FFI_ICHECK(false) << "Invalid multiplicand/accumulator data types: "
<< DTypeToString(dtype_a) << DTypeToString(dtype_b)
<< DTypeToString(dtype_c) << ".";
}
}
/*!
* \brief Check whether the given configuration is valid for MMA computation.
* \param m The M in mMnNkK of MMA instructions.
* \param n The N in mMnNkK of MMA instructions.
* \param k The K in mMnNkK of MMA instructions.
* \param layout_a The layout of multiplicand A (row/col).
* \param layout_b The layout of multiplicand B (row/col).
* \param dtype_a The data type of multiplicand A.
* \param dtype_b The data type of multiplicand B.
* \param dtype_c The data type of accumulator C.
* \param bit_op The bit operator for 1-bit MMA computation, can be "xor"/"and" or ""(if it's not
* 1-bit MMA).
* \param sparse Whether it's Sparse MMA or not.
* \param saturate Whether saturate output or not.
*/
void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType layout_b,
DataType dtype_a, DataType dtype_b, DataType dtype_c,
const std::string& bit_op, bool sparse, bool saturate) {
TVM_FFI_ICHECK(bit_op == "xor" || bit_op == "and" || bit_op == "")
<< "Unrecognized 1-bit operation " << bit_op << " , can only be xor/and.";
bool use_bit_op = !bit_op.empty();
if (use_bit_op) {
TVM_FFI_ICHECK(dtype_a == DataType::kBit1)
<< "Bit operator is only compatible with 1-bit multiplicand.";
}
CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c);
if (saturate) {
TVM_FFI_ICHECK(dtype_a == DataType::kInt4 || dtype_a == DataType::kUInt4 ||
dtype_a == DataType::kInt8 || dtype_a == DataType::kUInt8)
<< "Output saturation only applicable to multiplicand type s4/u4/s8/u8.";
}
if (!(m == 8 && n == 8 && k == 4 && dtype_a == ptx::DataType::kFloat16)) {
// Only MMA on m8n8k4 for fp16 supports customized layouts.
TVM_FFI_ICHECK(layout_a == LayoutType::kRowMajor && layout_b == LayoutType::kColumnMajor)
<< "Invalid layout combination " << LayoutTypeToString(layout_a) << ","
<< LayoutTypeToString(layout_b) << ".";
}
MMAConfig config(m, n, k, dtype_a, use_bit_op, sparse);
bool match = false;
for (const MMAConfig& valid_config : valid_mma_configs) {
if (config == valid_config) {
match = true;
break;
}
}
TVM_FFI_ICHECK(match) << "Cannot find matched MMA configurations.";
}
/*!
* \brief Fragment attributes
*/
class FragAttrs {
public:
explicit FragAttrs(char reg_type, uint32_t size, std::string ptr_type)
: reg_type(reg_type), size(size), ptr_type(ptr_type) {}
/*! \brief PTX register type */
char reg_type;
/*! \brief Fragment size */
uint32_t size;
/*! \brief Fragment pointer type */
std::string ptr_type;
};
/*!
* \brief Fragment attributes of given data type.
*/
inline FragAttrs GetFragAttrs(DataType dtype) {
switch (dtype) {
case DataType::kBit1:
case DataType::kInt4:
case DataType::kUInt4:
case DataType::kInt8:
case DataType::kUInt8:
case DataType::kFloat8_e4m3fn:
case DataType::kFloat8_e5m2:
case DataType::kBit16:
case DataType::kFloat16: // .f16x2 register
case DataType::kBFloat16:
case DataType::kTensorFloat32:
return FragAttrs('r', 32, "(unsigned *)");
case DataType::kInt32:
return FragAttrs('r', 32, "(int *)");
case DataType::kFloat32:
return FragAttrs('f', 32, "(float *)");
case DataType::kFloat64:
return FragAttrs('d', 64, "(double *)");
default:
TVM_FFI_ICHECK(false) << DTypeToString(dtype) << " is not matrix data type in MMA.";
return FragAttrs('\0', 0, "");
}
}
}; // namespace ptx
/*!
* \brief Replace patterns with replacement strings.
* \note should use std::format instead when codebase is ported to C++20.
*/
class Replacer {
public:
void register_rule(const std::string& pattern, const std::string& replacement) {
_rules.emplace_back(pattern, replacement);
}
std::string rewrite(std::string str) {
for (auto&& rule : _rules) {
auto [pattern, replacement] = rule;
size_t len = pattern.size();
size_t new_len = replacement.size();
size_t pos = str.find(pattern);
while (pos != std::string::npos) {
str = str.replace(pos, len, replacement);
pos = str.find(pattern, pos + new_len);
}
}
return str;
}
void empty_rules() { _rules.clear(); }
private:
std::vector<std::pair<std::string, std::string>> _rules;
};
/*!
* \brief Get the number of MMA computations for given shape and datatype.
*/
inline uint32_t GetNumMMAComputations(int m, int n, int k, ptx::DataType dtype) {
if (m == 8 && n == 8 && k == 4 && dtype == ptx::DataType::kFloat16) {
// MMA for m8n8k4 on fp16 would launch 4 MMA computations instead of one.
return 4;
} else {
return 1;
}
}
/*!
* \brief Return template string, input operands string and output operands string.
* \param m The M in mMnNkK of MMA instructions.
* \param n The N in mMnNkK of MMA instructions.
* \param k The K in mMnNkK of MMA instructions.
* \param dtype_a The data type of multiplicand a.
* \param dtype_b The data type of multiplicand b.
* \param dtype_c The data type of accumulator c.
* \param sparse Whether it's Sparse MMA or not.
*/
inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, int n, int k,
ptx::DataType dtype_a,
ptx::DataType dtype_b,
ptx::DataType dtype_c,
bool sparse) {
std::stringstream templates, inputs, outputs;
const ptx::FragAttrs frag_attr_a = ptx::GetFragAttrs(dtype_a),
frag_attr_b = ptx::GetFragAttrs(dtype_b),
frag_attr_c = ptx::GetFragAttrs(dtype_c);
constexpr uint32_t warp_size = 32;
const uint32_t threads = warp_size / GetNumMMAComputations(m, n, k, dtype_a);
const int num_operands_a =
(m * k) * ptx::DTypeBits(dtype_a) / frag_attr_a.size / threads / (sparse ? 2 : 1),
num_operands_b = (k * n) * ptx::DTypeBits(dtype_b) / frag_attr_b.size / threads,
num_operands_c = (m * n) * ptx::DTypeBits(dtype_c) / frag_attr_c.size / threads;
// generate templates;
int arg_counter = 0;
templates << "{"
<< "%" << arg_counter++;
for (int i = 1; i < num_operands_c; ++i) {
templates << ", %" << arg_counter++;
}
templates << "}, {"
<< "%" << arg_counter++;
for (int i = 1; i < num_operands_a; ++i) {
templates << ", %" << arg_counter++;
}
templates << "}, {"
<< "%" << arg_counter++;
for (int i = 1; i < num_operands_b; ++i) {
templates << ", %" << arg_counter++;
}
templates << "}, {"
<< "%" << arg_counter++;
for (int i = 1; i < num_operands_c; ++i) {
templates << ", %" << arg_counter++;
}
templates << "}";
// templates of metadata and sparse selector for sparse mma.
if (sparse) {
templates << ", %" << (arg_counter++) << ", F";
}
// generate inputs
for (int i = 0; i < num_operands_a; ++i) {
if (i != 0) {
inputs << ", ";
}
inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_type << "(A))[" << i
<< "])";
}
for (int i = 0; i < num_operands_b; ++i) {
inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_type << "(B))[" << i
<< "])";
}
for (int i = 0; i < num_operands_c; ++i) {
inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(C))[" << i
<< "])";
}
// input of metadata for sparse mma.
if (sparse) {
inputs << ", \"r\"(((unsigned *)(E))[0])";
}
// generate outputs
for (int i = 0; i < num_operands_c; ++i) {
if (i != 0) {
outputs << ",";
}
outputs << " \"=" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(D))[" << i
<< "])";
}
return std::make_tuple(templates.str(), inputs.str(), outputs.str());
}
// ldmatrix assembly emitter.
// `local_elem_offset` / `smem_elem_offset` are element offsets in the
// respective buffer's dtype; the generated C expression `ptr + offset`
// relies on C pointer arithmetic to scale them to bytes.
inline std::tuple<std::string, std::string> GetLoadMatrixOperands(
int num, const std::string& local_ptr, const std::string& local_elem_offset) {
std::stringstream templates, outputs;
int arg_counter = 0;
templates << "{%" << arg_counter++;
for (int i = 1; i < num; ++i) {
templates << ", %" << arg_counter++;
}
templates << "}, [%" << arg_counter++ << "]";
std::string ptr_type = "(unsigned *)";
for (int i = 0; i < num; ++i) {
if (i != 0) outputs << ", ";
outputs << "\"=r\"((" << ptr_type << "(" << local_ptr << " + " << local_elem_offset << "))["
<< i << "])";
}
return std::make_tuple(templates.str(), outputs.str());
}
std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type,
const std::string& local_ptr,
const std::string& local_elem_offset,
const std::string& smem_ptr,
const std::string& smem_elem_offset) {
TVM_FFI_ICHECK(num == 1 || num == 2 || num == 4)
<< "ldmatrix only accept loading 1/2/4 matrices.";
ptx::DataType data_type = ptx::DTypeFromString(type);
TVM_FFI_ICHECK(data_type == ptx::DataType::kBit16)
<< "ldmatrix only accept matrix with type .b16.";
std::string asm_code = R"(
{
unsigned int addr = __cvta_generic_to_shared({smem_addr});
__asm__ __volatile__(
"ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}"
"{templates};\n"
: {outputs}
: "r"(addr)
);
}
)";
auto [templates_str, outputs_str] = GetLoadMatrixOperands(num, local_ptr, local_elem_offset);
Replacer replacer;
replacer.register_rule("{.shape}", ".m8n8");
replacer.register_rule("{.num}", ".x" + std::to_string(num));
replacer.register_rule("{.trans}", trans ? ".trans" : "");
replacer.register_rule("{.ss}", ".shared");
replacer.register_rule("{.type}", ptx::DTypeToString(data_type));
replacer.register_rule("{smem_addr}", smem_ptr + " + " + smem_elem_offset);
replacer.register_rule("{templates}", templates_str);
replacer.register_rule("{outputs}", outputs_str);
asm_code = replacer.rewrite(asm_code);
return asm_code;
}
std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout,
const std::string& B_layout, const std::string& A_dtype,
const std::string& B_dtype, const std::string& C_dtype,
const std::string& a_ptr, const std::string& a_elem_offset,
const std::string& b_ptr, const std::string& b_elem_offset,
const std::string& c_ptr, const std::string& c_elem_offset,
const std::string& metadata, const std::string& metadata_offset,
const std::string& sparsity_selector, const std::string& bit_op,
bool sparse, bool saturate) {
ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype), dtype_b = ptx::DTypeFromString(B_dtype),
dtype_c = ptx::DTypeFromString(C_dtype);
ptx::LayoutType layout_a = ptx::LayoutTypeFromString(A_layout),
layout_b = ptx::LayoutTypeFromString(B_layout);
auto [m, n, k] = ptx::ParseMMAShape(shape);
CheckMMAConfigValidity(m, n, k, layout_a, layout_b, dtype_a, dtype_b, dtype_c, bit_op, sparse,
saturate);
std::string asm_code = R"(
{
__asm__ __volatile__(
"mma{.sparse}.sync.aligned{.shape}{.alayout}{.blayout}{.saturate}{.dtype}{.atype}{.btype}{.ctype}{.bitop}"
"{templates};\n"
: {outputs}
: {inputs});
}
)";
auto [templates_str, inputs_str, outputs_str] =
GetMMAOperands(m, n, k, dtype_a, dtype_b, dtype_c, sparse);
// replace patterns
Replacer replacer;
replacer.register_rule("{.sparse}", sparse ? ".sp" : "");
replacer.register_rule("{.shape}", "." + shape);
replacer.register_rule("{.saturate}", saturate ? ".satfinite" : "");
replacer.register_rule("{.alayout}", "." + A_layout);
replacer.register_rule("{.blayout}", "." + B_layout);
replacer.register_rule("{.atype}", ptx::DTypeToString(dtype_a));
replacer.register_rule("{.btype}", ptx::DTypeToString(dtype_b));
replacer.register_rule("{.ctype}", ptx::DTypeToString(dtype_c));
replacer.register_rule("{.dtype}", ptx::DTypeToString(dtype_c));
replacer.register_rule("{.bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc");
replacer.register_rule("{templates}", templates_str);
replacer.register_rule("{outputs}", outputs_str);
replacer.register_rule("{inputs}", inputs_str);
asm_code = replacer.rewrite(asm_code);
replacer.empty_rules();
replacer.register_rule("A", a_ptr + " + " + a_elem_offset);
replacer.register_rule("B", b_ptr + " + " + b_elem_offset);
replacer.register_rule("C", c_ptr + " + " + c_elem_offset);
replacer.register_rule("D", c_ptr + " + " + c_elem_offset);
replacer.register_rule("E", metadata + " + " + metadata_offset);
replacer.register_rule("F", sparsity_selector);
asm_code = replacer.rewrite(asm_code);
return asm_code;
}
std::string PrintCpAsyncBulkAsm(const std::string& shared_ptr,
const std::string& shared_elem_offset,
const std::string& global_ptr,
const std::string& global_elem_offset, const std::string& bytes,
const std::string& barrier) {
std::string asm_code = R"(
{
unsigned int smem_addr_int = __cvta_generic_to_shared({smem_addr});
unsigned int barrier_addr_int = __cvta_generic_to_shared({barrier});
__asm__ __volatile__(
"cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];"
:: "r"(smem_addr_int), "l"({global_ptr}), "r"({bytes}), "r"(barrier_addr_int)
: "memory"
);
}
)";
Replacer replacer;
replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset);
replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset);
replacer.register_rule("{bytes}", bytes);
replacer.register_rule("{barrier}", "&" + barrier);
asm_code = replacer.rewrite(asm_code);
return asm_code;
}
std::string PrintCpAsyncBarrierAsm(const std::string& barrier) {
std::string predicated_asm_code = R"(
{
unsigned int barrier_addr_int = __cvta_generic_to_shared({barrier});
__asm__ __volatile__(
"cp.async.mbarrier.arrive.shared.b64 [%0];"
:: "r" (barrier_addr_int)
);
}
)";
Replacer replacer;
replacer.register_rule("{barrier}", "&" + barrier);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
}
} // namespace codegen
} // namespace tvm