blob: 3efb38d44bf5e1bb982b8e883b854c10f328c8c9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/distributed/struct_info.h>
#include <tvm/relax/struct_info.h>
#include <tvm/script/ir_builder/relax/ir.h>
#include <tvm/tir/op.h>
#include "./utils.h"
namespace tvm {
namespace relax {
Expr MakeCallTIRDist(Expr func, Tuple args,
ffi::Array<distributed::DTensorStructInfo> out_sinfo_list,
ffi::Optional<Expr> packed_ints) {
for (const distributed::DTensorStructInfo& sinfo : out_sinfo_list) {
const auto* shape = sinfo->tensor_sinfo->shape.as<ShapeExprNode>();
CHECK(shape != nullptr) << "out_sinfo of call_tir should have defined ShapeExpr as shape. "
"However, one given structure info is "
<< sinfo;
}
StructInfo out_sinfo{nullptr};
if (out_sinfo_list.size() == 1) {
out_sinfo = out_sinfo_list[0];
} else {
out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()});
}
static const Op& op = Op::Get("relax.call_tir");
Call call;
if (!packed_ints) {
// don't use additional optional argument
call = Call(op, {func, args}, {}, {out_sinfo});
} else {
call = Call(op, {func, args, packed_ints.value()}, {}, {out_sinfo});
}
return call;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("script.ir_builder.relax.distributed.call_tir_dist", MakeCallTIRDist);
}
} // namespace relax
} // namespace tvm