blob: c9628daf0d805063ee85762ab9e3390285f85e03 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <gtest/gtest.h>
#include <tvm/relax/nested_msg.h>
#include <tvm/relax/struct_info.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/logging.h>
#include <tvm/tir/expr.h>
#include <algorithm>
#include <array>
#include <cstring>
#include <functional>
#include <iterator>
#include <new>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
using namespace tvm;
using namespace tvm::runtime;
using namespace tvm::relax;
TEST(NestedMsg, Basic) {
// start with no annotation
relax::Var x("x", std::nullopt), y("y", std::nullopt);
// constructor from array, T and nullopt.
NestedMsg<relax::Expr> msg({x, std::nullopt, x});
EXPECT_TRUE(msg.IsNested());
EXPECT_FALSE(msg.IsLeaf());
EXPECT_TRUE(msg != nullptr);
EXPECT_ANY_THROW(msg.LeafValue());
auto arr = msg.NestedArray();
EXPECT_TRUE(arr[0].LeafValue().same_as(x));
EXPECT_TRUE(arr[1] == nullptr);
EXPECT_TRUE(arr[1].IsNull());
EXPECT_TRUE(arr[2].LeafValue().same_as(x));
auto a0 = arr[0];
EXPECT_TRUE(a0.IsLeaf());
// assignment
// assign null
a0 = std::nullopt;
EXPECT_TRUE(a0 == nullptr);
// assign array
a0 = {x, {x, std::nullopt, y}};
EXPECT_TRUE(a0.IsNested());
auto t0 = a0.NestedArray()[1];
EXPECT_TRUE(t0.IsNested());
EXPECT_TRUE(t0.NestedArray()[2].LeafValue().same_as(y));
// assign leaf
a0 = x;
EXPECT_TRUE(a0.IsLeaf());
EXPECT_TRUE(a0.LeafValue().same_as(x));
}
TEST(NestedMsg, IntAndAny) {
NestedMsg<int64_t> msg({1, std::nullopt, 2});
Any any_msg = msg;
NestedMsg<int64_t> msg2 = any_msg.cast<NestedMsg<int64_t>>();
EXPECT_TRUE(msg2.IsNested());
EXPECT_EQ(msg2.NestedArray()[0].LeafValue(), 1);
EXPECT_TRUE(msg2.NestedArray()[1].IsNull());
EXPECT_EQ(msg2.NestedArray()[2].LeafValue(), 2);
}
TEST(NestedMsg, ForEachLeaf) {
relax::Var x("x", std::nullopt), y("y", std::nullopt);
NestedMsg<Expr> msg = {x, {x, y}, std::nullopt, {x, {x, y}}};
int x_count = 0, y_count = 0;
ForEachLeaf(msg, [&](const Expr& v) {
if (v.same_as(x)) ++x_count;
if (v.same_as(y)) ++y_count;
});
EXPECT_EQ(x_count, 4);
EXPECT_EQ(y_count, 2);
}
TEST(NestedMsg, Equal) {
relax::Var x("x", std::nullopt), y("y", std::nullopt);
relax::Var z("z", std::nullopt);
auto fequal = [](Expr lhs, Expr rhs) { return lhs.same_as(rhs); };
using M = NestedMsg<relax::Expr>;
EXPECT_TRUE(Equal(M(std::nullopt), M(std::nullopt), fequal));
EXPECT_TRUE(Equal(M(x), M(x), fequal));
EXPECT_TRUE(Equal(M({x, y}), M({x, y}), fequal));
EXPECT_TRUE(Equal(M({x, std::nullopt}), M({x, std::nullopt}), fequal));
EXPECT_TRUE(Equal(M({x, {std::nullopt, y}}), M({x, {std::nullopt, y}}), fequal));
EXPECT_TRUE(Equal(M({x, {std::nullopt, y}, {x, z}}), M({x, {std::nullopt, y}, {x, z}}), fequal));
// type mismatch
EXPECT_FALSE(Equal(M({x, {std::nullopt, y}, x}), M({x, {std::nullopt, y}, {x, z}}), fequal));
EXPECT_FALSE(Equal(M({x, {std::nullopt, y}, {x, std::nullopt}}),
M({x, {std::nullopt, y}, {x, z}}), fequal));
EXPECT_FALSE(Equal(M({x, {std::nullopt, y}}), M({x, {std::nullopt, y}, {x, z}}), fequal));
EXPECT_FALSE(Equal(M(x), M(std::nullopt), fequal));
EXPECT_FALSE(Equal(M(std::nullopt), M(x), fequal));
EXPECT_FALSE(Equal(M(x), M(ffi::Array<M>({x})), fequal));
EXPECT_FALSE(Equal(M(ffi::Array<M>({x})), M(x), fequal));
}
TEST(NestedMsg, MapAndDecompose) {
relax::Var x("x", PrimStructInfo(runtime::DataType::Int(16)));
relax::Var y("y", PrimStructInfo(runtime::DataType::Int(32)));
relax::Var z("z", PrimStructInfo(runtime::DataType::Int(64)));
BlockBuilder bb = BlockBuilder::Create(std::nullopt);
relax::Expr t0 = bb->Normalize(Tuple({x, y}));
relax::Expr t1 = bb->Normalize(Tuple({t0, x, z, t0}));
auto c0 = Integer(0);
auto c1 = Integer(1);
auto c2 = Integer(2);
auto output = MapToNestedMsg<Integer>(t1, [&](Expr value) {
if (value.same_as(x)) return c0;
if (value.same_as(y)) return c1;
return c2;
});
NestedMsg<Integer> expected = {{c0, c1}, c0, c2, {c0, c1}};
EXPECT_TRUE(Equal(output, expected,
[](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; }));
auto output2 =
MapToNestedMsg<Integer>(GetStructInfo(t1), [&](StructInfo sinfo) -> NestedMsg<Integer> {
const auto* prim_sinfo = sinfo.as<PrimStructInfoNode>();
if (prim_sinfo == nullptr) return std::nullopt;
int bits = prim_sinfo->dtype.bits();
if (bits == 16) return c0;
if (bits == 32) return c1;
if (bits == 64) return c2;
return std::nullopt;
});
EXPECT_TRUE(Equal(output2, expected,
[](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; }));
int x_count = 0, y_count = 0, z_count = 0;
DecomposeNestedMsg(t1, expected, [&](Expr value, NestedMsg<Integer> msg) {
if (value.same_as(x)) {
EXPECT_TRUE(msg.LeafValue().same_as(c0));
++x_count;
} else if (value.same_as(y)) {
EXPECT_TRUE(msg.LeafValue().same_as(c1));
++y_count;
} else {
EXPECT_TRUE(msg.LeafValue().same_as(c2));
++z_count;
}
});
EXPECT_EQ(x_count, 3);
EXPECT_EQ(y_count, 2);
EXPECT_EQ(z_count, 1);
}
TEST(NestedMsg, MapToNestedMsgBySInfo) {
auto sf0 = TensorStructInfo(DataType::Float(32), /*ndim=*/0);
auto sf1 = TupleStructInfo({sf0, sf0});
auto sf2 = TupleStructInfo({sf0, sf0});
auto x = relax::Var("x", TupleStructInfo({sf1, sf2, sf0}));
auto msg = MapToNestedMsgBySInfo<Expr>(x, [](Expr value) { return value; });
EXPECT_TRUE(msg.IsNested());
auto arr = msg.NestedArray();
EXPECT_TRUE(arr[1].IsNested());
auto arr1 = arr[1].NestedArray();
EXPECT_TRUE(arr1[0].IsLeaf());
EXPECT_TRUE(StructuralEqual()(arr1[0].LeafValue(), TupleGetItem(TupleGetItem(x, 1), 0)));
EXPECT_TRUE(arr[2].IsLeaf());
EXPECT_TRUE(StructuralEqual()(arr[2].LeafValue(), TupleGetItem(x, 2)));
}
TEST(NestedMsg, NestedMsgToExpr) {
auto sf0 = TensorStructInfo(DataType::Float(32), /*ndim=*/0);
auto sf1 = TupleStructInfo({sf0, sf0});
auto c0 = Integer(0);
auto c1 = Integer(1);
auto c2 = Integer(2);
relax::Var x("x", sf0), y("y", sf0), z("z", sf0);
NestedMsg<Integer> msg = {c0, {c0, c1}, {c0, {c1, c2}}};
auto expr = NestedMsgToExpr<Integer>(msg, [&](ffi::Optional<Integer> leaf) {
ICHECK(leaf.defined());
int value = leaf.value().IntValue();
switch (value) {
case 0:
return x;
case 1:
return y;
default:
return z;
}
});
Expr expected = Tuple({x, Tuple({x, y}), Tuple({x, Tuple({y, z})})});
EXPECT_TRUE(StructuralEqual()(expr, expected));
// test simplified
relax::Var t("t", sf1);
NestedMsg<Expr> msg1 = {TupleGetItem(t, 0), TupleGetItem(t, 1)};
auto expr1 = NestedMsgToExpr<Expr>(msg1, [](ffi::Optional<Expr> leaf) { return leaf.value(); });
EXPECT_TRUE(StructuralEqual()(expr1, t));
}
TEST(NestedMsg, CombineNestedMsg) {
auto c0 = Integer(0);
auto c1 = Integer(1);
auto c2 = Integer(2);
NestedMsg<Integer> lhs = {c0, {c0, c1}, std::nullopt, {c0, {c1, c2}}};
NestedMsg<Integer> rhs = {c1, {c2, std::nullopt}, std::nullopt, {c1, {c2, c2}}};
NestedMsg<Integer> expected = {c1, {c2, c1}, std::nullopt, {c1, {c2, c2}}};
auto output = CombineNestedMsg(lhs, rhs, [](Integer x, Integer y) {
if (x->value > y->value) return x;
return y;
});
EXPECT_TRUE(Equal(output, expected,
[](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; }));
}
TEST(NestedMsg, MapNestedMsg) {
auto c0 = Integer(0);
auto c1 = Integer(1);
auto c2 = Integer(2);
auto c3 = Integer(3);
NestedMsg<Integer> msg = {c0, {c0, c1}, std::nullopt, {c0, {c2, c1}}};
NestedMsg<Integer> expected = {c3, {c3, std::nullopt}, std::nullopt, {c3, {c2, std::nullopt}}};
auto output = MapNestedMsg(msg, [](Integer x) {
if (x->value == 0) {
return NestedMsg<Integer>(Integer(3));
} else if (x->value == 1) {
return NestedMsg<Integer>();
} else {
return NestedMsg<Integer>(x);
}
});
EXPECT_TRUE(Equal(output, expected,
[](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; }));
}
TEST(NestedMsg, TransformTupleLeaf) {
auto c0 = Integer(0);
auto c1 = Integer(1);
auto c2 = Integer(2);
using NInt = NestedMsg<Integer>;
NInt msg1 = {c0, {c0, c1}, c2, {c0, {c1, c2}}};
NInt msg2 = {c1, {c2, c0}, c2, {c1, {c2, c0}}};
PrimStructInfo s = PrimStructInfo(runtime::DataType::Int(32));
relax::Var x("x", s), y("y", s), z("z", s);
BlockBuilder bb = BlockBuilder::Create(std::nullopt);
Expr expr = bb->Normalize(Tuple({x, Tuple({x, x}), x, Tuple({x, Tuple({x, x})})}));
auto ftransleaf = [&](Expr value, std::array<NInt, 2> msgs) -> Expr {
int lhs = Downcast<Integer>(msgs[0].LeafValue())->value;
int rhs = Downcast<Integer>(msgs[1].LeafValue())->value;
if (lhs > rhs)
return z;
else if (lhs == rhs)
return value;
else
return y;
};
Expr expected = Tuple({y, Tuple({y, z}), x, Tuple({y, Tuple({y, z})})});
EXPECT_TRUE(StructuralEqual()(
TransformTupleLeaf(expr, std::array<NInt, 2>({msg1, msg2}), ftransleaf), expected));
EXPECT_TRUE(
expr.same_as(TransformTupleLeaf(expr, std::array<NInt, 2>({msg1, msg1}), ftransleaf)));
}