blob: f5d89a85092b880c289aca36b81484fcb5bc357c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"
namespace tvm {
namespace meta_schedule {
using tir::Instruction;
using tir::InstructionKind;
using tir::Trace;
/*! \brief A mutator that mutates the thread binding factor decision of SampleCategorical */
class MutateThreadBindingNode : public MutatorNode {
public:
/*! \brief JSON representation of the workload */
std::string json_mod_;
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "meta_schedule.MutateThreadBinding";
TVM_DECLARE_FINAL_OBJECT_INFO(MutateThreadBindingNode, MutatorNode);
public:
// Inherit from `MutatorNode`
void InitializeWithTuneContext(const TuneContext& context) final {
this->json_mod_ = SaveJSON(context->mod.value());
}
// Inherit from `MutatorNode`
Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
// Inherit from `MutatorNode`
Mutator Clone() const final {
ObjectPtr<MutateThreadBindingNode> n = make_object<MutateThreadBindingNode>(*this);
return Mutator(n);
}
private:
struct Candidate {
/*! \brief The sampling instruction to be mutated */
Instruction inst;
/*! \brief The probability */
std::vector<double> probs;
/*! \brief The decision made */
int decision;
explicit Candidate(Instruction inst, std::vector<double> probs, int decision)
: inst(std::move(inst)), probs(std::move(probs)), decision(std::move(decision)) {}
};
std::vector<Candidate> FindCandidates(const Trace& trace, TRandState* rand_state);
};
/*!
* \brief Find Candidate with the following pattern:
* \code
* v = sch.sample_categorical(...)
* l1, l2 = sch.split(loop=l0, factors=[None, v])
* sch.bind(loop=l2, thread_axis="threadIdx.x")
* \endcode
*
* \param trace The trace from which to find the instructions
* \return All the candidate instructions
*/
std::vector<MutateThreadBindingNode::Candidate> MutateThreadBindingNode::FindCandidates(
const Trace& trace, TRandState* rand_state) {
using tir::InstructionNode;
static InstructionKind inst_sample_categorical = InstructionKind::Get("SampleCategorical");
static InstructionKind inst_split = InstructionKind::Get("Split");
static InstructionKind inst_bind = InstructionKind::Get("Bind");
std::vector<MutateThreadBindingNode::Candidate> candidates;
std::unordered_map<const PrimExprNode*, const tir::InstructionNode*> sample_insts;
std::unordered_map<const tir::LoopRVNode*, const tir::InstructionNode*> sampled_split_insts;
std::vector<const InstructionNode*> bind_insts;
auto is_split_by_sample = [&sample_insts](const Instruction& inst) -> bool {
if (!inst->kind.same_as(inst_split)) {
return false;
}
// Only consider cases with 2 factors and the first one is None
if (inst->inputs.size() != 3 || inst->inputs[1].defined()) return false;
ICHECK(inst->inputs[2].defined());
return sample_insts.find(Downcast<PrimExpr>(inst->inputs[2]).get()) != sample_insts.end();
};
auto is_thread_binding_by_sample = [&sampled_split_insts](const Instruction& inst) -> bool {
if (!inst->kind.same_as(inst_bind)) {
return false;
}
ICHECK_EQ(inst->inputs.size(), 1);
ICHECK_EQ(inst->attrs.size(), 1);
if (Downcast<String>(inst->attrs[0]) != "threadIdx.x") return false;
return sampled_split_insts.find(Downcast<tir::LoopRV>(inst->inputs[0]).get()) !=
sampled_split_insts.end();
};
for (const Instruction& inst : trace->insts) {
if (inst->kind.same_as(inst_sample_categorical)) {
ICHECK_EQ(inst->outputs.size(), 1);
const PrimExprNode* var_rv = TVM_TYPE_AS(inst->outputs[0], PrimExprNode);
sample_insts[var_rv] = inst.get();
} else if (is_split_by_sample(inst)) {
CHECK_EQ(inst->outputs.size(), 2);
// Only consider the inner loop, which can be bound to threadIdx.x
const tir::LoopRVNode* var_rv = TVM_TYPE_AS(inst->outputs[1], tir::LoopRVNode);
sampled_split_insts[var_rv] = inst.get();
} else if (is_thread_binding_by_sample(inst)) {
bind_insts.push_back(inst.get());
}
}
for (const InstructionNode* bind_inst : bind_insts) {
const auto* loop_rv = TVM_TYPE_AS(bind_inst->inputs[0], tir::LoopRVNode);
auto split_it = sampled_split_insts.find(loop_rv);
ICHECK(split_it != sampled_split_insts.end());
const InstructionNode* split_inst = split_it->second;
const auto* expr_rv = TVM_TYPE_AS(split_inst->inputs[2], PrimExprNode);
auto sample_it = sample_insts.find(expr_rv);
ICHECK(sample_it != sample_insts.end());
const InstructionNode* sample_inst = sample_it->second;
int decision = Downcast<Integer>(trace->decisions[GetRef<Instruction>(sample_inst)])->value;
std::vector<double> probs =
support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(sample_inst->attrs[1]));
candidates.emplace_back(GetRef<Instruction>(sample_inst), probs, decision);
}
return candidates;
}
Optional<Trace> MutateThreadBindingNode::Apply(const Trace& trace, TRandState* rand_state) {
std::vector<Candidate> candidates = FindCandidates(trace, rand_state);
if (candidates.empty()) {
return NullOpt;
}
Candidate candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())];
// Remove the current decision
candidate.probs.erase(candidate.probs.begin() + candidate.decision);
int result = tir::MakeMultinomialSampler(rand_state, candidate.probs)();
if (result >= candidate.decision) {
result += 1;
}
return trace->WithDecision(candidate.inst, Integer(result), /*remove_postproc=*/true);
}
Mutator Mutator::MutateThreadBinding() { return Mutator(make_object<MutateThreadBindingNode>()); }
TVM_REGISTER_NODE_TYPE(MutateThreadBindingNode);
TVM_REGISTER_GLOBAL("meta_schedule.MutateThreadBinding")
.set_body_typed(Mutator::MutateThreadBinding);
} // namespace meta_schedule
} // namespace tvm