blob: 7cfe6ccbc2c011acc972266d3a451b99d92451ce [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import
"""
The definition of the "state" in the search.
Each LoopState corresponds to a schedule for its ComputeDAG.
A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to
construct the loop structure.
The loop structure keeps a preview of how the schedule will finally look like after lowering the
current state (e.g. number of iterators, the extent of each iterator, the compute_at locations
...).
During the schedule search process, the loop structure can provide search policy with necessary
information on how to manipulate the current state.
The transform history is a sequence of `TransformStep` which will finally be mapped to TVM
schedule primitives. The steps are also used for the serialization of a state.
The LoopState can be seen as a lightweight loop structure IR specifically for schedule search.
We don't use the existing TVM IR but to extend a new structure on it is because:
1. We want fast incremental change to the loop structures. The search policy needs to get the
immediate loop structures update rather than after TVM lowering;
2. We want serializable transform history for replay, backtracking, and mutation;
3. We may create some macro schedule primitives that represent the combination of several
TVM schedule primitives.
When the search is finished, we will lower the state to TVM IR with TVM's schedule primitives.
Since we share a lot of common objects during search, the transformation is implemented in
copy on write style. All objects are immutable, which is similar to TVM IR.
"""
import tvm._ffi
from tvm.te.tensor import Operation, Tensor
from tvm.runtime import Object
from . import _ffi_api
@tvm._ffi.register_object("auto_scheduler.Iterator")
class Iterator(Object):
""" A loop iterator structure. """
@tvm._ffi.register_object("auto_scheduler.Stage")
class Stage(Object):
""" A stage in the compute declaration. Similar to tvm.te.schedule.Stage. """
# Static trans table for compute_at location
# This is used to transform the compute_at location to C++ enum
COMPUTE_AT_TRANS_TABLE = {"root": 0, "inlined": 1, "iter": 2}
@tvm._ffi.register_object("auto_scheduler.State")
class StateObject(Object):
""" The internal State object """
def __eq__(self, other):
return _ffi_api.StateEqual(self, other)
class State:
"""
A state in the search process. It consists of the current loop structure
and a list of transformation steps used to construct it.
Each State corresponds to a specific schedule for its ComputeDAG.
Parameters
----------
state_object : StateObject
The StateObject corresponding to C++ internal State object.
dag : ComputeDAG
The original ComputeDAG of this State.
Notes
-----
This is a wrapper class of StateObject to deal with copy-on-write property
"""
# Static trans table for thread bind and annotation
# This is used to transform the annotation name to C++ enum
ANNOTATION_TRANS_TABLE = {
"none": 0,
"unroll": 1,
"vectorize": 2,
"parallel": 3,
"vthread": 4,
"blockIdx.x": 5,
"threadIdx.x": 6,
"blockIdx.y": 7,
"threadIdx.y": 8,
"blockIdx.z": 9,
"threadIdx.z": 10,
"tensorize": 11,
}
def __init__(self, state_object, dag):
self.state_object = state_object
self.compute_dag = dag
self.stage_id_map = {} # A dict maps operation to stage id
self._update_stage_id_map()
@property
def stages(self):
"""
Returns
-------
stages : List[Stage]
"""
return self.state_object.stages
@property
def transform_steps(self):
"""
Returns
-------
transform_steps : List[transform_steps]
"""
return self.state_object.transform_steps
@property
def stage_ops(self):
"""
Returns
-------
ops: List[Operation]
"""
return [stage.op for stage in self.stages]
def bind(self, stage, iterator, thread_name):
"""Schedule primitive corresponding to `te.Stage.bind`.
See also the `te.Stage` for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be binded, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to be binded.
thread_name : str
The thread type to be binded. Candidates:
- vthread
- blockIdx.x
- threadIdx.x
- blockIdx.y
- threadIdx.y
- blockIdx.z
- threadIdx.z
Returns
-------
res_it : Iterator
The binded Iterator.
"""
if not thread_name in State.ANNOTATION_TRANS_TABLE.keys():
raise ValueError("Invalid thread_name: ", thread_name)
self.state_object, res = _ffi_api.StateBind(
self.state_object,
self._resolve_stage_id(stage),
iterator,
State.ANNOTATION_TRANS_TABLE[thread_name],
)
return res
def parallel(self, stage, iterator):
"""Schedule primitive corresponding to `te.Stage.parallel`.
See also the `te.Stage` for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be paralleled, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to be paralleled.
Returns
-------
res_it : Iterator
The paralleled Iterator.
"""
self.state_object, res = _ffi_api.StateParallel(
self.state_object, self._resolve_stage_id(stage), iterator
)
return res
def unroll(self, stage, iterator, max_unroll=None):
"""Schedule primitive corresponding to `te.Stage.unroll`.
See also the `te.Stage` for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be unrolled, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to be unrolled.
max_unroll : Optional[int]
The max unroll limit. Iterator with extent larger than this limit will be skipped.
Returns
-------
res_it : Iterator
The unrolled Iterator.
"""
self.state_object, res = _ffi_api.StateUnroll(
self.state_object,
self._resolve_stage_id(stage),
iterator,
max_unroll if max_unroll else -1,
)
return res
def vectorize(self, stage, iterator):
"""Schedule primitive corresponding to `te.Stage.vectorize`.
See also the `te.Stage` for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be vectorized, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to be vectorized.
Returns
-------
res_it : Iterator
The vectorized Iterator.
"""
self.state_object, res = _ffi_api.StateVectorize(
self.state_object, self._resolve_stage_id(stage), iterator
)
return res
def fuse(self, stage, iters):
"""Schedule primitive corresponding to `te.Stage.fuse`.
See also the `te.Stage` for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be fused, which can be specified by the integer index, Operation,
or output tensor of the stage.
iters : List[Iterator]
The iterators to be fused.
Returns
-------
res_it : Iterator
The fused Iterator.
Notes
-----
If the iterators to be fused have stages attached at them(by compute_at), the fused
result will become the new attach point.
"""
self.state_object, res = _ffi_api.StateFuse(
self.state_object, self._resolve_stage_id(stage), iters
)
return res
def pragma(self, stage, iterator, pragma_type):
"""Schedule primitive corresponding to `te.Stage.pragma`.
See also the `te.Stage` for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to add pragma, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to add pragma.
pragma_type : str
The pragma string.
"""
self.state_object = _ffi_api.StatePragma(
self.state_object, self._resolve_stage_id(stage), iterator, pragma_type
)
def reorder(self, stage, order):
"""Schedule primitive corresponding to `te.Stage.reorder`.
See also the `te.Stage` for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be reordered, which can be specified by the integer index, Operation,
or output tensor of the stage.
order : List[Iterator]
Iterators in the expected order.
"""
self.state_object = _ffi_api.StateReorder(
self.state_object, self._resolve_stage_id(stage), order
)
def split(self, stage, iterator, lengths, inner_to_outer=True):
"""Schedule primitive corresponding to `te.Stage.split`.
See also the `te.Stage` for more details.
This API supports multiple split factors. (e.g. with 2 split factors, the original iterator
will be split to 3 parts, use `inner_to_outer` to control the split order)
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be split, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to be split.
lengths: List[int]
The multiple split factors. Can be None to be filled by search policy.
inner_to_outer: boolean = True
Whether the factor go from inner to outer, or from outer to inner.
Returns
-------
res_its : List[Iterator]
The splitted new Iterators.
Notes
-----
If we do split on an iterator which has stages attached at it(by compute_at), the inner
most iterator of split results will become the new attach point.
"""
self.state_object, res = _ffi_api.StateSplit(
self.state_object, self._resolve_stage_id(stage), iterator, lengths, inner_to_outer
)
return res
def follow_split(self, stage, iterator, src_step_id, n_split):
"""The schedule primitive similar to split, but uses split factors from previous steps.
This step splits the iterator by the same factors as the given SplitStep.
Notes
------
This step is useful in a scenario that we have subgraph Dense -> Relu,
and we want to compute the Dense stage at ReLU. In this case, we need them to have
the same tiling structure of common outer loops.
The follow_split step could be used here to split the Dense stage and makes sure its
splitting factors are the same as the given split step for the ReLU stage.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be split, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to split.
src_step_id : int
The index of the split step to be followed in the history.
n_split : int
The number of split level.
Returns
-------
res_its : List[Iterator]
The splitted new Iterators.
"""
self.state_object, res = _ffi_api.StateFollowSplit(
self.state_object, self._resolve_stage_id(stage), iterator, src_step_id, n_split
)
return res
def follow_fused_split(self, stage, iterator, src_step_ids, level, factor_or_nparts):
"""Schedule primitive extends to split step.
This step is used to split an iterator by the same factors
as the given list of SplitSteps and FuseSteps.
Notes
------
This step is useful in a scenario that we have a subgraph
in GPU schedule: Input -> Dense
for i.0@j.0 = ... : Bind to blockIdx.x
for i.1@j.1 = ... : Bind to threadIdx.x
for i.2@j.2 = ...
Input_shared = Input ...
for k = ...
Dense = ...
We intend to apply cooperative fetching with the input stage, while the threadIdx.x
axis is bound to an iterator generated by split & fuse step.
The follow_fused_step is used split the iterator to 2 parts, while the split factor
matches the final extent of the threadIdx.x bound iterator.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be split, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to split.
src_step_ids : List[int]
The indices of the split steps to be followed in the history.
level : int
Use the length in this split level.
factor_or_nparts : bool
True to use `factor` for split from inner to outer,
False to use `nparts` for split from outer to inner.
Returns
-------
res_its : List[Iterator]
The splitted new Iterators.
"""
self.state_object, res = _ffi_api.StateFollowFusedSplit(
self.state_object,
self._resolve_stage_id(stage),
iterator,
src_step_ids,
level,
factor_or_nparts,
)
return res
def storage_align(self, stage, iterator, factor, offset):
"""Schedule primitive corresponding to `te.Stage.storage_align`.
See also the `te.Stage` for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be storage aligned, which can be specified by the integer index,
Operation, or output tensor of the stage.
iterator : Iterator
The iterator to be aligned.
factor : int
The factor in alignment specification.
offset : int
The offset in the alignment specification.
"""
self.state_object = _ffi_api.StateStorageAlign(
self.state_object, self._resolve_stage_id(stage), iterator, factor, offset
)
def compute_at(self, stage, target_stage, target_iter):
"""Schedule primitive corresponding to `te.Stage.compute_at`.
See also the `te.Stage` for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The source Stage of computed at, which can be specified by the integer index,
Operation, or output tensor of the stage.
target_stage : Union[int, Operation, Tensor]
The target stage of compute_at, which can be specified by the integer index, Operation,
or output tensor of the stage.
target_iter : Iterator
The target Iterator of compute_at.
Notes
-----
After compute_at, we need careful dependency analysis to compute the accurate bound
information. However, it is relatively expensive and complicated, so we just fill "None"
as bound for the newly created iterators.
Call ComputeDAG::InferBound on the returned state to get the complete bound information.
"""
self.state_object = _ffi_api.StateComputeAt(
self.state_object,
self._resolve_stage_id(stage),
self._resolve_stage_id(target_stage),
target_iter,
)
def compute_inline(self, stage):
"""Schedule primitive corresponding to `te.Stage.compute_inline`, see also the `te.Stage`
for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be marked compute inlined, which can be specified by the integer index,
Operation, or output tensor of the stage.
"""
self.state_object = _ffi_api.StateComputeInline(
self.state_object, self._resolve_stage_id(stage)
)
def compute_root(self, stage):
"""Schedule primitive corresponding to `te.Stage.compute_root`.
Ssee also the `te.Stage` for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be marked compute at root, which can be specified by the integer index,
Operation, or output tensor of the stage.
Notes
-----
After compute_root, we need careful dependency analysis to compute the accurate bound
information. However, it is relatively expensive and complicated, so we just fill "None"
as bound for the newly created iterators.
Call ComputeDAG::InferBound on the returned state to get the complete bound information.
"""
self.state_object = _ffi_api.StateComputeRoot(
self.state_object, self._resolve_stage_id(stage)
)
def cache_read(self, stage, scope_name, reader_stages):
"""Schedule primitive corresponding to `te.Schedule.cache_read`.
See also the `te.Schedule` for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be cache_read, which can be specified by the integer index, Operation,
or output tensor of the stage.
scope_name : str
The scope name of the newly added read stage.
reader_stages : List[Union[int, Operation, Tensor]]
The reader stages. Each of the list can be specified by the integer index, Operation,
or output tensor of the stage.
Returns
-------
new_stage_op : Operator
The Operator of the new added stage.
Notes
-----
Cache read step will insert an extra stage to the original ComputeDAG (at the back of the
target stage).
"""
reader_stage_ids = [self._resolve_stage_id(i) for i in reader_stages]
self.state_object, new_stage_id = _ffi_api.StateCacheRead(
self.state_object,
self._resolve_stage_id(stage),
scope_name,
reader_stage_ids,
self.compute_dag,
)
# Add a new stage will change all ops behind the added stage. But we still want to keep the
# original ops map, apply stage id offset to stage_id_map to make them work.
self._apply_stage_id_offset(int(new_stage_id))
self._update_stage_id_map()
return self.stages[int(new_stage_id)].op
def cache_write(self, stage, scope_name):
"""Schedule primitive corresponding to `te.Schedule.cache_write`.
See also the `te.Schedule` for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be cache_write, which can be specified by the integer index, Operation,
or output tensor of the stage.
scope_name : str
The scope name of the newly added compute stage.
Returns
-------
new_stage_op : Operator
The Operator of the new added stage.
Notes
-----
Cache write step will insert an extra stage to the original ComputeDAG (in the front of the
target stage).
This step will cache write all output tensors of the target stage.
"""
self.state_object, new_stage_id = _ffi_api.StateCacheWrite(
self.state_object, self._resolve_stage_id(stage), scope_name, self.compute_dag
)
# Add a new stage will change all ops behind the added stage. But we still want to keep the
# original ops map, apply stage id offset to stage_id_map to make them work.
self._apply_stage_id_offset(int(new_stage_id))
self._update_stage_id_map()
return self.stages[int(new_stage_id)].op
def rfactor(self, stage, iterator, factor_iter_id):
"""Schedule primitive corresponding to `te.Schedule.rfactor`.
See also the `te.Schedule` for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be factored, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The reduction iterator to be factored.
factor_iter_id : int
The position where the new iterator is placed.
Returns
-------
new_stage_op : Operator
The Operator of the new added stage.
Notes
-----
Rfactor step will insert an extra stage to the original ComputeDAG (in the front of the
target stage).
"""
self.state_object, new_stage_id = _ffi_api.StateRfactor(
self.state_object,
self._resolve_stage_id(stage),
iterator,
factor_iter_id,
self.compute_dag,
)
# Add a new stage will change all ops behind the added stage. But we still want to keep the
# original ops map, apply stage id offset to stage_id_map to make them work.
self._apply_stage_id_offset(int(new_stage_id))
self._update_stage_id_map()
return self.stages[int(new_stage_id)].op
def copy(self):
""" Do deep copy of this State. """
state = State(self.state_object, self.compute_dag)
state.stage_id_map = self.stage_id_map.copy()
return state
def _resolve_stage_id(self, stage_id):
if isinstance(stage_id, Operation):
return self.stage_id_map[stage_id]
if isinstance(stage_id, Tensor):
return self.stage_id_map[stage_id.op]
if isinstance(stage_id, int):
return stage_id
raise ValueError(
"Invalid stage: " + stage_id + " . Expect to be a int, Operation or Tensor"
)
def _update_stage_id_map(self):
for index, stage in enumerate(self.stages):
self.stage_id_map[stage.op] = index
def _apply_stage_id_offset(self, start_id, offset=1):
for key, value in self.stage_id_map.items():
if value >= start_id:
self.stage_id_map[key] = value + offset
def __getitem__(self, key):
if isinstance(key, Tensor):
key = key.op
if isinstance(key, Operation):
return self.stages[self.stage_id_map[key]]
raise ValueError("Invalid item: " + key + " . Expect to be a Operation or Tensor")
def __str__(self):
return str(self.state_object)
def __eq__(self, other):
return _ffi_api.StateEqual(self.state_object, other.state_object)