blob: c17709cfaa7b0ef7956d26a61f15876b3fa3e42e [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for T.cuda.cta_reduce / cta_sum / cta_max / cta_min intrinsics."""
import numpy as np
import pytest
import tvm
from tvm.script import tirx as Tx
DEV = tvm.cuda(0)
TARGET = tvm.target.Target("cuda")
def _build_and_run(func, n):
mod = tvm.IRModule({"main": func})
mod = tvm.compile(mod, target=TARGET, tir_pipeline="tirx")
out_np = np.zeros(n, dtype="float32")
out = tvm.runtime.tensor(out_np, device=DEV)
mod(out)
return out.numpy(), mod
def test_cta_sum_4_warps():
"""CTA sum with 4 warps (128 threads): all threads get the same sum."""
NUM_WARPS = 4
N = NUM_WARPS * 32
# fmt: off
@Tx.prim_func
def func(out_ptr: Tx.handle):
out = Tx.match_buffer(out_ptr, (N,), "float32")
Tx.device_entry()
cta_id = Tx.cta_id([1])
warp_id = Tx.warp_id([NUM_WARPS])
lane_id = Tx.lane_id([32])
tid = Tx.thread_id([N])
with Tx.cta():
scratch = Tx.alloc_buffer((NUM_WARPS,), "float32", scope="shared")
with Tx.thread():
val: Tx.f32 = Tx.float32(tid + 1)
val = Tx.cuda.cta_sum(val, NUM_WARPS, scratch.ptr_to([0]))
out[tid] = val
# fmt: on
result, mod = _build_and_run(func, N)
expected = np.float32(N * (N + 1) / 2) # sum(1..128)
np.testing.assert_allclose(result, np.full(N, expected))
assert "cta_reduce_sum_4" in mod.mod.imports[0].inspect_source()
def test_cta_sum_8_warps():
"""CTA sum with 8 warps (256 threads)."""
NUM_WARPS = 8
N = NUM_WARPS * 32
# fmt: off
@Tx.prim_func
def func(out_ptr: Tx.handle):
out = Tx.match_buffer(out_ptr, (N,), "float32")
Tx.device_entry()
cta_id = Tx.cta_id([1])
warp_id = Tx.warp_id([NUM_WARPS])
lane_id = Tx.lane_id([32])
tid = Tx.thread_id([N])
with Tx.cta():
scratch = Tx.alloc_buffer((NUM_WARPS,), "float32", scope="shared")
with Tx.thread():
val: Tx.f32 = Tx.float32(tid + 1)
val = Tx.cuda.cta_sum(val, NUM_WARPS, scratch.ptr_to([0]))
out[tid] = val
# fmt: on
result, _ = _build_and_run(func, N)
expected = np.float32(N * (N + 1) / 2)
np.testing.assert_allclose(result, np.full(N, expected))
def test_cta_max_4_warps():
"""CTA max with 4 warps: all threads get the maximum value."""
NUM_WARPS = 4
N = NUM_WARPS * 32
# fmt: off
@Tx.prim_func
def func(out_ptr: Tx.handle):
out = Tx.match_buffer(out_ptr, (N,), "float32")
Tx.device_entry()
cta_id = Tx.cta_id([1])
warp_id = Tx.warp_id([NUM_WARPS])
lane_id = Tx.lane_id([32])
tid = Tx.thread_id([N])
with Tx.cta():
scratch = Tx.alloc_buffer((NUM_WARPS,), "float32", scope="shared")
with Tx.thread():
val: Tx.f32 = Tx.float32(tid + 1)
val = Tx.cuda.cta_max(val, NUM_WARPS, scratch.ptr_to([0]))
out[tid] = val
# fmt: on
result, _ = _build_and_run(func, N)
np.testing.assert_allclose(result, np.full(N, float(N)))
def test_cta_min_4_warps():
"""CTA min with 4 warps: all threads get the minimum value."""
NUM_WARPS = 4
N = NUM_WARPS * 32
# fmt: off
@Tx.prim_func
def func(out_ptr: Tx.handle):
out = Tx.match_buffer(out_ptr, (N,), "float32")
Tx.device_entry()
cta_id = Tx.cta_id([1])
warp_id = Tx.warp_id([NUM_WARPS])
lane_id = Tx.lane_id([32])
tid = Tx.thread_id([N])
with Tx.cta():
scratch = Tx.alloc_buffer((NUM_WARPS,), "float32", scope="shared")
with Tx.thread():
val: Tx.f32 = Tx.float32(tid + 1)
val = Tx.cuda.cta_min(val, NUM_WARPS, scratch.ptr_to([0]))
out[tid] = val
# fmt: on
result, _ = _build_and_run(func, N)
np.testing.assert_allclose(result, np.full(N, 1.0))
def test_cta_sum_1_warp():
"""CTA sum with 1 warp: degenerates to a pure warp reduce."""
NUM_WARPS = 1
N = 32
# fmt: off
@Tx.prim_func
def func(out_ptr: Tx.handle):
out = Tx.match_buffer(out_ptr, (N,), "float32")
Tx.device_entry()
cta_id = Tx.cta_id([1])
warp_id = Tx.warp_id([NUM_WARPS])
lane_id = Tx.lane_id([32])
tid = Tx.thread_id([N])
with Tx.cta():
scratch = Tx.alloc_buffer((NUM_WARPS,), "float32", scope="shared")
with Tx.thread():
val: Tx.f32 = Tx.float32(tid + 1)
val = Tx.cuda.cta_sum(val, NUM_WARPS, scratch.ptr_to([0]))
out[tid] = val
# fmt: on
result, _ = _build_and_run(func, N)
expected = np.float32(32 * 33 / 2)
np.testing.assert_allclose(result, np.full(N, expected))
@pytest.mark.parametrize("num_warps", [1, 2, 4, 8, 16])
def test_cta_sum_all_warp_counts(num_warps):
"""Parametric test: cta_sum with various warp counts."""
N = num_warps * 32
# fmt: off
@Tx.prim_func
def func(out_ptr: Tx.handle):
out = Tx.match_buffer(out_ptr, (N,), "float32")
Tx.device_entry()
cta_id = Tx.cta_id([1])
warp_id = Tx.warp_id([num_warps])
lane_id = Tx.lane_id([32])
tid = Tx.thread_id([N])
with Tx.cta():
scratch = Tx.alloc_buffer((num_warps,), "float32", scope="shared")
with Tx.thread():
val: Tx.f32 = Tx.float32(tid + 1)
val = Tx.cuda.cta_sum(val, num_warps, scratch.ptr_to([0]))
out[tid] = val
# fmt: on
result, _ = _build_and_run(func, N)
expected = np.float32(N * (N + 1) / 2)
np.testing.assert_allclose(result, np.full(N, expected))