The goal of this RFC is to extend the capability of tracing source information between different IRs for the debugging uses. Three features get benefit from this change as following:
These changes provide users a clear backtrace of an IR in CLI text format. Furthermore, paired with our on-going project TVM Explorer
, a colorful and convenient GUI can improve the user experience even better. We will demonstrate the use cases of TVM Explorer
with examples in the following sections.
We aim to ease debugging process by enhancing and creating the features to carry source information. TVM performs numbers of transformations to optimize and deploy a ML frontend IR to a targeted device. However, currently modules which record source information between IRs are not fully used. It makes users hard to trace the source of a transformed IR. Usually an investigation to source code should be done so as to understand details of a transformation.
We provide the following enhancements to mitigate users' effort by recording source information between IR and schedules of op implementation:
Frontend span filler
: Fill the layer name to Relay IR during the frontend conversion.Pass source information builder
: Construct SequentialSpan
from Span
and SIBuilder
to handle source information for both Relay IR and TIR.Schedule/Stage visualization enhancement
: Record and propagate op's schedule snapshots with primitives applied in regular build flow.After these modifications, user can obtain the source information simply by a glance or via debugger.
Finally, inspired by Compiler Explorer, we build a web-GUI, TVM Explorer for TVM. Based on the infrastructures above, TVM Explorer
provides a batter user experience when comparing IRs or analyzing schedules (the code base of TVM Explorer
is maintained in another git repository and not included in this RFC).
Based on the ExprMutator, we implement set_span
to recursively fill the source information to Relay IR during the op conversion. We could obtain the Relay IR with span even in an one-to-many conversion. Take Pack op from TF for example, it inserts multiple expand_dims during conversion:
# implement of pack TF conversion def _pack(): def _impl(inputs, attr, params, mod): axis = int(attr["axis"]) inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs] return _op.concatenate(inputs_reshaped, axis) return _impl # After convert an op from frontend ret = self.convert_map[op_code_str](op) ret = set_span(ret, frontend_layer_name) ''' The result after set_span of a pack op conversion def @main (%input: Tensor[(?, ?, 3, 1), float32]) { %0 = shape_of(%input, dtype="int32") /* Shape */; %1 = strided_slice(%0, …) /* strided_slice */; %2 = squeeze(%1) /* strided_slice */; } ======> def @main (%input: Tensor[(?, ?, 3, 1), float32]) { %0 = shape_of(%input, dtype="int32") /* Shape */; %1 = strided_slice(%0, …) /* strided_slice */; %2 = squeeze(%1) /* strided_slice */; %3 = expand_dims(%2, axis=0) /* stack */; %4 = expand_dims(3, axis=0) /* stack */; %5 = expand_dims(3, axis=0) /* stack */; %6 = (%3, %4, %5) /* stack */; %7 = concatenate(%6) /* stack */; } '''
To manage the span propagation in passes, we extend SequentialSpan
from Span
, and create a new class SIBuilder
. First, we construct a container class, SequentialSpan
to carry a set of source spans in its member variable for those many-to-n (n>=1) conversion, which is common in transformations between passes:
// C++ SequentialSpan new_span{expr_1->span, expr_2->span}
# Python relay.SequentialSpan([expr_1, expr_2])
Take the IfNode
condition in FoldConstant
pass for example. When the condition is a constant, FoldConstant
extracts the expression of the triggered path as the result. We create a SequentialSpan
to keep the existent span from the selected branch and the span from discarded If
expression.
Expr VisitExpr_(const IfNode* if_node) final { If new_if = Downcast<If>(ExprMutator::VisitExpr_(if_node)); if (const auto* const_node = AsIgnoringOnDevice<ConstantNode>(new_if->cond)) { Expr ret; if (reinterpret_cast<uint8_t*>(const_node->data->data)[0]) { ret = new_if->true_branch; } else { ret = new_if->false_branch; } ret->span = SequentialSpan({ret->span, new_if->span}); return ret; } return std::move(new_if); }
On the other hand, SIBuilder
aims to ease the developers' workload when filling span in the pass transformation. Based on our experiences when filling span to existing passes, we provide two functionalities in SIBuilder
. First, RecursivelyFillSpan
provides an easy way to automatically fill up source span to those conversions which result in multiple expressions. Given a source span, RecursivelyFillSpan
applies DFS traversal from “start_expression” and fill the source span until it encounters any of those given inputs.
SIBuilder si_builder(source_span); sibuilder->RecursivelyFillSpan(start_expression, {inputs_of_the_first_new_generated_expr});
An use case of RecursivelyFillSpan
is SimplifyInference
. This pass simplifies certain operators during inference. Take BatchNorm for example, SimplifyInference
unpacks the Call
of BatchNorm and its TupleGetItem
indexed at 0 to several simplified expressions. In this case we can invoke RecursivelyFillSpan
to fill span to those new generated expressions once for all.
Expr BatchNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, Type tdata, Span span) { auto ttype = tdata.as<TensorTypeNode>(); ICHECK(ttype); const auto param = attrs.as<BatchNormAttrs>(); Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon)); Expr var_add_eps = Add(moving_var, epsilon); Expr sqrt_var = Sqrt(var_add_eps); Expr scale = Divide(MakeConstantScalar(ttype->dtype, 1.0f), sqrt_var); //... Expr out = Multiply(data, scale); out = Add(out, shift); SIBuilder si_builder(span); si_builder.RecursivelyFillSpan(/* entry */ out, /* inputs */ {data, gamma, beta, moving_mean, moving_var}); return out; }
Second, SIBuilder
provides a constructor to collect a continuous sequence of source spans. Starts from entry, it puts the span of an Expr
to its array member variable, and continues the traversal until hits the inputs. Finally, invoke CreateSpan
from the created SIBuilder
instance to obtain the source span.
SIBuilder si_builder(entry_expr, {inputs}); new_span = si_builder.CreateSpan();
This constructor works properly in SimplifyExpr
pass. A pattern of SimplifyExpr
is SimplifyReshape
, one of its patterns is an expression followed by two consecutive rehsapes or contrib_reverse_reshapes. In this case we can use the constructor of SIBuilder
above to obtain all source spans of the matched pattern.
class SimplifyReshape : public DFPatternRewrite { public: SimplifyReshape() { x_ = IsWildcard(); auto reshape1 = IsOp("reshape") || IsOp("contrib_reverse_reshape"); auto reshape2 = IsOp("reshape") || IsOp("contrib_reverse_reshape"); pattern_ = reshape1({reshape2({x_})}); } Expr Callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) const override { //... if (const_shape) { auto x = node_map[x_][0]; auto ret = MakeReshape(x, newshape); SIBuilder si_builder(/* entry */ node_map[pattern_][0], /* inputs */ {x}); ret->span = si_builder.CreateSpan(); return ret; //... };
Based on the classes above, we have filled span to all relay passes in the build flow.
Tensor Expressions are scheduled with primitives, it becomes complicated quickly with the increasing number of applied primitives. Although TEDD
(Tensor Expression Debug Display) already provides a mechanism to visualize different kinds of schedule diagrams(Schedule Tree, Itervar Relationship and Dataflow). The resulting information still seems hard to recognize the effect of each applied primitive.
We propose a change to record the snapshot of schedule after each new primitive is applied by introducing some modifications to the interface of Schedule
/Stage
class. In order to inspect the schedules created inside TVM build flow, new APIs will also be added.
By doing so, we can leverage TEDD
to display a sequential schedule diagrams, the followings are the snippet of driving code and the corresponding result:
# load TFLite model tflite_model_buf = open('mobilenet.tflite', "rb").read() model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) input_shape = {'input': (1, 224, 224, 3)} mod, params = relay.frontend.from_tflite(model, input_shape) # invoke build process with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, 'llvm', params=params) # (new API) get schedule from 17th node in TVM graph sch = lib.build_module.get_node_schedule(17) # (new API) get schedule record (snapshots of schedule) schs = sch.schedule_record # the second to last schedule ori_dot = tedd.viz_schedule_tree(schs[-2].normalize(), dot_file_path="ori.dot") # the last schedule with all the optimization strategies cmp_dot = tedd.viz_schedule_tree(schs[-1].normalize(), dot_file_path="cmp.dot")
Inspired by Compiler Explorer, TVM Explorer
is our on-going project which is a web-GUI to investigate TVM behaviors. Based on the infrastructures above, TVM Explorer
achieves the following goals:
TVM Explorer
provides mechanism to visualize the computation graph generated from GraphExecutor
. With the proposed changes, the data structure of Schedule
will be kept inside each graph node where users are able to visualize the implementation details:This feature had been introduced previously in PR-9723, but was reverted because the unexpected duplicated expressions problem in PR-10072. We fix the issue in PR-10072 and propose a modified version with the following differences:
set_span
in each required condition, and did not handle tuple/list type properly during PyTorch conversion. It resulted in duplicated expressions were generated. After the investigation, we insert set_span
to each required place to avoid duplication.TVM_SPANFILLING
to disable/enable span filling:export TVM_SPANFILLING=0
to disable the procedure.def set_span(sym, span): """Set up the sapn of relay expression(s) while converting OP""" class SpanFiller(ExprMutator): """SpanFiller""" return SpanFiller(span).fill(sym) if _should_fill_span() else sym
The following is the details of set_span
. The constructor now accepts both string and span format as its source information. The function fill
accepts types in the whitelist to prevent unexpected symbol. The function visit
stop traversal deeper once the flow hits an expression with span. In the dispatched visit
function like visit_call
, SpanFiller
reconstructs and returns a new expression with the given span.
class SpanFiller(ExprMutator): """SpanFiller""" def __init__(self, span): ExprMutator.__init__(self) if isinstance(span, tvm.relay.Span): self._span = span elif isinstance(span, str): self._span = tvm.relay.Span(tvm.relay.SourceName(span), 0, 0, 0, 0) else: assert False, f"unsupported span type: {type(span)}" def visit(self, expr): if hasattr(expr, "span") and expr.span: return expr def visit_call(self, call): new_args = [self.visit(arg) for arg in call.args] return _expr.Call(call.op, new_args, call.attrs, call.type_args, self._span) #... def fill(self, sym): if isinstance(sym, _expr.TupleWrapper): return _expr.TupleWrapper(self.visit(sym.tuple_value), sym.size) elif isinstance(sym, _expr.RelayExpr): return self.visit(sym) elif isinstance(sym, list): assert all( isinstance(expr, _expr.TupleGetItem) for expr in sym ), f"unexpected relay expressions in {sym}" return [self.visit(expr) for expr in sym] elif isinstance(sym, tuple): assert all( isinstance(expr, _expr.RelayExpr) for expr in sym ), f"unexpected relay expressions in {sym}" return tuple(self.visit(expr) for expr in sym) assert False, f"unsupported type {type(sym)}"
Span
, SequentialSpan
can accept and put a sequence of Span
to its tvm::Array
. For those many-to-n (n>=1) transformations, SequentialSpan
is a good container to carry their source. When comparing the equalness between two SequentialSpan
, simply fall back to the equalness of each span to obtain the result iteratively.class SequentialSpanNode : public SpanNode { public: /*! \brief A list of spans that used to compose a sequential span. */ tvm::Array<Span> spans; static constexpr const char* _type_key = "SequentialSpan"; bool SEqualReduce(const SequentialSpanNode* other, SEqualReducer equal) const; TVM_DECLARE_FINAL_OBJECT_INFO(SequentialSpanNode, SpanNode); }; class SequentialSpan : public Span { public: TVM_DLL SequentialSpan(Array<Span> spans); TVM_DLL SequentialSpan(std::initializer_list<Span> init); };
SIBuilder
provides two functionalities for both Relay/TIR pass transformations. One is recursively filling spans to those new generated expressions without span. Another is collecting source spans from a contiguous sequence of expressions. The following UML demonstrates the overview of SIBuilder
:class SIBuilder { public: explicit SIBuilder(const Span& span = Span()); /*! * \brief Create SIBuilder via a subgraph, * will construct span based on the exprs falls in the subgraph * * \param entry Entry expr for subgraph * \param inputs End exprs for subgraph */ template <typename T, typename = std::enable_if_t<std::is_base_of<BaseExpr, T>::value>> explicit SIBuilder(const T& entry, const tvm::Array<T>& inputs = {}); explicit SIBuilder(const tir::Stmt& entry, const tvm::Array<PrimExpr>& inputs = {}); explicit SIBuilder(const tir::Stmt& entry, const tvm::Array<tir::Stmt>& inputs = {}); ~SIBuilder(); SIBuilder(const SIBuilder&) = delete; SIBuilder& operator=(const SIBuilder&) = delete; /*! * \brief create new source info based on span_buffer_. * * \return The span. */ Span CreateSpan() const; /*! * \brief Recursively fill subgraphs exprs' span * * \param entry Entry expr for subgraph * \param inputs End exprs for subgraph */ template <typename T, typename = std::enable_if_t<std::is_base_of<BaseExpr, T>::value>> void RecursivelyFillSpan(const T& entry, const std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>& inputs) const; void RecursivelyFillSpan(const tir::Stmt& entry, const std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual>& inputs) const; void RecursivelyFillSpan(const tir::Stmt& entry, const std::unordered_set<tir::Stmt, ObjectPtrHash, ObjectPtrEqual>& inputs) const; private: struct Impl; std::unique_ptr<Impl> impl_; std::unique_ptr<Impl> CreateImpl(const Span& span); };
Start from the RecursivelyFillSpan
we will describe how to fill a given span to those new generated expressions. Take the RelayRecursivelyFill
for Relay type as an example, it inherits from ExprMutator
to traverse the given expressions. If the visited expression is one of the inputs, it stops the traversal. Otherwise RecursivelyFillSpan
dispatches to the corresponding type, sets up the span, and traverses deeper.
class RelayRecursivelyFill : public relay::ExprMutator { public: RelayRecursivelyFill(const Span& span, const RelayExprSet& inputs = {}) : span_(span), inputs_(inputs) {} void Fill(const relay::Expr& entry); relay::Expr VisitExpr(const relay::Expr& expr) final; relay::Expr VisitExpr_(const relay::CallNode* call_node) final; // other types... private: const Span& span_; const RelayExprSet& inputs_; }; relay::Expr RelayRecursivelyFill::VisitExpr(const relay::Expr& expr) { //... if (inputs_.find(expr) != inputs_.end()) { return expr; } //... } relay::Expr RelayRecursivelyFill::VisitExpr_(const relay::CallNode* call_node) { call_node->span = span_; return relay::ExprMutator::VisitExpr_(call_node); }
On the other hand, the constructor of SIBuilder
accepts an entry and a set of inputs to collect all of the source information. The core functionality for Relay is implemented by the class RelayCollapse
, which inherits from ExprVisitor
. Visitor function Collapse
acts in a similar way to RecursivelyFill
, it starts from the entry, put the span of an expression to its array member variable, and continues the traversal until hits the inputs. The collected spans can be produced by invoking the CreateSpan
function from the SIBuilder
instance.
class RelayCollapse : public relay::ExprVisitor { public: RelayCollapse(const RelayExprSet& inputs = {}) : inputs_(inputs) {} Span Collapse(const relay::Expr& entry); void VisitExpr(const relay::Expr& expr) final; private: tvm::Array<Span> spans_; const RelayExprSet& inputs_; }; void RelayCollapse::VisitExpr(const relay::Expr& expr) { // ... if (expr->span.defined()) { spans_.push_back(expr->span); } if (inputs_.find(expr) != inputs_.end()) { visit_counter_.emplace(expr.get(), 1); return; } // ... } Span RelayCollapse::Collapse(const relay::Expr& entry) { VisitExpr(entry); return SequentialSpan(spans_); }
Finally, SIbuilder can be disabled by the setting of ir.enable_si_builder
in the config of PassContext
:
TVM_REGISTER_PASS_CONFIG_OPTION("ir.enable_si_builder", Bool);
Schedule Record:
To inspect the series of Schedule
transformations, new member variables are introduced to store the objects.
// ${TVM}/include/tvm/te/schedule.h class ScheduleNode : public Object { public: ... /*! * \brief list of all schedules during primitives applied to stages. */ Array<Schedule> schedule_record; /*! * \brief Flag to keep schedule record or not. */ bool keep_schedule_record; ... };
For every Stage
inside a Schedule
, it needs to know what current Schedule
is and appends the snapshot of Schedule
after a primitive applied.
// ${TVM}/include/tvm/te/schedule.h class Stage : public ObjectRef { public: ... explicit Stage(Operation op, Schedule& sch); ... /*! * \brief Not functional currently. */ TVM_DLL void EnterWithScope(); /*! * \brief Store current schedule after primitive being applied. */ TVM_DLL void ExitWithScope(); ... };
Semantic “With” is used here:
// ${TVM}/src/te/schedule/schedule_lang.cc void Schedule::EnterWithScope() {} void Schedule::ExitWithScope() { ScheduleNode* sch_node = operator->(); if (sch_node->keep_schedule_record) { sch_node->schedule_record.push_back(copy()); } }
All primitives could leverage the mechanism above to record the status of Schedule
, take “parallel” primitive as an example:
Stage& Stage::parallel(IterVar var) { // NOLINT(*) + With<Schedule> sch_scope(operator->()->attach_sch); SetAttrIterType(operator->(), var, kParallelized); return *this; }
The effect can be explained in the following snippet:
def schedule_record_with_gemm(): M, K, N = 1024, 1024, 1024 k = te.reduce_axis((0, K), "k") A = te.placeholder((M, K), name="A") B = te.placeholder((K, N), name="B") C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k), name="C") s = te.create_schedule(C.op) # currently there are no other applied primitives # size of schedule record is expected to be 1 (vanilla schedule) assert len(s.schedule_record) == 1 # let's apply sequential optimization primitives block_size, factor = 32, 8 # tile -> split + split + reorder mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], block_size, block_size) ko, ki = s[C].split(k, factor=factor) s[C].reorder(mo, ko, no, mi, ki, ni) s[C].vectorize(ni) s[C].parallel(mo) # the primitives inside schedule record are (primitive type and its store order): # vanilla(1), split(2), split(3), reorder(4), split(5), reorder(6), vectorize(7), parallel(8) assert len(s.schedule_record) == 8
Schedule Propagation:
By investigating the TVM build flow (Relay to a target executable), the Schedule
instance will be stored in the attribute of CallNode
inside MakeLoweredCall
and retrieved in GraphExecutorCodegen
process (i.e. schedules will finally be kept in corresponding graph nodes)
Finally, a series of APIs will be created accordingly for user to access the Schedule
instance from Relay build module.
source_name
member inside Span is used to achieve source mapping mechanism, how to leverage other members like line
or col
?The collection of extra debug information can be controlled by environment variable to minimized the effect on performance.
TEDD
to visualize the effect of every schedule primitive.A fundamental set_span function has been introduced to TVM repo in TensorFlow frontend. The new implementation we proposed can resolve the following problems:
After investigations, we can support multiple frontends and resolve the problem 1. Based on the set_span
derived from ExprMutator
, we can properly handle the problem 2 and 3.
The SequentialSpan
extends the capability of Span
so as to handle those multiple source transformations. The SIBuilder
is a new helper class for the developers when they are tagging span in a pass.
This functionality extends original design with some interface changes, so as to the existent tool, TEDD
, which will also be modified slightly. With this new feature, users could have better understanding on op implementation by visualizing the effect of primitives.
Schedule
data structure rather than using TEDD
only.We plan to have following PRs with corresponding test cases:
TEDD
modification (PR)