blob: df12aaece2da46994335f192102ed96093e064bd [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Algebraic data types in Relay."""
from tvm.ir import Constructor, TypeData
from tvm.runtime import Object
import tvm._ffi
from .base import RelayNode
from . import _ffi_api
from .ty import Type
from .expr import ExprWithOp, RelayExpr, Call
class Pattern(RelayNode):
"""Base type for pattern matching constructs."""
@tvm._ffi.register_object("relay.PatternWildcard")
class PatternWildcard(Pattern):
"""Wildcard pattern in Relay: Matches any ADT and binds nothing."""
def __init__(self):
"""Constructs a wildcard pattern.
Parameters
----------
None
Returns
-------
wildcard: PatternWildcard
a wildcard pattern.
"""
self.__init_handle_by_constructor__(_ffi_api.PatternWildcard)
@tvm._ffi.register_object("relay.PatternVar")
class PatternVar(Pattern):
"""Variable pattern in Relay: Matches anything and binds it to the variable."""
def __init__(self, var):
"""Construct a variable pattern.
Parameters
----------
var: tvm.relay.Var
Returns
-------
pv: PatternVar
A variable pattern.
"""
self.__init_handle_by_constructor__(_ffi_api.PatternVar, var)
@tvm._ffi.register_object("relay.PatternConstructor")
class PatternConstructor(Pattern):
"""Constructor pattern in Relay: Matches an ADT of the given constructor, binds recursively."""
def __init__(self, constructor, patterns=None):
"""Construct a constructor pattern.
Parameters
----------
constructor: Constructor
The constructor.
patterns: Optional[List[Pattern]]
Optional subpatterns: for each field of the constructor,
match to the given subpattern (treated as a variable pattern by default).
Returns
-------
wildcard: PatternWildcard
a wildcard pattern.
"""
if patterns is None:
patterns = []
self.__init_handle_by_constructor__(_ffi_api.PatternConstructor, constructor, patterns)
@tvm._ffi.register_object("relay.PatternTuple")
class PatternTuple(Pattern):
"""Constructor pattern in Relay: Matches a tuple, binds recursively."""
def __init__(self, patterns=None):
"""Construct a tuple pattern.
Parameters
----------
patterns: Optional[List[Pattern]]
Optional subpatterns: for each field of the constructor,
match to the given subpattern (treated as a variable pattern by default).
Returns
-------
wildcard: PatternWildcard
a wildcard pattern.
"""
if patterns is None:
patterns = []
self.__init_handle_by_constructor__(_ffi_api.PatternTuple, patterns)
@tvm._ffi.register_object("relay.Clause")
class Clause(Object):
"""Clause for pattern matching in Relay."""
def __init__(self, lhs, rhs):
"""Construct a clause.
Parameters
----------
lhs: tvm.relay.Pattern
Left-hand side of match clause.
rhs: tvm.relay.Expr
Right-hand side of match clause.
Returns
-------
clause: Clause
The Clause.
"""
self.__init_handle_by_constructor__(_ffi_api.Clause, lhs, rhs)
@tvm._ffi.register_object("relay.Match")
class Match(ExprWithOp):
"""Pattern matching expression in Relay."""
def __init__(self, data, clauses, complete=True):
"""Construct a Match.
Parameters
----------
data: tvm.relay.Expr
The value being deconstructed and matched.
clauses: List[tvm.relay.Clause]
The pattern match clauses.
complete: Optional[Bool]
Should the match be complete (cover all cases)?
If yes, the type checker will generate an error if there are any missing cases.
Returns
-------
match: tvm.relay.Expr
The match expression.
"""
self.__init_handle_by_constructor__(_ffi_api.Match, data, clauses, complete)