blob: 1aa78ee68c5bbc9f5b588f4143f5c81736170704 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
# mypy: ignore-errors
import sys
import pytest
import tvm.testing
from tvm.tir.schedule import BlockRV, Instruction, InstructionKind, LoopRV
def test_inst_kind_get():
kind = InstructionKind.get("EnterPostproc")
assert not kind.is_pure
assert kind.name == "EnterPostproc"
def test_inst_construct_1():
block = BlockRV()
loop0 = LoopRV()
loop1 = LoopRV()
inst = Instruction(
kind=InstructionKind.get("GetLoops"),
inputs=[block],
attrs=[],
outputs=[loop0, loop1],
)
assert str(inst) == "_, _ = sch.get_loops(block=_)"
assert len(inst.inputs) == 1
assert len(inst.attrs) == 0
assert len(inst.outputs) == 2
assert inst.kind.same_as(InstructionKind.get("GetLoops"))
assert inst.inputs[0].same_as(block)
assert inst.outputs[0].same_as(loop0)
assert inst.outputs[1].same_as(loop1)
def test_inst_construct_2():
block = BlockRV()
inst = Instruction(
kind=InstructionKind.get("ComputeInline"),
inputs=[block],
attrs=[],
outputs=[],
)
assert str(inst) == "sch.compute_inline(block=_)"
assert len(inst.inputs) == 1
assert len(inst.attrs) == 0
assert len(inst.outputs) == 0
assert inst.kind.same_as(InstructionKind.get("ComputeInline"))
assert inst.inputs[0].same_as(block)
if __name__ == "__main__":
tvm.testing.main()