blob: 2558c7669ac97db92be60918bf899e0cfc7b1473 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-variable,invalid-name,unused-argument
Decorators for registering tunable templates to TOPI.
These decorators can make your simple implementation be able to use different configurations
for different workloads.
Here we directly use all arguments to the TOPI call as "workload", so make sure all the arguments
(except tvm.te.Tensor) in you calls are hashable. For tvm.te.Tensor,
we will serialize it to a hashable tuple.
See tvm/topi/python/topi/arm_cpu/ for example usage.
import tvm.te._ffi_api
from import Target
from tvm.te import tensor
from .task import (
# Task extractor for relay program
class TaskExtractEnv:
"""Global environment for extracting tuning tasks from graph"""
current = None
registered = None
def __init__(self, allow_duplicate=False):
self.allow_duplicate = allow_duplicate
self.task_collection = []
self.wanted_relay_ops = None
self.modified_funcs = []
self.tracing = False
def __enter__(self):
self.task_collection = []
self.tracing = True
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.tracing = False
def reset(self, wanted_relay_ops=None):
"""Reset task collections
wanted_relay_ops: List of
The relay ops to be extracted
self.task_collection = []
self.wanted_relay_ops = wanted_relay_ops
def add_task(self, task_name, args):
"""Add AutoTVM task
task_name: str
AutoTVM task name.
args: tuple
Arguments to the TOPI function.
key = (task_name, serialize_args(args))
if self.allow_duplicate or key not in self.task_collection:
def get_tasks(self):
"""Get collected tasks
tasks: List of tuple(name, args)
A list of tasks extracted from the graph
return self.task_collection
def get(allow_duplicate=False):
"""Get the single instance of TaskExtractEnv
allow_duplicate : boolean
Whether to fetch all workloads in the network,
even though some of them are the same. This is
useful for graph tuning.
env: TaskExtractEnv
The single instance of TaskExtractEnv
if not TaskExtractEnv.current:
TaskExtractEnv.current = TaskExtractEnv(allow_duplicate)
TaskExtractEnv.current.allow_duplicate = allow_duplicate
return TaskExtractEnv.current
def register_topi_compute(task_name, func=None):
"""Register a tunable template for a topi compute function.
The registration will wrap this topi compute to take `cfg` as the first argument,
followed by the original argument list. It uses all its argument as workload and
stores this "workload" to its final ComputeOp, which can be used to reconstruct
"workload" in the following topi_schedule call.
task_name: str
The AutoTVM task name
func: None or callable
If it is None, return a decorator.
If is callable, decorate this function.
decorator: callable
A decorator
See tvm/topi/python/topi/arm_cpu/ for example usage.
def _decorate(topi_compute):
def wrapper(*args, **kwargs):
"""wrapper function for topi compute"""
assert not kwargs, "Do not support kwargs in template function call"
task_env = TaskExtractEnv.current
if task_env is not None and task_env.tracing:
task_env.add_task(task_name, args)
workload = args_to_workload(args, task_name)
tgt = Target.current()
cfg = DispatchContext.current.query(tgt, workload)
node = topi_compute(cfg, *args)
# attach workload to return op
op = node.op
attrs = {}
for k, v in node.op.attrs.items():
attrs[k] = v
attrs["workload"] = workload
if isinstance(op, tensor.ComputeOp):
op = tvm.te._ffi_api.ComputeOp(, op.tag, attrs, op.axis, op.body)
elif isinstance(op, tensor.ExternOp):
op = tvm.te._ffi_api.ExternOp(,
raise RuntimeError("Unsupported op type: " + str(type(op)))
if isinstance(node, tensor.Tensor):
return op.output(0)
return [op.output(i) for i in range(len(node))]
return wrapper
if func:
return _decorate(func)
return _decorate
def register_topi_schedule(task_name, func=None):
"""Register a tunable template for a topi schedule function.
The registration will wrap this topi schedule to take `cfg` as the first argument,
followed by the original argument list.
Note that this function will try to find "workload" from all the ComputeOp in the input.
You can attach "workload" to your compute op by using :any:`register_topi_compute`.
The task name has to be the same as that of the corresponding topi compute function.
task_name: str
The AutoTVM task name
func: None or callable
If it is None, return a decorator.
If is callable, decorate this function.
decorator: callable
A decorator
See tvm/topi/python/topi/arm_cpu/ for example usage.
def _decorate(topi_schedule):
def wrapper(outs, *args, **kwargs):
"""wrapper function for topi schedule"""
workload = get_workload(outs, task_name)
if workload is None:
raise RuntimeError("Cannot find workload in attribute of this schedule")
tgt = Target.current()
cfg = DispatchContext.current.query(tgt, workload)
return topi_schedule(cfg, outs, *args, **kwargs)
return wrapper
if func:
return _decorate(func)
return _decorate
def get_workload(outs, task_name=None):
"""Retrieve the workload from outputs"""
def traverse(tensors):
"""traverse all ops to find attached workload"""
for t in tensors:
op = t.op
wkl = traverse(op.input_tensors)
if wkl:
return wkl
if "workload" in op.attrs:
ret = args_to_workload(op.attrs["workload"])
if task_name is None or ret[0] == task_name:
return ret
return None
outs = [outs] if isinstance(outs, tensor.Tensor) else outs
return traverse(outs)