| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| #include <gtest/gtest.h> |
| #include <tvm/relay/dataflow_pattern.h> |
| #include <tvm/tir/analysis.h> |
| |
| TEST(DFPattern, IsVar) { |
| using namespace tvm; |
| using namespace tvm::relay; |
| auto pattern = IsVar("add"); |
| auto* node = pattern.as<VarPatternNode>(); |
| ICHECK(node); |
| ICHECK(node->name == String("add")); |
| } |
| |
| TEST(DFPattern, IsConstant) { |
| using namespace tvm; |
| using namespace tvm::relay; |
| auto pattern = IsConstant(); |
| auto* node = pattern.as<ConstantPatternNode>(); |
| ICHECK(node); |
| } |
| |
| TEST(DFPattern, IsOp) { |
| using namespace tvm; |
| using namespace tvm::relay; |
| auto pattern = IsOp("add"); |
| auto* node = pattern.as<ExprPatternNode>(); |
| ICHECK(node); |
| ICHECK(node->expr == Op::Get("add")); |
| } |
| |
| TEST(DFPattern, IsTuple) { |
| using namespace tvm; |
| using namespace tvm::relay; |
| auto a = WildcardPattern(); |
| auto b = WildcardPattern(); |
| auto pattern = IsTuple({a, b}); |
| auto* node = pattern.as<TuplePatternNode>(); |
| ICHECK(node); |
| ICHECK(node->fields[0] == a); |
| ICHECK(node->fields[1] == b); |
| } |
| |
| TEST(DFPattern, IsTupleGetItem) { |
| using namespace tvm; |
| using namespace tvm::relay; |
| auto a = WildcardPattern(); |
| auto b = WildcardPattern(); |
| auto tuple = IsTuple({a, b}); |
| auto pattern = IsTupleGetItem(tuple, 1); |
| auto* node = pattern.as<TupleGetItemPatternNode>(); |
| ICHECK(node); |
| ICHECK(node->tuple == tuple); |
| ICHECK(node->index == 1); |
| } |
| |
| TEST(DFPattern, ADD) { |
| using namespace tvm; |
| using namespace tvm::relay; |
| auto a = WildcardPattern(); |
| auto b = WildcardPattern(); |
| auto pattern = a + b; |
| auto* node = pattern.as<CallPatternNode>(); |
| ICHECK(node); |
| ICHECK(node->args[0] == a); |
| ICHECK(node->args[1] == b); |
| auto* expr_pattern = node->op.as<ExprPatternNode>(); |
| ICHECK(expr_pattern); |
| ICHECK(expr_pattern->expr == Op::Get("add")); |
| } |
| |
| TEST(DFPattern, SUB) { |
| using namespace tvm; |
| using namespace tvm::relay; |
| auto a = WildcardPattern(); |
| auto b = WildcardPattern(); |
| auto pattern = a - b; |
| auto* node = pattern.as<CallPatternNode>(); |
| ICHECK(node); |
| ICHECK(node->args[0] == a); |
| ICHECK(node->args[1] == b); |
| auto* expr_pattern = node->op.as<ExprPatternNode>(); |
| ICHECK(expr_pattern); |
| ICHECK(expr_pattern->expr == Op::Get("subtract")); |
| } |
| |
| TEST(DFPattern, MUL) { |
| using namespace tvm; |
| using namespace tvm::relay; |
| auto a = WildcardPattern(); |
| auto b = WildcardPattern(); |
| auto pattern = a * b; |
| auto* node = pattern.as<CallPatternNode>(); |
| ICHECK(node); |
| ICHECK(node->args[0] == a); |
| ICHECK(node->args[1] == b); |
| auto* expr_pattern = node->op.as<ExprPatternNode>(); |
| ICHECK(expr_pattern); |
| ICHECK(expr_pattern->expr == Op::Get("multiply")); |
| } |
| |
| TEST(DFPattern, DIV) { |
| using namespace tvm; |
| using namespace tvm::relay; |
| auto a = WildcardPattern(); |
| auto b = WildcardPattern(); |
| auto pattern = a / b; |
| auto* node = pattern.as<CallPatternNode>(); |
| ICHECK(node); |
| ICHECK(node->args[0] == a); |
| ICHECK(node->args[1] == b); |
| auto* expr_pattern = node->op.as<ExprPatternNode>(); |
| ICHECK(expr_pattern); |
| ICHECK(expr_pattern->expr == Op::Get("divide")); |
| } |
| |
| TEST(DFPattern, OR) { |
| using namespace tvm; |
| using namespace tvm::relay; |
| auto a = WildcardPattern(); |
| auto b = WildcardPattern(); |
| auto pattern = a || b; |
| auto* node = pattern.as<AltPatternNode>(); |
| ICHECK(node); |
| ICHECK(node->left == a); |
| ICHECK(node->right == b); |
| } |
| |
| TEST(DFPattern, Optional) { |
| using namespace tvm; |
| using namespace tvm::relay; |
| DFPattern a = WildcardPattern(); |
| DFPattern b = WildcardPattern(); |
| auto pattern = a.Optional([b](const DFPattern& other) { return other + b; }); |
| auto* node = pattern.as<AltPatternNode>(); |
| ICHECK(node); |
| ICHECK(node->left == a); |
| auto* right_node = node->right.as<CallPatternNode>(); |
| ICHECK(right_node); |
| ICHECK(right_node->args.size() == 2); |
| ICHECK(right_node->args[0] == a); |
| ICHECK(right_node->args[1] == b); |
| auto* expr_pattern = right_node->op.as<ExprPatternNode>(); |
| ICHECK(expr_pattern); |
| ICHECK(expr_pattern->expr == Op::Get("add")); |
| } |
| |
| TEST(DFPattern, HasAttr) { |
| using namespace tvm; |
| using namespace tvm::relay; |
| auto a = WildcardPattern(); |
| Map<String, ObjectRef> attrs; |
| auto b = String("b"); |
| attrs.Set("a", b); |
| auto pattern = a.HasAttr(attrs); |
| auto* node = pattern.as<AttrPatternNode>(); |
| ICHECK(node); |
| ICHECK(node->pattern == a); |
| ICHECK(node->attrs->dict.at("a") == b); |
| } |
| |
| TEST(DFPattern, HasType) { |
| using namespace tvm; |
| using namespace tvm::relay; |
| auto a = WildcardPattern(); |
| TensorType type({1, 2, 3}, DataType(runtime::String2DLDataType("float32"))); |
| auto pattern = a.HasType(type); |
| auto* node = pattern.as<TypePatternNode>(); |
| ICHECK(node); |
| ICHECK(node->pattern == a); |
| ICHECK(node->type == type); |
| } |
| |
| TEST(DFPattern, HasDtype) { |
| using namespace tvm; |
| using namespace tvm::relay; |
| auto a = WildcardPattern(); |
| auto pattern = a.HasDtype("float32"); |
| auto* node = pattern.as<DataTypePatternNode>(); |
| ICHECK(node); |
| ICHECK(node->pattern == a); |
| ICHECK(runtime::DLDataType2String(node->dtype.operator DLDataType()) == "float32"); |
| } |
| |
| TEST(DFPattern, HasShape) { |
| using namespace tvm; |
| using namespace tvm::relay; |
| auto a = WildcardPattern(); |
| Array<PrimExpr> shape{1, 2, 3}; |
| auto pattern = a.HasShape(shape); |
| auto* node = pattern.as<ShapePatternNode>(); |
| ICHECK(node); |
| ICHECK(node->pattern == a); |
| ICHECK(node->shape == shape); |
| } |