blob: 1d9ee6da4f7ac0b8ce8515dd3c617fee49c6ad10 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Workload registration and serialization.
We use a json string to represent a workload (a computation graph).
The format of the string is `[func_name, [args...]]`.
The dag should be the return value of this `func_name(*args)`.
Rationale: The workload is actually a compute dag defined by tvm dsl. But serializing compute dags
and matching them efficiently is not easy. Therefore, we use the above string to encode a compute
dag.
These strings are efficient for serialization/matching and won't be too long.
When we need the dag, we decode the string and call the function, which will return the dag.
"""
import pickle
import json
import tvm._ffi
from .utils import serialize_args, deserialize_args, get_func_name
WORKLOAD_FUNC_REGISTRY = {}
def register_workload(func_name, f=None, override=False):
"""Register a function that generates a certain workload.
The input function should take hashable and jsonable arguments
(int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor.
Parameters
----------
func_name : Union[Function, str]
The generation function that returns the compute declaration Tensors or its function name.
f : Optional[Function]
The generation function to be registered.
override : boolean = False
Whether override existing entry.
Examples
--------
.. code-block:: python
@auto_scheduler.register_workload
def matmul(N, M, K):
A = te.placeholder((N, K), name='A')
B = te.placeholder((K, M), name='B')
k = te.reduce_axis((0, K), name='k')
C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C')
return [A, B, C]
"""
global WORKLOAD_FUNC_REGISTRY
if callable(func_name):
f = func_name
func_name = get_func_name(f)
if not isinstance(func_name, str):
raise ValueError("expect string function name")
def register(myf):
"""internal register function"""
if func_name in WORKLOAD_FUNC_REGISTRY and not override:
raise RuntimeError("%s has been registered already" % func_name)
WORKLOAD_FUNC_REGISTRY[func_name] = myf
return myf
if f:
return register(f)
return register
def make_workload_key(func, args):
"""Make a workload key by function and arguments.
Parameters
----------
func : Union[Function, str]
The function that returns the compute declaration Tensors.
Can be the a function or the function name.
args : Args
The args of the function.
Returns
-------
workload_key : str
The workload key of the function.
"""
global WORKLOAD_FUNC_REGISTRY
if callable(func):
func_name = get_func_name(func)
elif isinstance(func, str):
func_name = func
else:
raise ValueError(
"Invalid function: "
+ str(func)
+ " . `make_workload_key` expects a callable function or its function name"
)
if not func_name in WORKLOAD_FUNC_REGISTRY:
raise ValueError(
"%s is not registered. " % func,
"Please register it with @auto_scheduler.register_workload",
)
args = serialize_args(args)
return json.dumps((func_name,) + args)
def decode_workload_key_to_func_args(workload_key):
"""Decode a workload key to the registered function name and its corresponding args.
Parameters
----------
workload_key : str
The input workload key.
Returns
-------
name : str
The function name of this workload key.
args : List[Tensor]
The args of the generation function.
"""
global WORKLOAD_FUNC_REGISTRY
workload = json.loads(workload_key)
if not workload[0] in WORKLOAD_FUNC_REGISTRY:
raise ValueError(
"%s is not registered. " % workload[0]
+ "Please register it with @auto_scheduler.register_workload"
)
return workload[0], deserialize_args(workload[1:])
@tvm._ffi.register_func("auto_scheduler.workload_key_to_tensors")
def workload_key_to_tensors(workload_key):
"""Get the input/output tensors from the workload key.
This method is usually used to create a ComputeDAG by workload key.
Parameters
----------
workload_key : str
The input workload key.
Returns
-------
tensors : List[Tensor]
The registered compute declaration Tensors.
"""
global WORKLOAD_FUNC_REGISTRY
name, args = decode_workload_key_to_func_args(workload_key)
lookup = WORKLOAD_FUNC_REGISTRY[name]
assert callable(lookup)
return lookup(*args)
def save_workload_func_registry(filename):
"""Dump workload function registry to a pickle binary file.
Parameters
----------
filename : str
The filename to dump workload function registry to.
"""
global WORKLOAD_FUNC_REGISTRY
pickle.dump(WORKLOAD_FUNC_REGISTRY, open(filename, "wb"))
def load_workload_func_registry(filename):
"""Load workload function registry from a pickle binary file.
Parameters
----------
filename : str
The filename to load workload function registry from.
"""
global WORKLOAD_FUNC_REGISTRY
WORKLOAD_FUNC_REGISTRY = pickle.load(open(filename, "rb"))