blob: a76cb50da3bd39165cfe3f4d072f1c337139d23b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test layout and bijective-layout node"""
import pytest
import tvm
import tvm.error
from tvm.topi.utils import get_const_tuple
def test_layout():
layout = tvm.tir.layout("NCHW16c")
assert layout is not None
assert isinstance(layout, tvm.tir.Layout)
assert layout.factor_of("c") == 16
assert layout.factor_of("C") == 16
assert layout.factor_of("N") == -1
assert layout.index_of("N") == 0
assert layout.index_of("C") == 1
assert layout.index_of("H") == 2
assert layout.index_of("W") == 3
assert layout.index_of("c") == 4
assert layout.index_of("O") == -1
assert "N" in layout
assert "C" in layout
assert "H" in layout
assert "W" in layout
assert "c" in layout
assert "O" not in layout
assert layout[0] == "N"
assert layout[1] == "C"
assert layout[2] == "H"
assert layout[3] == "W"
assert layout[4] == "c"
assert layout[-1] == "c"
def test_layout_dtype():
layout_i32 = tvm.tir.layout("NCHW")
assert layout_i32.axes[0].var.dtype == "int32"
assert layout_i32.axes[0].dom.min.dtype == "int32"
assert layout_i32.axes[0].dom.extent.dtype == "int32"
assert layout_i32.axes[1].var.dtype == "int32"
assert layout_i32.axes[1].dom.min.dtype == "int32"
assert layout_i32.axes[1].dom.extent.dtype == "int32"
layout_i64 = tvm.tir.layout("NCHW", dtype="int64")
assert layout_i64.axes[2].var.dtype == "int64"
assert layout_i64.axes[2].dom.min.dtype == "int64"
assert layout_i64.axes[2].dom.extent.dtype == "int64"
assert layout_i64.axes[3].var.dtype == "int64"
assert layout_i64.axes[3].dom.min.dtype == "int64"
assert layout_i64.axes[3].dom.extent.dtype == "int64"
with pytest.raises(TypeError):
tvm.tir.layout("NCHW", dtype="float32")
with pytest.raises(TypeError):
tvm.tir.layout("NCHW", dtype=None)
def test_bilayout_convertible():
# not convertible
assert tvm.tir.bijective_layout("NCHW", "ABCD") is None
assert tvm.tir.bijective_layout("__undef__", "NCHW") is None
assert tvm.tir.bijective_layout("NCHW", "__undef__") is None
assert tvm.tir.bijective_layout("__undef__", "__undef__") is None
assert tvm.tir.bijective_layout("", "NCHW") is None
assert tvm.tir.bijective_layout("NCHW", "") is None
assert tvm.tir.bijective_layout("", "") is None
# convertible
assert tvm.tir.bijective_layout("NCHW", "NCHW16c") is not None
def test_bilayout_shape():
bilayout = tvm.tir.bijective_layout("NCHW", "NCHW16c")
assert isinstance(bilayout, tvm.tir.BijectiveLayout)
dst_shape = bilayout.forward_shape((1, 32, 7, 7))
assert get_const_tuple(dst_shape) == (1, 2, 7, 7, 16)
src_shape = bilayout.backward_shape(dst_shape)
assert get_const_tuple(src_shape) == (1, 32, 7, 7)
def test_bilayout_index():
bilayout = tvm.tir.bijective_layout("NCHW", "NCHW16c")
dst_index = bilayout.forward_index([0, 18, 6, 6])
assert get_const_tuple(dst_index) == (0, 1, 6, 6, 2)
src_index = bilayout.backward_index([0, 1, 6, 6, 2])
assert get_const_tuple(src_index) == (0, 18, 6, 6)
if __name__ == "__main__":
test_layout()
test_layout_dtype()
test_bilayout_convertible()
test_bilayout_shape()
test_bilayout_index()