| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from itertools import product |
| |
| import numpy as np |
| import pytest |
| |
| import tvm |
| import tvm.testing |
| from tvm.script import tir as T |
| |
| try: |
| from ml_dtypes import float4_e2m1fn |
| |
| ML_DTYPES_AVAILABLE = True |
| except ImportError: |
| ML_DTYPES_AVAILABLE = False |
| |
| |
| @pytest.mark.parametrize("promoted_dtype", ["float32x2", "float16x2"]) |
| @tvm.testing.requires_cuda_compute_version(10) |
| def test_e2m1_vector_conversions(promoted_dtype): |
| native_dtype = "float4_e2m1fnx2" |
| vector_length = 64 |
| |
| @T.prim_func |
| def add( |
| A: T.Buffer((vector_length,), native_dtype), |
| B: T.Buffer((vector_length,), native_dtype), |
| C: T.Buffer((vector_length,), native_dtype), |
| ): |
| T.func_attr({"tir.noalias": True}) |
| for i in range(vector_length): |
| with T.block("C"): |
| v_i = T.axis.spatial(vector_length, i) |
| T.reads(A[v_i], B[v_i]) |
| T.writes(C[v_i]) |
| C[v_i] = T.Cast( |
| native_dtype, T.Cast(promoted_dtype, A[v_i]) + T.Cast(promoted_dtype, B[v_i]) |
| ) |
| |
| sch = tvm.tir.Schedule(add) |
| block = sch.get_block("C") |
| b = sch.get_loops(block) |
| bx, tx = sch.split(b[0], factors=[None, 32]) |
| sch.bind(bx, "blockIdx.x") |
| sch.bind(tx, "threadIdx.x") |
| |
| target = "cuda" |
| fadd = tvm.compile(sch.mod, target=target) |
| dev = tvm.device(target, 0) |
| |
| if "x" in native_dtype: |
| lanes = int(native_dtype.split("x")[-1]) |
| else: |
| lanes = 1 |
| |
| if "x" in promoted_dtype: |
| promoted_base_dtype = promoted_dtype.split("x")[0] |
| else: |
| promoted_base_dtype = promoted_dtype |
| |
| np_shape = (vector_length, lanes) if lanes > 1 else (vector_length,) |
| |
| # Create test data - either using ml_dtypes if available, or using int8 with valid FP4 values |
| if ML_DTYPES_AVAILABLE: |
| a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(float4_e2m1fn) |
| b_np = np.random.uniform(low=0, high=5, size=np_shape).astype(float4_e2m1fn) |
| else: |
| # float4_e2m1fn possible values: [0, 0.5, 1, 1.5, 2, 3, 4, 6] |
| # We will create int8 arrays with valid FP4 bit patterns |
| valid_fp4_values = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] # 4-bit values |
| a_np = np.random.choice(valid_fp4_values, size=np_shape).astype(np.int8) |
| b_np = np.random.choice(valid_fp4_values, size=np_shape).astype(np.int8) |
| |
| a = tvm.runtime.empty(shape=(vector_length,), dtype=native_dtype, device=dev) |
| a.copyfrom(a_np) |
| b = tvm.runtime.empty(shape=(vector_length,), dtype=native_dtype, device=dev) |
| b.copyfrom(b_np) |
| c = tvm.runtime.empty(shape=(vector_length,), dtype=native_dtype, device=dev) |
| fadd(a, b, c) |
| |
| # For the comparison, we will convert result to the promoted dtype and compare |
| # Note: When ml_dtypes is not available, we skip the numpy-level computation comparison |
| # and just verify that the CUDA kernel compiles and executes without error |
| c_result = c.numpy().astype(promoted_base_dtype) |
| |
| if ML_DTYPES_AVAILABLE: |
| # Full comparison when ml_dtypes is available |
| expected = (a_np + b_np).astype(promoted_base_dtype) |
| tvm.testing.assert_allclose(c_result, expected) |
| else: |
| # When ml_dtypes is not available, we just verify the comparison ran successfully |
| # by checking that we got a result with the expected shape and dtype |
| assert c_result.shape == np_shape |
| assert c_result.dtype == promoted_base_dtype |
| |
| |
| @tvm.testing.requires_cuda_compute_version(10) |
| def test_e2m1_dequantize(): |
| n = 128 |
| |
| dev = tvm.device("cuda", 0) |
| target = tvm.target.Target.from_device(dev) |
| num_elem_per_storage = 32 // 4 |
| |
| def get_reinterpret_mod(func_type, vector_length): |
| @T.prim_func |
| def shuffle_reinterpret( |
| A: T.Buffer((n // num_elem_per_storage,), "uint32"), |
| B: T.Buffer((n,), "float16"), |
| ): |
| T.func_attr({"tir.noalias": True}) |
| for i in range(n): |
| with T.block("C"): |
| v_i = T.axis.spatial(n, i) |
| T.reads(A[v_i]) |
| T.writes(B[v_i]) |
| B[v_i] = T.Shuffle( |
| [ |
| T.reinterpret( |
| "float4_e2m1fnx2", |
| T.bitwise_and( |
| T.shift_right( |
| A[v_i // num_elem_per_storage], |
| ((v_i % num_elem_per_storage) // 2 * 4 * 2).astype( |
| "uint32" |
| ), |
| ), |
| T.uint32((1 << (4 * 2)) - 1), |
| ).astype("uint8"), |
| ).astype("float16x2") |
| ], |
| indices=[v_i % 2], |
| ) |
| |
| @T.prim_func |
| def scalar_reinterpret( |
| A: T.Buffer((n // num_elem_per_storage,), "uint32"), |
| B: T.Buffer((n,), "float16"), |
| ): |
| T.func_attr({"tir.noalias": True}) |
| for i in range(n): |
| with T.block("C"): |
| v_i = T.axis.spatial(n, i) |
| T.reads(A[v_i]) |
| T.writes(B[v_i]) |
| B[v_i] = T.reinterpret( |
| "float4_e2m1fn", |
| T.bitwise_and( |
| T.shift_right( |
| A[v_i // num_elem_per_storage], |
| (v_i % num_elem_per_storage * 4).astype("uint32"), |
| ), |
| T.uint32((1 << 4) - 1), |
| ).astype("uint8"), |
| ).astype("float16") |
| |
| func = shuffle_reinterpret if func_type == "shuffle" else scalar_reinterpret |
| sch = tvm.tir.Schedule(func) |
| block = sch.get_block("C") |
| b = sch.get_loops(block) |
| bx, tx, vec = sch.split(b[0], factors=[None, 32, vector_length]) |
| sch.bind(bx, "blockIdx.x") |
| sch.bind(tx, "threadIdx.x") |
| sch.vectorize(vec) |
| return sch.mod |
| |
| # We only test the whether the code can be compiled. |
| for func_type, vector_length in product(["shuffle", "scalar"], [1, 2, 4]): |
| if func_type == "shuffle" and vector_length == 1: |
| # Vectorize is necessary for shuffle. |
| continue |
| mod = get_reinterpret_mod(func_type, vector_length) |
| tvm.compile(mod, target=target) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |