| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file tvm/tir/function.h |
| * \brief TIR Function. |
| */ |
| #ifndef TVM_TIR_FUNCTION_H_ |
| #define TVM_TIR_FUNCTION_H_ |
| |
| #include <tvm/ir/function.h> |
| #include <tvm/runtime/ndarray.h> |
| #include <tvm/tir/buffer.h> |
| #include <tvm/tir/expr.h> |
| #include <tvm/tir/stmt.h> |
| |
| #include <string> |
| |
| namespace tvm { |
| namespace tir { |
| |
| /*! |
| * \brief Primitive functions that contains TIR statements. |
| * |
| * The PrimFunc provides low-level code representation does not |
| * automatically manage |
| * |
| * \sa PrimFunc |
| */ |
| class PrimFuncNode : public BaseFuncNode { |
| public: |
| /*! \brief Function parameters */ |
| Array<tir::Var> params; |
| /*! \brief The body of the function */ |
| tir::Stmt body; |
| /*! \brief The return type of the function. */ |
| Type ret_type; |
| /*! |
| * \brief Maps some parameters to specific Buffer data structures. |
| * |
| * buffer_map provides a way to express data structure's field and shape |
| * constraints. The provided information is used in the program analysis |
| * and the code generation. |
| * |
| * - It defines the vars in the Buffer (m, n) in the cases below when |
| * they appears in the buffer_map for the first time. |
| * - When a var appears multiple times, they translate into runtime |
| * assertion to check the field constraint. |
| * |
| * \code |
| * |
| * # The corresponding fields of f are as follows |
| * # |
| * # - f.params = [a, b] |
| * # - f.buffer_map = {a: A, b: B} |
| * # - A = decl_buffer(shape=[m, n]) |
| * # - B = decl_buffer(shape=[m, n]) |
| * |
| * def f(a, b): |
| * m, n = var(), var() |
| * A = bind_buffer(a, shape=[m, n]) |
| * B = bind_buffer(b, shape=[m, n]) |
| * # body |
| * |
| * \endcode |
| * |
| * buffer_map is a sugar to express: |
| * - Parameter unpacking: e.g. I can load a.shape[0] to get value of m |
| * - Constraint checking: a.shape[0] must equal b.shape[0] because they |
| * both corresponds to m. |
| |
| * While we could have express parameter unpacking and constraint using |
| * normal statements, making buffer_map as first class citizen of PrimFunc |
| * will make program analysis much easier. |
| */ |
| Map<tir::Var, Buffer> buffer_map; |
| |
| /*! \brief The buffer map prior to flattening. |
| * |
| * This contains the buffers as they exists prior to flattening, and |
| * is used for validating an input tensor passed into the packed |
| * API. Any buffer that is present in `buffer_map` but not present |
| * in `preflattened_buffer_map` is assumed to be the same before |
| * and after flattening (e.g. a 1-d tensor that is backed by 1-d |
| * flat memory). |
| * |
| * TODO(Lunderberg): Remove preflattened_buffer_map, and instead |
| * declare each flattened buffer as aliasing the original tensor |
| * shape. This should include improving the StmtExprMutator to |
| * provide easier interactions with Buffer objects, so that the |
| * bookkeeping of relationships between buffers doesn't need to be |
| * repeated across several transforms. |
| */ |
| Map<tir::Var, Buffer> preflattened_buffer_map; |
| |
| void VisitAttrs(tvm::AttrVisitor* v) { |
| v->Visit("params", ¶ms); |
| v->Visit("body", &body); |
| v->Visit("ret_type", &ret_type); |
| v->Visit("buffer_map", &buffer_map); |
| v->Visit("preflattened_buffer_map", &preflattened_buffer_map); |
| v->Visit("attrs", &attrs); |
| v->Visit("span", &span); |
| v->Visit("_checked_type_", &checked_type_); |
| } |
| |
| bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const { |
| // visit params and buffer_map first as they contains defs. |
| return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) && |
| equal(preflattened_buffer_map, other->preflattened_buffer_map) && |
| equal(ret_type, other->ret_type) && equal(body, other->body) && |
| equal(attrs, other->attrs); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce.DefHash(params); |
| hash_reduce(buffer_map); |
| hash_reduce(preflattened_buffer_map); |
| hash_reduce(ret_type); |
| hash_reduce(body); |
| hash_reduce(attrs); |
| } |
| /*! |
| * \brief Return the derived function annotation of this function. |
| * |
| * \return The function type annotation. |
| * \note The function type annotation of PrimExpr is |
| * directly derived from the Vars without the need of type inference. |
| */ |
| TVM_DLL FuncType func_type_annotation() const; |
| |
| static constexpr const char* _type_key = "tir.PrimFunc"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to PrimFuncNode. |
| * \sa PrimFuncNode |
| */ |
| class PrimFunc : public BaseFunc { |
| public: |
| /*! |
| * \brief Constructor |
| * |
| * \param params The parameters of the function. |
| * |
| * \param body The body of the function. |
| * |
| * \param ret_type The return type of the function. |
| * |
| * \param buffer_map The buffer map for parameter buffer unpacking. |
| * This contains buffer objects as they appear in the body of the |
| * PrimFunc. (e.g. a buffer of shape ``[1024]`` originally |
| * generated as a tensor of shape ``[32, 32]``) |
| * |
| * \param preflattened_buffer_map The buffer map for |
| * parameter buffer unpacking. This contains buffer |
| * objects as they are expected to be passed in by the |
| * callee. (e.g. a buffer of shape ``[32, 32]`` originally |
| * generated as a tensor of shape ``[32, 32]``) |
| * |
| * \param attrs Additional function attributes. |
| * |
| * \param span The location of this object in the source code. |
| */ |
| TVM_DLL PrimFunc( |
| Array<tir::Var> params, Stmt body, Type ret_type = VoidType(), |
| Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(), |
| Optional<Map<tir::Var, Buffer>> preflattened_buffer_map = Optional<Map<tir::Var, Buffer>>(), |
| DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span()); |
| |
| TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); |
| TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode); |
| }; |
| |
| /*! |
| * \brief Tensor intrinsics for tensorization |
| */ |
| class TensorIntrinNode : public Object { |
| public: |
| /*! \brief The function to describe the computation. */ |
| PrimFunc desc; |
| /*! \brief The function of the implementation for the execution. */ |
| PrimFunc impl; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("desc", &desc); |
| v->Visit("impl", &impl); |
| } |
| |
| static constexpr const char* _type_key = "tir.TensorIntrin"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); |
| }; |
| |
| /*! |
| * \brief Managed reference to TensorIntrinNode. |
| */ |
| class TensorIntrin : public ObjectRef { |
| public: |
| /*! |
| * \brief Constructor |
| * \param desc The function to describe the computation. |
| * \param impl The function of the implementation for the execution. |
| */ |
| TVM_DLL explicit TensorIntrin(PrimFunc desc, PrimFunc impl); |
| |
| /*! |
| * \brief Create and register a TensorIntrin. After registration, the TensorIntrin can be looked |
| * up with its name. |
| * \param name The name of the TensorIntrin to register |
| * \param intrin The TensorIntrin to register. |
| * \param override Whether override existing intrinsic. |
| * \throws This method throws an exception if the TensorIntrin with the specified name already |
| * exists. |
| */ |
| TVM_DLL static void Register(String name, TensorIntrin intrin, bool override = false); |
| |
| /*! |
| * \brief Look up TensorIntrin by name. Raises an exception if not found. |
| * \param name The name of the TensorIntrin. |
| * \return The TensorIntrin with the specified name. |
| * \throws This method throws an exception if the TensorIntrin does not exist. |
| */ |
| TVM_DLL static TensorIntrin Get(String name); |
| |
| TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode) |
| }; |
| |
| /* |
| * \brief Specialize parameters of PrimFunc. |
| * \param func The PrimFunc to be specialized. |
| * \param param_map The mapping from function params to the instance. |
| * \return The new function with parameter specialized. |
| * \note We can define a Meta TIR function with symbolic shape: |
| * |
| * \code |
| * @T.prim_func |
| * def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None: |
| * A = T.match_buffer(a, (m, n), "float32") |
| * B = T.match_buffer(b, (m, n), "float32") |
| * for i, j in T.grid(m, n): |
| * with T.block(): |
| * vi, vj = T.axis.remap("SS", [i, j]) |
| * B[vi, vj] = A[vi, vj] |
| * \endcode |
| * |
| * Then we can make it specialized with given shapes or buffers. |
| * |
| * \code |
| * a, _, m, n = mem_copy.params |
| * func = mem_copy.specialize({a: tir.decl_buffer((16, 16))}) |
| * # or |
| * func = mem_copy.specialize({n: 16, m: 16}) |
| * \endcode |
| * |
| * \code {.language-id} |
| * @T.prim_func |
| * def mem_copy_16_16(a: T.handle, b: T.handle) -> None: |
| * A = T.match_buffer(a, (16, 16), "float32") |
| * B = T.match_buffer(b, (16, 16), "float32") |
| * for i, j in T.grid(16, 16): |
| * with T.block(): |
| * vi, vj = T.axis.remap("SS", [i, j]) |
| * B[vi, vj] = A[vi, vj] |
| * \endcode |
| */ |
| PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map); |
| |
| /*! |
| * \brief PrimFunc specific attribute names. |
| * |
| * \sa tvm::attr |
| */ |
| namespace attr { |
| /*! |
| * \brief List of thread IterVar that a DeviceLaunch function corresponds to. |
| * |
| * Type: Array<tir::IterVar> |
| * |
| * We call a device kernel launch function f using the following convention: |
| * |
| * Call(f, |
| * [arg1, arg2, ..., arg_n, |
| * work_size_1, work_size_2, ... work_size_m, dyn_shmem_size]) |
| * |
| * Here n = len(arg), m = len(work_size) = len(device_thread_axis). |
| * |
| * When kDeviceUseDynSharedMemory is not set, dyn_shmem_size argument is omitted. |
| * |
| * The list of device_thread_axis indicates how can be bind the |
| * work_size arguments to the corresponding threads. |
| * |
| * \sa tvm::CallingConv::kDeviceKernelLaunch |
| */ |
| constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis"; |
| |
| /*! |
| * \brief Whether or not use dynamic shared memory. |
| * |
| * Type: Integer |
| */ |
| constexpr const char* kDeviceUseDynSharedMemory = "tir.device_use_dyn_shared_memory"; |
| |
| /*! |
| * \brief Whether to set noalias rule on the function arguments. |
| * |
| * Type: Integer |
| */ |
| constexpr const char* kNoAlias = "tir.noalias"; |
| |
| /*! |
| * \brief Mark the function as the entry function of |
| * the final generated runtime module. |
| * |
| * Type: Integer |
| * |
| * \note There can only be one entry function per module. |
| */ |
| constexpr const char* kIsEntryFunc = "tir.is_entry_func"; |
| |
| /*! |
| * \brief Mark the function as the global function called from the host. |
| * |
| * Type: Integer |
| */ |
| constexpr const char* kIsGlobalFunc = "tir.is_global_func"; |
| |
| } // namespace attr |
| } // namespace tir |
| } // namespace tvm |
| #endif // TVM_TIR_FUNCTION_H_ |