| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import tvm |
| import tvm.testing |
| from tvm.contrib import tvmjs, utils |
| |
| import pytest |
| import numpy as np |
| |
| from tvm.ir import assert_structural_equal |
| from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode |
| |
| |
| def test_make_shape(): |
| MK = MakeShapeCode |
| make_shape = tvm.get_global_func("vm.builtin.make_shape") |
| heap = tvm.runtime.tensor(np.arange(10).astype("int64")) |
| s = make_shape(heap, 3, MK.USE_IMM, 10, MK.LOAD_SHAPE, 0, MK.LOAD_SHAPE, 2) |
| |
| assert s == tvm.runtime.container.ShapeTuple([10, 0, 2]) |
| |
| |
| def test_match_shape(): |
| MS = MatchShapeCode |
| match_shape = tvm.get_global_func("vm.builtin.match_shape") |
| heap = tvm.runtime.tensor(np.zeros(10).astype("int64")) |
| |
| assert heap.numpy()[2] == 0 |
| |
| s = tvm.runtime.container.ShapeTuple([1, 2, 3]) |
| x = tvm.runtime.tensor(np.zeros([1, 2, 3])) |
| |
| match_shape(s, heap, 3, MS.ASSERT_EQUAL_TO_IMM, 1, MS.STORE_TO_HEAP, 2, MS.NO_OP, 0, "") |
| |
| assert heap.numpy()[2] == 2 |
| |
| match_shape( |
| x, |
| heap, |
| 3, |
| MS.ASSERT_EQUAL_TO_IMM, |
| 1, |
| MS.ASSERT_EQUAL_TO_LOAD, |
| 2, |
| MS.ASSERT_EQUAL_TO_IMM, |
| 3, |
| "", |
| ) |
| |
| with pytest.raises(RuntimeError): |
| match_shape(s, heap, 2, MS.ASSERT_EQUAL_TO_IMM, 1, MS.STORE_TO_HEAP, 2, "") |
| |
| with pytest.raises(RuntimeError): |
| match_shape(s, heap, 3, MS.ASSERT_EQUAL_TO_IMM, 2, MS.STORE_TO_HEAP, 2, MS.NO_OP, 0, "") |
| |
| |
| def test_check_shape_info(): |
| check_shape_info = tvm.get_global_func("vm.builtin.check_shape_info") |
| s = tvm.runtime.container.ShapeTuple([1, 2, 3]) |
| |
| check_shape_info(s, 3, "") |
| check_shape_info(s, -1, "") |
| |
| # wrong ndim |
| with pytest.raises(ValueError): |
| check_shape_info(s, 2, "") |
| |
| # wrong type |
| with pytest.raises(TypeError): |
| check_shape_info([], 2, "") |
| |
| |
| def test_check_tensor_info(): |
| check_tensor_info = tvm.get_global_func("vm.builtin.check_tensor_info") |
| x = tvm.runtime.tensor(np.zeros((2, 3)).astype("int32")) |
| |
| check_tensor_info(x, 2, "int32", "") |
| check_tensor_info(x, -1, "int32", "") |
| check_tensor_info(x, 2, "", "") |
| check_tensor_info(x, -1, "", "") |
| |
| # allow not passing in dtype |
| check_tensor_info(x, 2, "") |
| check_tensor_info(x, -1, "") |
| |
| # ndim mismatch |
| with pytest.raises(ValueError, match=r".* ndim .*"): |
| check_tensor_info(x, 3, "int32", "") |
| |
| # dtype mismatch |
| with pytest.raises(ValueError, match=r"myerror.* dtype .*"): |
| check_tensor_info(x, 2, "float32", "myerror") |
| |
| # error with context |
| with pytest.raises(ValueError, match=r".* myerror .*"): |
| check_tensor_info(x, 3, "myerror") |
| |
| # wrong type |
| with pytest.raises(TypeError): |
| check_tensor_info([], 2, "", "") |
| |
| |
| def test_check_tuple_info(): |
| check_tuple_info = tvm.get_global_func("vm.builtin.check_tuple_info") |
| x = tvm.runtime.tensor(np.zeros((2, 3)).astype("int32")) |
| t = tvm.runtime.convert([x, x, x]) |
| |
| check_tuple_info(t, 3, "") |
| |
| # size |
| with pytest.raises(ValueError, match=r".*elements.*"): |
| check_tuple_info(t, 2, "") |
| |
| # wrong type |
| with pytest.raises(TypeError): |
| check_tuple_info(x, 2, "") |
| |
| |
| def test_check_func_info(): |
| check_func_info = tvm.get_global_func("vm.builtin.check_func_info") |
| f = tvm.runtime.convert(lambda x: x) |
| x = tvm.runtime.tensor(np.zeros((2, 3)).astype("int32")) |
| |
| check_func_info(f, "") |
| |
| # wrong type |
| with pytest.raises(TypeError, match=".*myerror.*"): |
| check_func_info(x, "myerror") |
| |
| |
| def test_tuple_getitem(): |
| tuple_getitem = tvm.get_global_func("vm.builtin.tuple_getitem") |
| x = tvm.runtime.tensor(np.zeros((2, 3)).astype("int32")) |
| y = tvm.runtime.tensor(np.zeros((2, 3)).astype("int32")) |
| t = tvm.runtime.convert([x, y]) |
| |
| assert tuple_getitem(t, 0) == x |
| assert tuple_getitem(t, 1) == y |
| |
| |
| def test_attention_kv_cache(): |
| fcreate = tvm.get_global_func("vm.builtin.attention_kv_cache_create") |
| fappend = tvm.get_global_func("vm.builtin.attention_kv_cache_append") |
| fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view") |
| |
| cache = fcreate(tvm.runtime.empty((1, 2), dtype="int32"), tvm.runtime.ShapeTuple([2, 2]), 0) |
| num_steps = 2 |
| for i in range(num_steps): |
| cache = fappend(cache, tvm.runtime.tensor(i * np.ones((1, 2)).astype("int32"))) |
| |
| res = fview(cache, tvm.runtime.ShapeTuple((num_steps, 2))).numpy() |
| for i in range(num_steps): |
| assert res[i][0] == i |
| assert res[i][1] == i |
| |
| |
| def test_tensor_cache(): |
| fload = tvm.get_global_func("vm.builtin.tensor_cache.load") |
| fget_params = tvm.get_global_func("vm.builtin.param_array_from_cache") |
| |
| param_dict = { |
| "x_0": np.array([1, 2, 3], dtype="int32"), |
| "x_1": np.random.uniform(size=[10, 20]).astype("float32"), |
| } |
| |
| temp = utils.tempdir() |
| tvmjs.dump_tensor_cache(param_dict, temp.path, encode_format="f32-to-bf16") |
| fload(str(temp.path), tvm.cpu().dlpack_device_type(), 0) |
| res = fget_params("x", -1) |
| for i, v in enumerate(res): |
| v_np = param_dict[f"x_{i}"] |
| if v_np.dtype == "float32": |
| v_np = tvmjs._convert_bf16_to_f32(tvmjs._convert_f32_to_bf16(v_np)) |
| np.testing.assert_allclose(v.numpy(), v_np, atol=1e-6, rtol=1e-6) |
| |
| |
| def test_tensor_cache_update(): |
| fload = tvm.get_global_func("vm.builtin.tensor_cache.load") |
| fget_params = tvm.get_global_func("vm.builtin.param_array_from_cache") |
| |
| param_dict = { |
| "x_0": np.array([1, 2, 3], dtype="int32"), |
| "x_1": np.random.uniform(size=[10, 20]).astype("float32"), |
| } |
| |
| temp = utils.tempdir() |
| tvmjs.dump_tensor_cache(param_dict, temp.path, encode_format="f32-to-bf16") |
| param_dict["x_1"] = np.random.uniform(size=[10, 20]).astype("float32") |
| param_dict["x_2"] = np.random.uniform(size=[10]).astype("float32") |
| tvmjs.dump_tensor_cache( |
| param_dict, temp.path, encode_format="f32-to-bf16", update_if_exists=True |
| ) |
| fload(str(temp.path), tvm.cpu().dlpack_device_type(), 0) |
| res = fget_params("x", -1) |
| for i, v in enumerate(res): |
| v_np = param_dict[f"x_{i}"] |
| if v_np.dtype == "float32": |
| v_np = tvmjs._convert_bf16_to_f32(tvmjs._convert_f32_to_bf16(v_np)) |
| np.testing.assert_allclose(v.numpy(), v_np, atol=1e-6, rtol=1e-6) |
| |
| |
| def test_attention_kv_cache_window_override(): |
| fcreate = tvm.get_global_func("vm.builtin.attention_kv_cache_create") |
| foverride = tvm.get_global_func("vm.builtin.attention_kv_cache_window_override") |
| fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view") |
| |
| current_pos = 4 |
| cache = fcreate( |
| tvm.runtime.tensor(np.full((16, 2), -1).astype("int32")), |
| tvm.runtime.ShapeTuple([16, 2]), |
| current_pos, |
| ) |
| np_all_arrays = np.zeros((0, 2)).astype("int32") |
| |
| num_steps = 10 |
| for i in range(1, num_steps): |
| np_array = i * np.ones((i, 2)).astype("int32") |
| np_all_arrays = np.concatenate((np_all_arrays, np_array), axis=0) |
| cache = foverride(cache, tvm.runtime.tensor(np_array), 16) |
| current_pos = (current_pos + i) % 16 |
| |
| res = fview(cache, tvm.runtime.ShapeTuple((16, 2))).numpy() |
| |
| # unrotate cache and assert cache matches last 16 elements |
| assert ( |
| np_all_arrays[np_all_arrays.shape[0] - 16 :, :] |
| == np.concatenate((res[current_pos:], res[:current_pos])) |
| ).all() |
| |
| |
| def test_attention_kv_cache_window_override_with_sinks(): |
| fcreate = tvm.get_global_func("vm.builtin.attention_kv_cache_create") |
| foverride = tvm.get_global_func("vm.builtin.attention_kv_cache_window_override_with_sinks") |
| fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view") |
| |
| num_attention_sinks = 2 |
| has_sink = False |
| current_pos = 0 |
| |
| cache = fcreate( |
| tvm.runtime.tensor(np.full((16, 2), -1).astype("int32")), |
| tvm.runtime.ShapeTuple([16, 2]), |
| current_pos, |
| ) |
| np_all_arrays = np.zeros((0, 2)).astype("int32") |
| |
| num_steps = 40 |
| for i in range(num_steps): |
| np_array = i * np.ones((1, 2)).astype("int32") |
| np_all_arrays = np.concatenate((np_all_arrays, np_array), axis=0) |
| cache = foverride(cache, tvm.runtime.tensor(np_array), 16, num_attention_sinks) |
| |
| if has_sink: |
| current_pos = max((current_pos + 1) % 16, num_attention_sinks) |
| else: |
| current_pos += 1 |
| has_sink = current_pos >= num_attention_sinks |
| |
| res = fview(cache, tvm.runtime.ShapeTuple((16, 2))).numpy() |
| |
| # unrotate cache and assert cache matches last 16 elements |
| assert ( |
| np.concatenate( |
| (np_all_arrays[:num_attention_sinks, :], np_all_arrays[-16 + num_attention_sinks :, :]) |
| ) |
| == np.concatenate( |
| (res[:num_attention_sinks], res[current_pos:], res[num_attention_sinks:current_pos]) |
| ) |
| ).all() |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |