blob: c6b30d38edac0475c81364adc8ef377235629603 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Common base structures."""
import tvm._ffi
import tvm.error
import tvm.runtime._ffi_node_api
from tvm.runtime import Object
from . import _ffi_api
from . import json_compact
class Node(Object):
"""Base class of all IR Nodes, implements astext function."""
def astext(self, show_meta_data=True, annotate=None):
"""Get the text format of the expression.
Parameters
----------
show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.
annotate: Optional[Object->str]
Optionally annotate function to provide additional
information in the comment block.
Returns
-------
text : str
The text format of the expression.
Notes
-----
The meta data section is necessary to fully parse the text format.
However, it can contain dumps that are big (e.g constant weights),
so it can be helpful to skip printing the meta data section.
"""
return _ffi_api.AsText(self, show_meta_data, annotate)
def __str__(self):
return _ffi_api.PrettyPrint(self)
@tvm._ffi.register_object("SourceName")
class SourceName(Object):
"""A identifier for a source location.
Parameters
----------
name : str
The name of the source.
"""
def __init__(self, name):
self.__init_handle_by_constructor__(_ffi_api.SourceName, name)
@tvm._ffi.register_object("Span")
class Span(Object):
"""Specifies a location in a source program.
Parameters
----------
source : SourceName
The source name.
lineno : int
The line number.
col_offset : int
The column offset of the location.
"""
def __init__(self, source_name, line, end_line, column, end_column):
self.__init_handle_by_constructor__(
_ffi_api.Span, source_name, line, end_line, column, end_column
)
@tvm._ffi.register_object
class EnvFunc(Object):
"""Environment function.
This is a global function object that can be serialized by its name.
"""
def __call__(self, *args):
return _ffi_api.EnvFuncCall(self, *args)
@property
def func(self):
return _ffi_api.EnvFuncGetPackedFunc(self)
@staticmethod
def get(name):
"""Get a static env function
Parameters
----------
name : str
The name of the function.
"""
return _ffi_api.EnvFuncGet(name)
def load_json(json_str):
"""Load tvm object from json_str.
Parameters
----------
json_str : str
The json string
Returns
-------
node : Object
The loaded tvm node.
"""
try:
return tvm.runtime._ffi_node_api.LoadJSON(json_str)
except tvm.error.TVMError:
json_str = json_compact.upgrade_json(json_str)
return tvm.runtime._ffi_node_api.LoadJSON(json_str)
def save_json(node):
"""Save tvm object as json string.
Parameters
----------
node : Object
A TVM object to be saved.
Returns
-------
json_str : str
Saved json string.
"""
return tvm.runtime._ffi_node_api.SaveJSON(node)
def structural_equal(lhs, rhs, map_free_vars=False):
"""Check structural equality of lhs and rhs.
The structural equality is recursively defined in the DAG of IRNodes.
There are two kinds of nodes:
- Graph node: a graph node in lhs can only be mapped as equal to
one and only one graph node in rhs.
- Normal node: equality is recursively defined without the restriction
of graph nodes.
Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes.
For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal
to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay.
A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var
with the same type if one of the following condition holds:
- They appear in a same definition point(e.g. function argument).
- They points to the same VarNode via the same_as relation.
- They appear in a same usage point, and map_free_vars is set to be True.
The rules for var are used to remap variables occurs in function
arguments and let-bindings.
Parameters
----------
lhs : Object
The left operand.
rhs : Object
The left operand.
map_free_vars : bool
Whether free variables (i.e. variables without a definition site) should be mapped
as equal to each other.
Return
------
result : bool
The comparison result.
See Also
--------
structural_hash
assert_strucural_equal
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
return bool(tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars))
def get_first_structural_mismatch(lhs, rhs, map_free_vars=False):
"""Like structural_equal(), but returns the ObjectPaths of the first detected mismatch.
Parameters
----------
lhs : Object
The left operand.
rhs : Object
The left operand.
map_free_vars : bool
Whether free variables (i.e. variables without a definition site) should be mapped
as equal to each other.
Returns
-------
mismatch: Optional[Tuple[ObjectPath, ObjectPath]]
`None` if `lhs` and `rhs` are structurally equal.
Otherwise, a tuple of two ObjectPath objects that point to the first detected mismtach.
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
mismatch = tvm.runtime._ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars)
if mismatch is None:
return None
else:
return mismatch.lhs_path, mismatch.rhs_path
def assert_structural_equal(lhs, rhs, map_free_vars=False):
"""Assert lhs and rhs are structurally equal to each other.
Parameters
----------
lhs : Object
The left operand.
rhs : Object
The left operand.
map_free_vars : bool
Whether or not shall we map free vars that does
not bound to any definitions as equal to each other.
Raises
------
ValueError : if assertion does not hold.
See Also
--------
structural_equal
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars)
def structural_hash(node, map_free_vars=False):
"""Compute structural hash of node
The structural hash value is recursively defined in the DAG of IRNodes.
There are two kinds of nodes:
- Normal node: the hash value is defined by its content and type only.
- Graph node: each graph node will be assigned a unique index ordered by the
first occurence during the visit. The hash value of a graph node is
combined from the hash values of its contents and the index.
structural_hash is made to be concistent with structural_equal.
If two nodes are structurally equal to each other,
then their structural hash (with the same map_free_vars option)
should be equal to each other as well.
If the structural hash of two nodes equals to each other,
then it is highly likely(except for rare hash value collison cases)
that the two nodes are structurally equal to each other.
Parameters
----------
node : Object
The input to be hashed.
map_free_vars : bool
If map_free_vars is set to true, we will hash free variables
by the order of their occurrences. Otherwise, we will hash by
their in-memory pointer address.
Return
------
result : int
The hash result
See Also
--------
structrual_equal
"""
return tvm.runtime._ffi_node_api.StructuralHash(node, map_free_vars)