| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """Wrapping existing transformations.""" |
| # pylint: disable=invalid-name |
| from typing import Optional |
| from . import _ffi_api |
| from . import function_pass as _fpass |
| |
| |
| def Apply(ftransform): |
| """Apply ftransform to each function in the Module. |
| |
| This function is a thin wrapper around tvm.tir.transform.prim_func_pass |
| |
| Parameters |
| ---------- |
| ftransform: tvm.tir.PrimFunc -> tvm.tir.PrimFunc |
| The transformation pass. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| # pylint: disable=unused-argument |
| def _transform(func, mod, ctx): |
| return ftransform(func) |
| |
| return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply") # type: ignore |
| |
| |
| def Filter(fcond): |
| """Filter functions by the calling convention attribute. |
| |
| Parameters |
| ---------- |
| fcond : tvm.tir.PrimFunc -> bool |
| The condition of the filtering. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| # pylint: disable=unused-argument |
| def _transform(func, mod, ctx): |
| return func if fcond(func) else None |
| |
| return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter") # type: ignore |
| |
| |
| def InjectPrefetch(): |
| """Inject prefetch instructions into stmt. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.InjectPrefetch() # type: ignore |
| |
| |
| def StorageFlatten(cache_line_size, create_bound_attribute: bool = False): |
| """Flatten the multi-dimensional read/write to 1D. |
| |
| |
| Parameters |
| ---------- |
| cache_line_size: int |
| The size of CPU cache line. |
| |
| create_bound_attribute: |
| Whether to create bound attributes. |
| |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.StorageFlatten(cache_line_size, create_bound_attribute) # type: ignore |
| |
| |
| def InjectCopyIntrin(pragma_key: str, fintrin): |
| """Inject virtual thread loops. |
| |
| Parameters |
| ---------- |
| pragma_key : str |
| The pragma key for hint of copy. |
| |
| fintrin : function |
| The function with signature copyintrin(src, dst, pad_before, pad_after, pad_value) |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.InjectCopyIntrin(pragma_key, fintrin) # type: ignore |
| |
| |
| def CoProcSync(): |
| """Detect and insert sync points to co-processor. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.CoProcSync() # type: ignore |
| |
| |
| def LiftAttrScope(attr_key: str): |
| """Lift common attrs with attr_key to outer scope. |
| |
| Parameters |
| ---------- |
| attr_key : str |
| The attribute key to be checked. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.LiftAttrScope(attr_key) # type: ignore |
| |
| |
| def LoopPartition(): |
| """Inject virtual thread loops. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.LoopPartition() # type: ignore |
| |
| |
| def VectorizeLoop(enable_vectorize: bool = True): |
| """Lower vectorization loops. |
| |
| Parameters |
| ---------- |
| enable_vectorize : bool |
| Whether vectorization is enabled. |
| Will lower to scalar loop when it is turned off. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.VectorizeLoop(enable_vectorize) # type: ignore |
| |
| |
| def InjectVirtualThread(): |
| """Inject virtual thread loops. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.InjectVirtualThread() # type: ignore |
| |
| |
| def InjectDoubleBuffer(): |
| """Inject double buffer statements. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.InjectDoubleBuffer() # type: ignore |
| |
| |
| def StorageRewrite(): |
| """Rewrite storage allocation pattern. |
| |
| Moves the allocation to outer most possible scope. |
| Trying to share space between allocations to make |
| a static allocation plan when possible. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.StorageRewrite() # type: ignore |
| |
| |
| def UnrollLoop(): |
| """Unroll the constant loop marked by unroll. |
| |
| This pass also automatically attach pragma unroll tag to loops which meets the standard. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.UnrollLoop() # type: ignore |
| |
| |
| def RemoveNoOp(): |
| """Remove No Op from the Stmt. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.RemoveNoOp() # type: ignore |
| |
| |
| def BF16Legalize(): |
| """Legalize bf16 typed Ops. |
| Runs BF16Promote, BF16CastElimination and BF16TypeLowering |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.BF16Legalize() # type: ignore |
| |
| |
| def BF16Promote(): |
| """Promote bf16 to fp32. Add a cast to fp32 |
| before Ops, then add a cast back to bf16. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.BF16Promote() # type: ignore |
| |
| |
| def BF16CastElimination(): |
| """Eliminate verbose casting between fp32 and bf16 |
| Checks if the AST has the pattern: |
| castto32(castto16(some_fp32_op(...))) |
| The verbose casting is generated by BF16Promote for multiple |
| bf16 Ops in a row. e.g.: |
| X[i] + Y[i] + T[i] => |
| bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i]))) |
| After this pass: |
| bf16(float32(X[i]) + float32(Y[i]) + float32(T[i])) |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.BF16CastElimination() # type: ignore |
| |
| |
| def BF16TypeLowering(): |
| """Replace all bf16 type with uint16. Also lower the casting |
| between fp32 and bf16 |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.BF16TypeLowering() # type: ignore |
| |
| |
| def RewriteUnsafeSelect(): |
| """Detect and rewrite unsafe select that contains memory access. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.RewriteUnsafeSelect() # type: ignore |
| |
| |
| def Simplify(): |
| """Run arithmetic simplifications on the statements and expressions. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.Simplify() # type: ignore |
| |
| |
| def InstrumentBoundCheckers(): |
| """Instruments bound checkers. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.InstrumentBoundCheckers() # type: ignore |
| |
| |
| def LowerCustomDatatypes(): |
| """Lower custom datatypes. |
| |
| See tvm::datatypes::Registry for more information on adding custom datatypes. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.LowerCustomDatatypes() # type: ignore |
| |
| |
| def MakePackedAPI(num_unpacked_params: int = -1): |
| """Transform the PrimFuncs in the module to a packed func API. |
| |
| Parameters |
| ---------- |
| num_unpacked_params : int |
| Number of parameters that we hope to directly pass via normal arguments |
| following the PackedFunc input signature. If it is specified as -1 or it |
| is less than the number of arguments, the pass will packed arguments still. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.MakePackedAPI(num_unpacked_params) # type: ignore |
| |
| |
| def MakeUnpackedAPI(): |
| """Transform the PrimFuncs in the module to a C API compatible with internal calls. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.MakeUnpackedAPI() # type: ignore |
| |
| |
| def SplitHostDevice(): |
| """Split the function into a host function and device functions. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.SplitHostDevice() # type: ignore |
| |
| |
| def DecorateDeviceScope(): |
| """Decorate all the function's body as device function. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.DecorateDeviceScope() # type: ignore |
| |
| |
| def SkipAssert(): |
| """Skip assert stmt. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.SkipAssert() # type: ignore |
| |
| |
| def ThreadSync(storage_scope: str): |
| """Insert sync between parallel read/write of shared buffers. |
| |
| Parameters |
| ---------- |
| storage_scope: str |
| The target storage scope. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.ThreadSync(storage_scope) # type: ignore |
| |
| |
| def LowerThreadAllreduce(): |
| """Lower cross thread alleduce. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.LowerThreadAllreduce() # type: ignore |
| |
| |
| def InferFragment(): |
| """Infer the TensorCore fragment infomation using tensor intrinsics. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.InferFragment() # type: ignore |
| |
| |
| def LowerWarpMemory(): |
| """Lower warp memory access to low-level device related function calls. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.LowerWarpMemory() # type: ignore |
| |
| |
| def LowerTVMBuiltin(): |
| """Lower tvm builtin intrinsics. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.LowerTVMBuiltin() # type: ignore |
| |
| |
| def LegalizePackedCalls(): |
| """Legalize packed calls to have its arguments wrapped in TVMValues |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.LegalizePackedCalls() # type: ignore |
| |
| |
| def LowerIntrin(): |
| """Lower target specific intrinsic calls. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.LowerIntrin() # type: ignore |
| |
| |
| def LowerDeviceStorageAccessInfo(): |
| """Lower attached storage access information on device. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| |
| Note |
| ---- |
| Run this pass after all storage access analysis finish. |
| """ |
| return _ffi_api.LowerDeviceStorageAccessInfo() # type: ignore |
| |
| |
| def CombineContextCall(): |
| """Combine context calls in the host function. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.CombineContextCall() # type: ignore |
| |
| |
| def NarrowDataType(target_bits: int): |
| """Narrow down PrimExpr datatype in stmt to target_bits. |
| |
| Parameters |
| ---------- |
| target_bits : int |
| The target bit configuration. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| |
| Note |
| ---- |
| Run this pass after StorageFlatten. |
| """ |
| return _ffi_api.NarrowDataType(target_bits) # type: ignore |
| |
| |
| def VerifyMemory(): |
| """Verify if func contains illegal host side direct memory access. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.VerifyMemory() # type: ignore |
| |
| |
| # pylint: disable=no-else-return,inconsistent-return-statements |
| def HoistIfThenElse(variant: Optional[str] = None): |
| """Hoist loop-invariant IfThenElse nodes to outside the eligible loops. |
| |
| Parameters |
| ---------- |
| variant : Optional[String] |
| The variant of the pass. |
| variant can have any one of following values ["basic", None(Default)]. |
| |
| The basic variant supports basic hoisting scenarios where it expects |
| the For & If Nodes are in place consecutively and does not involve |
| global scope variables or more advanced scenarios. |
| |
| Default variant supports all hoisting scenarios,i.e., {"Basic" + "Advanced"} |
| supported with control with PassContext configs like below: |
| |
| config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| if variant == "basic": |
| return _ffi_api.HoistIfThenElseBasic() # type: ignore |
| elif variant is None: |
| return _ffi_api.HoistIfThenElse() # type: ignore |
| |
| |
| def LowerInitBlock(): |
| """Lower block init stmt into IfThenElse statements. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.LowerInitBlock() # type: ignore |
| |
| |
| def PlanAndUpdateBufferAllocationLocation(): |
| """Locate the buffer allocation to the exact position (usually is |
| the lca of buffer access). This pass will inject opaque block |
| with alloc_buffers at the allocation site. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.PlanAndUpdateBufferAllocationLocation() # type: ignore |
| |
| |
| def ConvertBlocksToOpaque(): |
| """Substitute all the block vars with the PrimExprs they are bound to, indicated by |
| the corresponding iter_values in BlockRealize, and then convert the blocks into |
| opaque ones by removing all the iter_values in BlockRealize and iter_vars in Block. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.ConvertBlocksToOpaque() # type: ignore |
| |
| |
| def CompactBufferAllocation(): |
| """Compact the buffer access region. by removing the buffer regions |
| that are not accessed, i.e. narrowing the buffer shape and adjust |
| the access region if necessary. |
| |
| Example |
| ------- |
| |
| Before narrowing, ``B`` is a ``[16, 16]`` buffer, but only a |
| skinny vector ``B[i, 0:16]`` is accessed. |
| |
| .. code-block:: python |
| |
| for i in range(0, 16): |
| with tir.block([]): |
| B = tir.alloc_buffer(16, 16) |
| for j in range(0, 16): |
| B[i, j] = A[i, j] + 1 |
| for j in range(0, 16): |
| C[i, j] = B[i, j] + 1 |
| |
| This pass narrows the buffer shape and adjust its accessed region |
| accordingly. In this particular case, because only a ``1 * 16`` |
| vector of ``B`` is accessed, the pass narrows ``B`` to shape ``[1, |
| 16]``, and changes the access to ``B[i, j]`` to ``B[0, j]``. |
| |
| .. code-block:: python |
| |
| for i in range(0, 16): |
| with tir.block([]): |
| B = tir.alloc_buffer(1, 16) |
| for j in range(0, 16): |
| B[0, j] = A[i, j] + 1 |
| for j in range(0, 16): |
| C[i, j] = B[0, j] + 1 |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| |
| """ |
| return _ffi_api.CompactBufferAllocation() # type: ignore |
| |
| |
| def LowerMatchBuffer(): |
| """Remove match buffers inside the block. Also, it will validate the binding. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.LowerMatchBuffer() # type: ignore |
| |
| |
| def FlattenBuffer(): |
| """Flatten the multi-dimensional BufferLoad and BufferStore |
| to single dimensional Load/Store. Also remove Block to |
| ensure that the flattened TIR can not be scheduled again. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.FlattenBuffer() # type: ignore |
| |
| |
| def MergeDynamicSharedMemoryAllocations(): |
| """This pass merges multiple TIR-level dynamic shared memory allocations |
| into one allocation. |
| |
| Returns |
| ------- |
| fpass : tvm.transform.Pass |
| The result pass |
| """ |
| return _ffi_api.MergeDynamicSharedMemoryAllocations() # type: ignore |