blob: ecd25b3a6749081caacfc85c7486fb51540ccf57 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test line-level debug info for TIR"""
import tvm
import tvm.testing
from tvm import tir
from tvm import relay
from tvm.script import tir as T, ir as I
from typing import List, Dict
import re
def find_di_locations(source: str) -> Dict[int, int]:
"""
Parse out DILocation references in printed LLVM IR
"""
result = {}
for line in source.splitlines():
m = re.match(r"!(\d+) = !DILocation\(line: (\d+).*", line)
if m:
debug_id, line = m.groups()
result[debug_id] = line
return result
def _module():
@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle):
# We exchange data between function by handles, which are similar to pointer.
T.func_attr(
{
"tir.noalias": True,
"target": T.target("llvm"),
}
)
# Create buffer from handles.
A = T.match_buffer(a, (8,), dtype="float32")
B = T.match_buffer(b, (8,), dtype="float32")
for i in range(8):
# A block is an abstraction for computation.
with T.block("B"):
# Define a spatial block iterator and bind it to value i.
vi = T.axis.spatial(8, i)
assert 1 == 0, "Some numbers"
B[vi] = A[vi] + 1.0
return MyModule
def test_tir_debug_info():
"""
Test that Spans are correctly replaced with debug spans that reference
the printed TIR
"""
def find_span(m):
func = next(m.functions.values())
return func.body.block.body.span
module_before = _module()
span_before = find_span(module_before)
assert span_before is None
module_after = tir.transform.InstallDebugSpans()(module_before)
span_after = find_span(module_after)
# Check that the module name has been added and a line number is present
assert span_after.source_name.name == "main.tir"
assert span_after.line == 4
def test_tir_debug_info_with_subroutine():
"""Like test_tir_debug_info, but with a TIR subroutine
The current InstallDebugSpans applies to a single PrimFunc. This
test verifies that the existence of device-side subroutines
"""
def find_span(m):
func = next(m.functions.values())
return func.body.block.body.span
@tvm.script.ir_module
class module_before:
@T.prim_func
def main(a: T.handle, b: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": T.target("llvm")})
A = T.match_buffer(a, (8,), dtype="float32")
B = T.match_buffer(b, (8,), dtype="float32")
for i in range(8):
with T.block("B"):
vi = T.axis.spatial(8, i)
module_before.subroutine(T.address_of(A[vi]), T.address_of(B[vi]))
@T.prim_func
def subroutine(a_ptr: T.handle("float32"), b_ptr: T.handle("float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.decl_buffer(1, "float32", data=a_ptr)
B = T.decl_buffer(1, "float32", data=b_ptr)
B[0] = A[1] + 1.0
span_before = find_span(module_before)
assert span_before is None
module_after = tir.transform.InstallDebugSpans()(module_before)
span_after = find_span(module_after)
# Check that the module name has been added and a line number is present
assert span_after.source_name.name == "main.tir"
assert span_after.line == 4
def test_llvm_ir_debug_info():
"""
Check that the right amount of debug locations are present
"""
MyModule = _module()
with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": True}):
runtime_module = tvm.build(MyModule, target="llvm")
source = runtime_module.get_source()
locations = find_di_locations(source)
assert len(locations) == 41
def test_llvm_ir_debug_accuracy():
"""
Check that the debug location on an assert is correct
"""
MyModule = _module()
with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": True}):
runtime_module = tvm.build(MyModule, target="llvm")
source = runtime_module.get_source()
locations = find_di_locations(source)
# Find the 'assert' from MyModule
debug_dir_match = re.search(r"tail call void %0\(.* !dbg !(\d+)\n", source)
# Extract out the debug directive line
directive_idx = debug_dir_match.groups()[0]
# Check that it matches the expected line number (in main.tir)
debug_line_no = int(locations[directive_idx])
assert debug_line_no == 60
def test_building_without_llvm_equivalent():
"""A TIR PrimFunc may contain non-LLVM types
Types used in optimized kernels (e.g. "e4m3_float8") may not have
an equivalent in DWARF, or the mapping from TIR type to DWARF type
may not be defined. If this occurs, the function should still be
able to be built.
"""
@I.ir_module
class Module:
@T.prim_func(private=True)
def main(A_data: T.handle("e4m3_float8"), B_data: T.handle("e4m3_float8")):
A = T.decl_buffer(128, "e4m3_float8", data=A_data)
B = T.decl_buffer(128, "e4m3_float8", data=B_data)
for i in range(128):
B[i] = A[i]
tvm.target.codegen.build_module(Module, "llvm")
if __name__ == "__main__":
tvm.testing.main()