| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import numpy as np |
| import pytest |
| import tvm |
| from tvm import relax |
| import tvm.topi.testing |
| from tvm.relax.transform import LegalizeOps |
| from tvm.script import relax as R, tir as T |
| import tvm.testing |
| |
| # TODO(tvm-team): `tir.transform.DefaultGPUSchedule` does not work. |
| target, dev = "llvm", tvm.cpu() |
| |
| |
| def build(mod): |
| exe = tvm.compile(mod, target=target) |
| return relax.VirtualMachine(exe, dev) |
| |
| |
| @pytest.mark.parametrize( |
| "begin, end, strides", |
| [ |
| ([0, 2, 4, 4], [5, 5, 7, 8], [1, 1, 2, 3]), |
| ([0, 2, 4, 4], [5, 5, 11, 10], [1, 1, 1, 1]), |
| ([0, 2, 10, 14], [0, 5, 1, 1], [1, 1, -1, -2]), |
| ], |
| ) |
| def test_dynamic_strided_slice(begin, end, strides): |
| # fmt: off |
| @tvm.script.ir_module |
| class DynamicStridedSlice: |
| @R.function |
| def main(x: R.Tensor((8, 9, 10, 10), "float32"), begin: R.Tensor((4,),"int64"), end: R.Tensor((4,),"int64"), strides: R.Tensor((4,),"int64")) -> R.Tensor("float32", ndim=4): |
| gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, begin, end, strides) |
| return gv |
| # fmt: on |
| vm = build(DynamicStridedSlice) |
| |
| x_np = np.random.rand(8, 9, 10, 10).astype(np.float32) |
| data_nd = tvm.runtime.tensor(x_np, dev) |
| begin_nd = tvm.runtime.tensor(np.array(begin).astype("int64"), dev) |
| end_nd = tvm.runtime.tensor(np.array(end).astype("int64"), dev) |
| strides_nd = tvm.runtime.tensor(np.array(strides).astype("int64"), dev) |
| |
| # Reference implementation |
| out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides) |
| out_nd = vm["main"](data_nd, begin_nd, end_nd, strides_nd) |
| tvm.testing.assert_allclose(out_nd.numpy(), out_npy) |
| |
| |
| @pytest.mark.parametrize( |
| "begin, end, strides", |
| [ |
| ([0, 2, 4, 4], [5, 5, 7, 8], [1, 1, 2, 3]), |
| ([0, 2, 4, 4], [5, 5, 11, 10], [1, 1, 1, 1]), |
| ([0, 2, 10, 14], [0, 5, 1, 1], [1, 1, -1, -2]), |
| ], |
| ) |
| def test_dynamic_strided_slice_symbolic(begin, end, strides): |
| # fmt: off |
| @tvm.script.ir_module |
| class DynamicStridedSlice: |
| @R.function |
| def main(x: R.Tensor(("m", "n", 10, 10), "float32"), begin: R.Tensor((4,),"int64"), end: R.Tensor((4,),"int64"), strides: R.Tensor((4,),"int64")) -> R.Tensor("float32", ndim=4): |
| m = T.int64() |
| n = T.int64() |
| gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, begin, end, strides) |
| return gv |
| # fmt: on |
| vm = build(DynamicStridedSlice) |
| |
| x_np = np.random.rand(8, 9, 10, 10).astype(np.float32) |
| data_nd = tvm.runtime.tensor(x_np, dev) |
| begin_nd = tvm.runtime.tensor(np.array(begin).astype("int64"), dev) |
| end_nd = tvm.runtime.tensor(np.array(end).astype("int64"), dev) |
| strides_nd = tvm.runtime.tensor(np.array(strides).astype("int64"), dev) |
| |
| # Reference implementation |
| out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides) |
| out_nd = vm["main"](data_nd, begin_nd, end_nd, strides_nd) |
| tvm.testing.assert_allclose(out_nd.numpy(), out_npy) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |