blob: a6bcc28dad433b9a126ca92202b754a0522b324b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tool to upgrade json from historical versions."""
import json
import tvm.ir
import tvm.runtime
def create_updater(node_map, from_ver, to_ver):
"""Create an updater to update json loaded data.
Parameters
----------
node_map : Map[str, Function]
Map from type_key to updating function
from_ver : str
Prefix of version that we can accept,
to_ver : str
The target version.
Returns
-------
fupdater : function
The updater function
"""
def _updater(data):
assert data["attrs"]["tvm_version"].startswith(from_ver)
nodes = data["nodes"]
for idx, item in enumerate(nodes):
f = node_map.get(item["type_key"], None)
if isinstance(f, list):
for fpass in f:
item = fpass(item, nodes)
elif f:
item = f(item, nodes)
nodes[idx] = item
data["attrs"]["tvm_version"] = to_ver
return data
return _updater
def create_updater_08_to_09():
"""
Create an update to upgrade json from v0.8 to v0.9
Returns
-------
fupdater : function
The updater function
"""
def _initialize_virtual_device(item, _):
if "virtual_device_" not in item["attrs"]:
item["attrs"]["virtual_device_"] = "0"
return item
node_map = {
# Base IR
"GlobalVar": _initialize_virtual_device,
"relay.Var": _initialize_virtual_device,
"relay.Function": _initialize_virtual_device,
"relay.Tuple": _initialize_virtual_device,
"relay.Call": _initialize_virtual_device,
"relay.Let": _initialize_virtual_device,
"relay.If": _initialize_virtual_device,
"relay.TupleGetItem": _initialize_virtual_device,
"relay.RefCreate": _initialize_virtual_device,
"relay.RefRead": _initialize_virtual_device,
"relay.RefWrite": _initialize_virtual_device,
"relay.Match": _initialize_virtual_device,
"relay.Constant": _initialize_virtual_device,
}
return create_updater(node_map, "0.8", "0.9")
def create_updater_07_to_08():
"""Create an update to upgrade json from v0.7 to v0.8"""
def _initialize_module_attributes(item, _):
assert item["type_key"] == "IRModule", "Only initialize the attributes for IRModules"
if "attrs" not in item["attrs"]:
item["attrs"]["attrs"] = "0"
return item
node_map = {"IRModule": _initialize_module_attributes}
return create_updater(node_map, "0.7", "0.8")
def create_updater_06_to_07():
"""Create an update to upgrade json from v0.6 to v0.7
Returns
-------
fupdater : function
The updater function
"""
def _ftype_var(item, nodes):
vindex = int(item["attrs"]["var"])
item["attrs"]["name_hint"] = nodes[vindex]["attrs"]["name"]
# set vindex to null
nodes[vindex]["type_key"] = ""
del item["attrs"]["var"]
assert item["type_key"].startswith("relay.")
item["type_key"] = item["type_key"][len("relay.") :]
return item
def _rename(new_name):
def _convert(item, _):
item["type_key"] = new_name
return item
return _convert
def _update_tir_var(new_name):
def _convert(item, _):
item["type_key"] = new_name
item["attrs"]["type_annotation"] = "0"
return item
return _convert
def _update_global_key(item, _):
if "global_key" in item:
item["repr_str"] = item["global_key"]
del item["global_key"]
return item
def _update_from_std_str(key):
def _convert(item, nodes):
str_val = item["attrs"][key]
jdata = json.loads(tvm.ir.save_json(tvm.runtime.String(str_val)))
root_idx = jdata["root"]
val = jdata["nodes"][root_idx]
sidx = len(nodes)
nodes.append(val)
item["attrs"][key] = "%d" % sidx
return item
return _convert
node_map = {
# Base IR
"SourceName": _update_global_key,
"EnvFunc": _update_global_key,
"relay.Op": [_update_global_key, _rename("Op")],
"relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")],
"TypeVar": _update_from_std_str("name_hint"),
"relay.Id": [_update_from_std_str("name_hint")],
"relay.GlobalTypeVar": [_ftype_var, _update_from_std_str("name_hint")],
"GlobalTypeVar": _update_from_std_str("name_hint"),
"relay.Type": _rename("Type"),
"relay.TupleType": _rename("TupleType"),
"relay.TypeConstraint": _rename("TypeConstraint"),
"relay.FuncType": _rename("FuncType"),
"relay.IncompleteType": _rename("IncompleteType"),
"relay.TypeRelation": _rename("TypeRelation"),
"relay.TypeCall": _rename("TypeCall"),
"relay.Constructor": _update_from_std_str("name_hint"),
"relay.Module": _rename("IRModule"),
"relay.SourceName": _rename("SourceName"),
"relay.Span": _rename("Span"),
"relay.GlobalVar": [_rename("GlobalVar"), _update_from_std_str("name_hint")],
"GlobalVar": _update_from_std_str("name_hint"),
"relay.Pass": _rename("transform.Pass"),
"relay.PassInfo": _rename("transform.PassInfo"),
"relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequential": _rename("transform.Sequential"),
"StrMap": _rename("Map"),
# TIR
"Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")],
"SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")],
"StringImm": [_rename("tir.StringImm"), _update_from_std_str("value")],
"Cast": _rename("tir.Cast"),
"Add": _rename("tir.Add"),
"Sub": _rename("tir.Sub"),
"Mul": _rename("tir.Mul"),
"Div": _rename("tir.Div"),
"Mod": _rename("tir.Mod"),
"FloorDiv": _rename("tir.FloorDiv"),
"FloorMod": _rename("tir.FloorMod"),
"Min": _rename("tir.Min"),
"Max": _rename("tir.Max"),
"EQ": _rename("tir.EQ"),
"NE": _rename("tir.NE"),
"LT": _rename("tir.LT"),
"LE": _rename("tir.LE"),
"GT": _rename("tir.GT"),
"GE": _rename("tir.GE"),
"And": _rename("tir.And"),
"Or": _rename("tir.Or"),
"Not": _rename("tir.Not"),
"Select": _rename("tir.Select"),
"Load": _rename("tir.Load"),
"BufferLoad": _rename("tir.BufferLoad"),
"Ramp": _rename("tir.Ramp"),
"Broadcast": _rename("tir.Broadcast"),
"Shuffle": _rename("tir.Shuffle"),
"Call": [_rename("tir.Call"), _update_from_std_str("name")],
"Let": _rename("tir.Let"),
"Any": _rename("tir.Any"),
"LetStmt": _rename("tir.LetStmt"),
"AssertStmt": _rename("tir.AssertStmt"),
"Store": _rename("tir.Store"),
"BufferStore": _rename("tir.BufferStore"),
"BufferRealize": _rename("tir.BufferRealize"),
"Allocate": _rename("tir.Allocate"),
"IfThenElse": _rename("tir.IfThenElse"),
"Evaluate": _rename("tir.Evaluate"),
"Prefetch": _rename("tir.Prefetch"),
"AttrStmt": [_rename("tir.AttrStmt"), _update_from_std_str("attr_key")],
"Layout": [_rename("tir.Layout"), _update_from_std_str("name")],
"Buffer": [
_rename("tir.Buffer"),
_update_from_std_str("name"),
_update_from_std_str("scope"),
],
}
return create_updater(node_map, "0.6", "0.7")
def upgrade_json(json_str):
"""Update json from a historical version.
Parameters
----------
json_str : str
A historical json file.
Returns
-------
updated_json : str
The updated version.
"""
data = json.loads(json_str)
from_version = data["attrs"]["tvm_version"]
if from_version.startswith("0.6"):
data = create_updater_08_to_09()(create_updater_07_to_08()(create_updater_06_to_07()(data)))
elif from_version.startswith("0.7"):
data = create_updater_08_to_09()(create_updater_07_to_08()(data))
elif from_version.startswith("0.8"):
data = create_updater_08_to_09()(data)
else:
raise ValueError("Cannot update from version %s" % from_version)
return json.dumps(data, indent=2)