| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file tensor.cc |
| */ |
| #include <tvm/runtime/registry.h> |
| #include <tvm/te/operation.h> |
| #include <tvm/te/tensor.h> |
| #include <tvm/te/tensor_intrin.h> |
| |
| #include <memory> |
| |
| namespace tvm { |
| namespace te { |
| |
| IterVar thread_axis(Range dom, std::string tag) { |
| return IterVar(dom, Var(tag), kThreadIndex, tag); |
| } |
| |
| IterVar reduce_axis(Range dom, std::string name) { return IterVar(dom, Var(name), kCommReduce); } |
| |
| Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } |
| |
| // Tensor |
| PrimExpr Tensor::operator()(Array<Var> indices) const { |
| Array<PrimExpr> arr(indices.begin(), indices.end()); |
| return operator()(arr); |
| } |
| |
| PrimExpr Tensor::operator()(Array<PrimExpr> indices) const { |
| if (ndim() != 0) { |
| CHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read" |
| << "ndim = " << ndim() << ", indices.size=" << indices.size(); |
| } |
| |
| return ProducerLoad((*this), indices); |
| } |
| |
| String TensorNode::GetNameHint() const { |
| return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index)); |
| } |
| |
| Tensor Operation::output(size_t i) const { |
| auto node = make_object<TensorNode>(); |
| node->op = *this; |
| node->value_index = i; |
| node->dtype = (*this)->output_dtype(i); |
| node->shape = (*this)->output_shape(i); |
| return Tensor(node); |
| } |
| |
| Tensor::Tensor(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index) { |
| auto n = make_object<TensorNode>(); |
| n->shape = std::move(shape); |
| n->dtype = dtype; |
| n->op = op; |
| n->value_index = value_index; |
| data_ = std::move(n); |
| } |
| |
| TVM_REGISTER_GLOBAL("te.Tensor") |
| .set_body_typed([](Array<PrimExpr> shape, DataType dtype, Operation op, int value_index) { |
| return Tensor(shape, dtype, op, value_index); |
| }); |
| |
| TVM_REGISTER_NODE_TYPE(TensorNode); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<TensorNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* t = static_cast<const TensorNode*>(node.get()); |
| p->stream << "Tensor(shape=" << t->shape << ", op.name=" << t->op->name << ')'; |
| }); |
| |
| // TensorIntrin |
| TensorIntrin::TensorIntrin(std::string name, Operation op, Array<Tensor> inputs, |
| Array<Buffer> buffers, Array<Var> scalar_params, Stmt body, |
| Stmt reduce_init, Stmt reduce_update) { |
| auto n = make_object<TensorIntrinNode>(); |
| n->name = std::move(name); |
| n->op = std::move(op); |
| n->inputs = std::move(inputs); |
| n->buffers = std::move(buffers); |
| n->scalar_params = std::move(scalar_params); |
| n->body = std::move(body); |
| n->reduce_init = std::move(reduce_init); |
| n->reduce_update = std::move(reduce_update); |
| data_ = std::move(n); |
| } |
| |
| TVM_REGISTER_GLOBAL("te.TensorIntrin") |
| .set_body_typed([](std::string name, Operation op, Array<Tensor> inputs, Array<Buffer> buffers, |
| Array<Var> scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update) { |
| return TensorIntrin(name, op, inputs, buffers, scalar_params, body, reduce_init, |
| reduce_update); |
| }); |
| |
| TVM_REGISTER_NODE_TYPE(TensorIntrinNode); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<TensorIntrinNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const TensorIntrinNode*>(node.get()); |
| p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")"; |
| }); |
| |
| // TensorIntrinCall |
| TensorIntrinCall::TensorIntrinCall(TensorIntrin intrin, Array<Tensor> tensors, |
| Array<Region> regions, Array<IterVar> reduce_axis, |
| Array<PrimExpr> scalar_inputs) { |
| auto n = make_object<TensorIntrinCallNode>(); |
| n->intrin = std::move(intrin); |
| n->tensors = std::move(tensors); |
| n->regions = std::move(regions); |
| n->reduce_axis = std::move(reduce_axis); |
| n->scalar_inputs = std::move(scalar_inputs); |
| data_ = std::move(n); |
| } |
| |
| TVM_REGISTER_GLOBAL("te.TensorIntrinCall") |
| .set_body_typed([](TensorIntrin intrin, Array<Tensor> tensors, Array<Region> regions, |
| Array<IterVar> reduce_axis, Array<PrimExpr> scalar_inputs) { |
| return TensorIntrinCall(intrin, tensors, regions, reduce_axis, scalar_inputs); |
| }); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* n = static_cast<const TensorIntrinCallNode*>(node.get()); |
| p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")"; |
| }); |
| |
| TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode); |
| |
| // Other tensor ops. |
| TVM_REGISTER_GLOBAL("te.TensorEqual").set_body_method(&Tensor::operator==); |
| |
| TVM_REGISTER_GLOBAL("te.TensorHash").set_body_typed([](Tensor tensor) -> int64_t { |
| return static_cast<int64_t>(std::hash<Tensor>()(tensor)); |
| }); |
| |
| TVM_REGISTER_GLOBAL("te.OpGetOutput").set_body_typed([](Operation op, int64_t output) { |
| return op.output(static_cast<size_t>(output)); |
| }); |
| |
| TVM_REGISTER_GLOBAL("te.OpNumOutputs").set_body_method<Operation>(&OperationNode::num_outputs); |
| |
| TVM_REGISTER_GLOBAL("te.OpInputTensors").set_body_method<Operation>(&OperationNode::InputTensors); |
| |
| } // namespace te |
| } // namespace tvm |