blob: 5063509e4e3501be63266f51746299b52fbe6b2f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../src/arith/pattern_match.h"
#include <gtest/gtest.h>
#include <tvm/tir/analysis.h>
TEST(Pattern, Basic) {
using namespace tvm;
using namespace tvm::tir;
using namespace tvm::arith;
tvm::tir::Var x("x"), y("y"), z("z");
arith::PVar<PrimExpr> px, py, pz;
arith::PVar<DataType> pt;
arith::PVar<int> planes;
// arithmetics
auto r = 1 + (y + 1);
CHECK(!(px + (px + px)).Match(r));
CHECK(!(px + (py + py)).Match(r));
CHECK((px + (py + pz)).Match(r));
auto pattern = px + (py + pz);
CHECK(pattern.Match(r));
{
CHECK((px + (py + px)).Match(r));
auto rr = (px + py).Eval();
CHECK(tir::ExprDeepEqual()(rr, 1 + y));
CHECK(tir::ExprDeepEqual()(px.Eval() + py.Eval(), 1 + y));
}
{
CHECK((px + max(py, px)).Match((x + 1) + max(y, (x + 1))));
CHECK(tir::ExprDeepEqual()(px.Eval(), x + 1));
}
CHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1))));
CHECK((px + min(py, px)).Match(z + min(y, z)));
CHECK((px + truncdiv(py, px * py)).Match(x + truncdiv(2, x * 2)));
CHECK((px - truncmod(py, px * pz)).Match(x - truncmod(2, x * 2)));
CHECK((px - floormod(py, px * PConst<PrimExpr>(2))).Match(x - floormod(2, x * 2)));
// logicals
CHECK((px == pz).Match(x == 1));
CHECK((px != pz).Match(x != 1));
CHECK((px > py).Match(x > y));
CHECK((px < py).Match(x < y));
CHECK((px <= py).Match(x <= y));
CHECK((px >= py).Match(x >= y));
CHECK((px >= py && px < pz).Match(x >= y && x < z));
CHECK((!(px > py || px != py)).Match(!(x > y || x != y)));
{
CHECK(select(px >= pz, py, py + pz).Match(tir::Select((x + 1) >= 1, y, y + 1)));
CHECK(tir::ExprDeepEqual()(px.Eval(), x + 1));
}
// bit intrinsics
{
CHECK((px >> pz).Match(x >> 1));
CHECK(is_const_int(pz.Eval(), 1));
}
CHECK(!(px >> pz).Match(x << 1));
CHECK((px << pz).Match(x << 1));
CHECK((px & pz).Match(x & 1));
CHECK((px | pz).Match(x | 1));
CHECK((px ^ pz).Match(x ^ 1));
CHECK((px - (~(py | (px * pz)))).Match(x - (~(2 | (x * 2)))));
// select
{
CHECK(select(px > pz, py, py + pz).Match(tir::Select(x > 1, y, y + 1)));
CHECK(is_const_int(pz.Eval(), 1));
}
CHECK(!select(px > pz, py, py + pz).Match(tir::Select(x > 2, y, y + 1)));
CHECK(!select(px > pz, py, py).Match(tir::Select(x > 2, y, y + 1)));
{
CHECK(select(px, py, pz).Match(tir::Select(x > 2, y, y + 1)));
CHECK(tir::ExprDeepEqual()(pz.Eval(), y + 1));
}
// if_then_else
{
CHECK(if_then_else(px > pz, py, py + pz).Match(if_then_else(x > 1, y, y + 1)));
CHECK(is_const_int(pz.Eval(), 1));
}
// cast pattern
{
CHECK(!cast(PConst<DataType>(DataType::Int(32)), px).Match(tir::Cast(DataType::Float(64), x)));
CHECK(cast(pt, px).Match(tir::Cast(DataType::Float(64), x)));
CHECK(pt.Eval() == DataType::Float(64));
auto zz = cast(pt, px).Eval();
CHECK((cast(pt, px) - cast(pt, py))
.Match(tir::Cast(DataType::Float(64), x) - tir::Cast(DataType::Int(64), x)));
auto expr = tir::Cast(DataType::Int(32), tir::Cast(DataType::Float(64), x));
CHECK(!(cast(pt, cast(pt, px))).Match(expr));
}
// ramp pattern
{
CHECK(ramp(px, PConst<PrimExpr>(1), planes).Match(tir::Ramp(x, 1, 10)));
CHECK(planes.Eval() == 10);
CHECK(!ramp(px, PConst<PrimExpr>(1), planes).Match(tir::Ramp(x, 2, 10)));
}
// broadcast pattern
{
CHECK(broadcast(px, planes).Match(tir::Broadcast(x, 10)));
CHECK(planes.Eval() == 10);
CHECK(broadcast(px * py, planes).Match(tir::Broadcast(x * 10, 10)));
}
}
TEST(Pattern, IntImm) {
using namespace tvm;
tir::Var tx, ty;
arith::PVar<IntImm> c;
arith::PVar<tir::Var> v;
{
// We can match integer and Var, both of which are
// special case container of Expr
CHECK((v * c).Match(tx * 3));
CHECK_EQ(c.Eval()->value, 3);
CHECK((v * 3).Match(tx * 3));
}
// cannot match c to ty
CHECK(!(v * c).Match(tx * ty));
// cannot match tx + 1 to v
CHECK(!(v * c).Match((tx + 1) * 3));
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}