We propose to integrate DietCode, an auto-scheduler for dynamic tensor programs, to AutoTIR. DietCode offers the following features:
DietCode has been published by MLSys 2022 so please see the paper for more details and evaluations. Meanwhile, the latest DietCode codebase is also publicly available here.
Achieving high performance for compute-intensive operators in machine learning workloads is a crucial but challenging task. Many machine learning and system practitioners rely on vendor libraries or auto-schedulers to do the job. While the former requires significant engineering efforts, the latter in TVM only supports static-shape workloads in existing works. It is difficult, if not impractical, to apply the existing auto-scheduler directly to dynamic-shape workloads, as this leads to extremely long tuning time.
We observe that the key challenge faced by existing auto-schedulers when handling a dynamic-shape workload is that they cannot construct a conclusive search space for all the possible shapes of the workload, because their search space is shape-dependent. To address this, this RFC aims to add dynamic-shape supports to AutoTIR by integrating DietCode framework, which constructs a shape-generic search space and cost model to auto-schedule dynamic-shape workloads efficiently.
Our evaluation shows that DietCode has the following key strengths when auto-scheduling an entire model end-to-end:
The existing experiments are largely conducted with auto-scheduler. However, having been syncing with the AutoTIR team for quarters, we plan to integrate this RFC to MetaSchedule (AutoTIR), because it provides more systematic interface and cleaner integration path with less hacks.
To provide an example of additional information users are required to feed the system (see https://github.com/UofT-EcoSystem/DietCode/tree/MLSys2022_AE for a PoC design):
# A symbolic shape constraint T = tir.ShapeVar('T’) I = tir.ShapeVar('I') H = tir.ShapeVar('H') # The candidate values of `T` T_vals = range(1, 128) wkl_insts = [] for t in T_vals: wkl_insts.append((t, 768, 768)) wkl_insts.append((t, 768, 3072)) wkl_insts.append((t, 3072, 768)) task = Task(func=Dense, args=(16*T, I, H), shape_vars=(T, I, H), wkl_insts=wkl_insts wkl_inst_weights=([1. for _ in T_vals],))
To enable auto-scheduling for dynamic shape workloads, users only need to:
ShapeVar
in the TE/TensorIR compututation.Notes:
Here is an overview of the DietCode framework design.
We accept the shape variables and the workload instances from the programmer. In the case when they are not detected, the auto-scheduler treats the workload as static and applies and current workflow on it.
We construct a shape-generic search space that consists of micro-kernels, an incomplete program that carries out a tile of the complete computation, to efficiently support dynamic-shape workloads.
We use the hardware constraints (e.g., the maximum number of threads, the amount of shared and local memory) rather than the shape information to determine the micro-kernel candidates. Those candidates serve as the building blocks and are executed repeatedly to carry out a workload instance (defined as an static-shape instance of the dynamic-shape workload).
We build a micro-kernel-based cost model. The key insight is that the cost of a complete program P that is made up of a micro-kernel M can be decomposed into two parts:
While fMK is a function that has to be learned and updated by real hardware measurements during the auto-scheduling process, fadapt is a simple term that can be evaluated using the core occupancy and the padding ratio (in other words, it does not require feature extraction from the schedules).
We generate one kernel per workload instance and use the scikit-learn framework to train a decision tree dispatcher to map the workload instance to its corresponding kernel. The decision tree will be output in predicate-only format for efficient runtime dispatching and embedded as part of the host code. As an example, one possible auto-scheduling outcome can look like the following:
__global__ void default_function0(float* X, float* W, float* Y) {...} __global__ void default_function1(float* X, float* W, float* Y) {...} __global__ void default_function2(float* X, float* W, float* Y) {...} // host code if (T < 16) call(default_function0) else if (T < 64) call(default_function1) else call(default_function2)
Because everything can be included in a single PackedFunc
object, the workflow is fully compatible with the Relay workflow.
__global__ void default_function(float* X, float* W, float* Y, const int T) // Note the `T` here.Our evaluations indicate that this program has at least 5% worse performance compared with the static-shape alternatives. Hence, we decide to sacrifice the binary size for the runtime performance, which can potentially be problematic when the hardware resources are limited.
There is an approach proposed by Nimble, which partitions a range of dynamic shape to buckets and tunes one kernel for each bucket. We could, of course, implement this approach to the current auto-scheduler and AutoTIR. However, as evaluated in the DietCode paper, this approach is not guaranteed to achieve better performance as static shapes.
Reuse-based Tuner
Selective Tuning (Cody Yu. 2019) and ETO (Jingzhi Fang et al. VLDB 2021) group workloads into clusters based on a set of pre-defined rules (e.g., similarity ratio in Selective Tuning) and reuse the same schedule in a single cluster.
Dynamic Neural Networks
Dynamic batching is a common graph-level optimization adopted by frameworks such as DyNet (Graham Neubig et al. 2017), Cavs (Shizhen Xu et al. USENIX ATC 2018), BatchMaker (Pin Gao et al. EuroSys 2018), and TensorFlow Fold (Moshe Looks et al. ICLR 2017) for cases when the batch size is dynamic.
Nimble (Haichen Shen et al. MLSys 2021) and DISC (Kai Zhu et al. EuroMLSys 2021) both design a compiler to represent and execute dynamic neural networks.
Cortex (Pratik Fegade et al. MLSys 2021) is a compiler-based framework on recursive neural networks.
Those works focus on the graph-level optimizations and therefore are orthogonal to DietCode, which operates on each individual layer. In fact, those graph-level solutions can also leverage DietCode for efficient operator code generation.
We propose the following milestones for upstreaming, where each bullet point corresponds to a PR with unit tests of roughly several hundred lines.
When testing, we will be following the same testing procedure with the meta-scheduler. We do not require any extra hardware platforms. Our plan is to use a dynamic-shape workload (i.e., dense from BERT and conv2d from ResNet-50) and compare its performance numbers with those delivered by the meta-scheduler on static-shape workloads. The performance difference is expected to be smaller than 5%.