blob: f0f58be84d10d4e81667173757dc934987d33e8d [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file intrin_rule_opencl.cc
* \brief OpenCL intrinsic rules.
*/
#include <tvm/arith/analyzer.h>
#include <tvm/tirx/op_attr_types.h>
#include "../../../target/intrin_rule.h"
namespace tvm {
namespace codegen {
namespace intrin {
using tirx::FLowerIntrinsic;
// There is no warp shuffle instruction in standard OpenCL. When shuffle is used, assume Intel's
// shuffle extension.
static PrimExpr DispatchIntelShuffle(const PrimExpr& e) {
const CallNode* call = e.as<CallNode>();
TVM_FFI_ICHECK(call != nullptr);
TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
arith::Analyzer analyzer;
TVM_FFI_ICHECK(analyzer->CanProve(call->args[3] == call->args[4]))
<< "Intel warp shuffle dose not support width != warp_size";
ffi::Array<PrimExpr> opencl_args{
{StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}};
return Call(call->dtype, builtin::call_pure_extern(), opencl_args);
}
void RegisterOpenCLIntrinRules() {
static bool registered = false;
if (registered) return;
registered = true;
// clang-format off
TVM_REGISTER_OP("tirx.clz")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.floor")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.ceil")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.trunc")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.fabs")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.round")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
// OpenCL's rint() uses ties-to-even, matching constant-folding semantics.
const tirx::CallNode* call = e.as<tirx::CallNode>();
TVM_FFI_ICHECK(call != nullptr);
ffi::Array<PrimExpr> new_args = {tirx::StringImm("rint")};
for (auto arg : call->args) {
new_args.push_back(arg);
}
return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(), new_args);
});
TVM_REGISTER_OP("tirx.nearbyint")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.exp")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.erf")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.exp2")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.exp10")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.log")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.log2")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.log10")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.tanh")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.sqrt")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.pow")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.popcount")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.fmod")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.sin")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.sinh")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.cos")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.cosh")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.tvm_warp_shuffle")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchIntelShuffle);
// clang-format on
}
} // namespace intrin
} // namespace codegen
} // namespace tvm