blob: c4171e9f60cc79dee2d62e332232af8c9f8df35a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file expression.h
* \brief definitions of abstract expressions and expressions template
* \author Tianqi Chen, Bing Xu
*/
#ifndef MSHADOW_EXPRESSION_H_
#define MSHADOW_EXPRESSION_H_
#include "./base.h"
namespace mshadow {
/*!
* \brief namespace for abstract expressions and expressions template,
* have no dependency on tensor.h,
* These data structure takes no charge in computations,
* they are only used to define operations and represent expression in a symbolic way
*/
namespace expr {
/*! \brief type of expressions */
namespace type {
// type expression type are defined as bitmask
// subtype relationshop kRValue < kMapper < kPull < kComplex
/*!
* \brief this expression directly correspnds to a data class,
* can be used to assign data
*/
const int kRValue = 0;
/*!
* \brief expression contains element-wise tensor operations,
* map a expression to same shape
*/
const int kMapper = 1;
/*!
* \brief expression that can be chained with other expressiones
* Usually it have function Eval(i,j) defined, which pulls the result (i, j) from input
* expression and output the result at certain position.
*/
const int kChainer = 3;
/*! \brief othercase: e.g dot product */
const int kComplex = 7;
} // namespace type
/*!
* \brief expression engine that actually interprets these expressions
* this is a function template that needed to be implemented for specific expressions
* \tparam Saver the save method
* \tparam RValue the type of RValue to be saved
* \sa namespace sv
*/
template<typename Saver, typename RValue, typename DType>
struct ExpEngine;
/*! \brief defines how expression exp can be evaluated and stored into dst */
// template<typename EType>
// inline static void Eval(RValue *dst, const EType &exp);
/*!
* \brief base class for expression
* \tparam SubType inheritated class must put their type into this parameter
* \tparam DType the data type of each element in the expression
* \tparam exp_type expression type, see namespace type
*/
template<typename SubType, typename DType, int exp_type>
struct Exp {
public:
/*! \return subtype instance of current class */
inline const SubType& self(void) const {
return *static_cast<const SubType*>(this);
}
/*! \return reference of subtype instance of current class */
inline SubType* ptrself(void) {
return static_cast<SubType*>(this);
}
};
/*!
* \brief scalar expression
* \tparam DType the data type of the scalar
*/
template<typename DType>
struct ScalarExp: public Exp<ScalarExp<DType>, DType, type::kMapper> {
/*! \brief scalar value */
DType scalar_;
/*! \brief implicit constructor, MUST NOT BE explicit */
ScalarExp(DType scalar) : scalar_(scalar) {} // NOLINT(*)
};
/*! \brief create an scalar expression */
template<typename DType>
inline ScalarExp<DType> scalar(DType s) {
return ScalarExp<DType>(s);
}
/*!
* \brief typecast expression, cast the type of elements
* \tparam DstDType the target type we want to cast into
* \tparam SrcDType the target type we want to cast from
* \tparam EType the type of the source expression
* \tparam etype the type of expression after cast
*/
template<typename DstDType, typename SrcDType, typename EType, int etype>
struct TypecastExp:
public Exp<TypecastExp<DstDType, SrcDType, EType, etype>,
DstDType, etype> {
/*! \brief expression to be typecasted */
const EType &exp;
/*! \brief constructor */
explicit TypecastExp(const EType &e) : exp(e) {}
};
/*! \brief create an scalar expression */
template<typename DstDType, typename SrcDType,
typename EType, int etype>
inline TypecastExp<DstDType, SrcDType, EType, (etype|type::kMapper)>
tcast(const Exp<EType, SrcDType, etype> &exp) {
return TypecastExp<DstDType, SrcDType, EType, (etype|type::kMapper)>(exp.self());
}
/*! \brief represent a transpose expression of a container */
template<typename EType, typename DType>
struct TransposeExp: public Exp<TransposeExp<EType, DType>,
DType, type::kChainer> {
/*! \brief expression to be transposed */
const EType &exp;
/*! \brief constructor */
explicit TransposeExp(const EType &e) : exp(e) {}
/*! \brief transpose expression */
inline const EType &T(void) const {
return exp;
}
};
/*!
* \brief base class of all rvalues
* \tparam Container the actually class of data container, e.g. Tensor1D
* \tparam DataType the element data type of each element in the container
*/
template<typename Container, typename DType>
class RValueExp: public Exp<Container, DType, type::kRValue> {
public:
/*!
*\brief transpose of a matrix
*\return transpose of current expression
*/
inline const TransposeExp<Container, DType> T(void) const {
return TransposeExp<Container, DType>(this->self());
}
/*! \brief operator overload */
inline Container &operator+=(DType s) {
ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
return *(this->ptrself());
}
/*! \brief operator overload */
inline Container &operator-=(DType s) {
ExpEngine<sv::minusto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
return *(this->ptrself());
}
/*! \brief operator overload */
inline Container &operator*=(DType s) {
ExpEngine<sv::multo, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
return *(this->ptrself());
}
/*! \brief operator overload */
inline Container &operator/=(DType s) {
ExpEngine<sv::divto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
return *(this->ptrself());
}
/*! \brief operator overload */
inline Container &__assign(DType s) {
ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
return *(this->ptrself());
}
/*! \brief we can not define container = container */
template<typename E, int etype>
inline Container &__assign(const Exp<E, DType, etype> &exp) {
ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), exp.self());
return *(this->ptrself());
}
/*! \brief operator overload, assign */
inline Container &__assign(const Exp<Container, DType, type::kRValue> &exp);
/*! \brief implementation of operator+= */
template<typename E, int etype>
inline Container &operator+=(const Exp<E, DType, etype> &exp) {
ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), exp.self());
return *(this->ptrself());
}
/*! \brief implementation of operator-= */
template<typename E, int etype>
inline Container &operator-=(const Exp<E, DType, etype> &exp) {
ExpEngine<sv::minusto, Container, DType>::Eval(this->ptrself(), exp.self());
return *(this->ptrself());
}
/*! \brief implementation of operator*= */
template<typename E, int etype>
inline Container &operator*=(const Exp<E, DType, etype> &exp) {
ExpEngine<sv::multo, Container, DType>::Eval(this->ptrself(), exp.self());
return *(this->ptrself());
}
/*! \brief implementation of operator/= */
template<typename E, int etype>
inline Container &operator/=(const Exp<E, DType, etype> &exp) {
ExpEngine<sv::divto, Container, DType>::Eval(this->ptrself(), exp.self());
return *(this->ptrself());
}
};
/*!
* \brief matrix multiplication expression dot(lhs[.T], rhs[.T])
* \tparam TA type of lhs
* \tparam TB type of rhs
* \tparam ltrans whether lhs is transposed
* \tparam rtrans whether rhs is transposed
* \tparam DType the data type of the scalar
*/
template<typename TA, typename TB, bool ltrans, bool rtrans, typename DType>
struct DotExp: public Exp<DotExp<TA, TB, ltrans, rtrans, DType>,
DType, type::kComplex> {
/*! \brief left operand */
const TA &lhs_;
/*! \brief right operand */
const TB &rhs_;
/*! \brief scale over result */
DType scale_;
/*! \brief constructor */
explicit DotExp(const TA &lhs, const TB &rhs, DType scale)
: lhs_(lhs), rhs_(rhs), scale_(scale) {}
};
// definition of dot expression
/*! \brief dot operator def */
template<typename TA, typename TB, typename DType>
inline DotExp<TA, TB, false, false, DType>
dot(const RValueExp<TA, DType> &lhs, const RValueExp<TB, DType> &rhs) {
return DotExp<TA, TB, false, false, DType>(lhs.self(), rhs.self(), DType(1.0f));
}
/*! \brief dot operator def */
template<typename TA, typename TB, typename DType>
inline DotExp<TA, TB, true, false, DType>
dot(const TransposeExp<TA, DType> &lhs, const RValueExp<TB, DType> &rhs) {
return DotExp<TA, TB, true, false, DType>(lhs.exp, rhs.self(), DType(1.0f));
}
/*! \brief dot operator def */
template<typename TA, typename TB, typename DType>
inline DotExp<TA, TB, false, true, DType>
dot(const RValueExp<TA, DType> &lhs, const TransposeExp<TB, DType> &rhs) {
return DotExp<TA, TB, false, true, DType>(lhs.self(), rhs.exp, DType(1.0f));
}
/*! \brief dot operator def */
template<typename TA, typename TB, typename DType>
inline DotExp<TA, TB, true, true, DType>
dot(const TransposeExp<TA, DType> &lhs, const TransposeExp<TB, DType> &rhs) {
return DotExp<TA, TB, true, true, DType>(lhs.exp, rhs.exp, DType(1.0f));
}
/*! \brief batch_dot operator def */
template<bool transpose_left, bool transpose_right, typename TA, typename TB, typename DType>
inline DotExp<TA, TB, transpose_left, transpose_right, DType>
batch_dot(const RValueExp<TA, DType> &lhs, const RValueExp<TB, DType> &rhs) {
return DotExp<TA, TB, transpose_left, transpose_right, DType>(
lhs.self(), rhs.self(), DType(1.0f));
}
//---------------
// TernaryMapExp
// --------------
/*!
* \brief ternary map expression
* \tparam OP operator
* \tparam TA type of item1
* \tparam TB type of item2
* \tparam etype expression type, sa namespace::type
*/
template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype>
struct TernaryMapExp: public Exp<TernaryMapExp<OP, TA, TB, TC, DType, etype>,
DType, etype> {
/*! \brief first operand */
const TA &item1_;
/*! \brief second operand */
const TB &item2_;
/*! \brief third operand */
const TC &item3_;
/*! \brief constructor */
explicit TernaryMapExp(const TA &item1, const TB &item2, const TC &item3)
:item1_(item1), item2_(item2), item3_(item3) {}
};
/*! \brief make expression */
template<typename OP, typename TA, typename TB, typename TC, typename DType, int ta, int tb, int tc>
inline TernaryMapExp<OP, TA, TB, TC, DType, (ta|tb|tc|type::kMapper)>
MakeExp(const Exp<TA, DType, ta> &item1, const Exp<TB, DType, tb> &item2,
const Exp<TC, DType, tc> &item3) {
return TernaryMapExp<OP, TA, TB, TC, DType,
(ta|tb|tc|type::kMapper)>(item1.self(), item2.self(), item3.self());
}
/*!
* \brief short hand for MakeExp, usage F<op>(item1,item2,item3). create a ternary operation expression
* \param item1 first operand
* \param item2 second operand
* \param item3 third operand
* \return the result expression
* \tparam ternary operator
* \tparam TA item1 expression
* \tparam ta item1 expression type
* \tparam TB item2 expression
* \tparam tb item2 expression type
* \tparam TC item3 expression
* \tparam tc item3 expression type
* \sa mshadow::op
*/
// Ternary
template<typename OP, typename TA, typename TB, typename TC, typename DType, int ta, int tb, int tc>
inline TernaryMapExp<OP, TA, TB, TC, DType, (ta|tb|tc|type::kMapper)>
F(const Exp<TA, DType, ta> &item1, const Exp<TB, DType, tb> &item2,
const Exp<TC, DType, tc> &item3) {
return MakeExp<OP>(item1, item2, item3);
}
//---------------
// BinaryMapExp
// --------------
/*!
* \brief binary map expression lhs [op] rhs
* \tparam OP operator
* \tparam TA type of lhs
* \tparam TB type of rhs
* \tparam etype expression type, sa namespace::type
*/
template<typename OP, typename TA, typename TB, typename DType, int etype>
struct BinaryMapExp: public Exp<BinaryMapExp<OP, TA, TB, DType, etype>,
DType, etype> {
/*! \brief left operand */
const TA &lhs_;
/*! \brief right operand */
const TB &rhs_;
/*! \brief constructor */
explicit BinaryMapExp(const TA &lhs, const TB &rhs)
:lhs_(lhs), rhs_(rhs) {}
};
/*! \brief make expression */
template<typename OP, typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<OP, TA, TB, DType, (ta|tb|type::kMapper)>
MakeExp(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return BinaryMapExp<OP, TA, TB, DType,
(ta|tb|type::kMapper)>(lhs.self(), rhs.self());
}
/*!
* \brief short hand for MakeExp, usage F<op>(lhs, rhs). create a binary operation expression
* \param lhs left operand
* \param rhs right operand
* \return the result expression
* \tparam binary operator
* \tparam TA lhs expression
* \tparam ta lhs expression type
* \tparam TB rhs expression
* \tparam tb rhs expression type
* \sa mshadow::op
*/
template<typename OP, typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<OP, TA, TB, DType, (ta|tb|type::kMapper)>
F(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return MakeExp<OP>(lhs, rhs);
}
// operator rules
/*! \brief operator overload */
template<typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<op::plus, TA, TB, DType, (ta|tb|type::kMapper)>
operator+(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return MakeExp<op::plus>(lhs, rhs);
}
/*! \brief operator overload */
template<typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<op::minus, TA, TB, DType, (ta|tb|type::kMapper)>
operator-(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return MakeExp<op::minus>(lhs, rhs);
}
/*! \brief operator overload */
template<typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<op::mul, TA, TB, DType, (ta|tb|type::kMapper)>
operator*(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return MakeExp<op::mul>(lhs, rhs);
}
/*! \brief operator overload */
template<typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<op::div, TA, TB, DType, (ta|tb|type::kMapper)>
operator/(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return MakeExp<op::div>(lhs, rhs);
}
//---------------
// UnaryMapExp
// --------------
/*!
* \brief unary map expression op(src)
* \tparam OP operator
* \tparam TA type of src
* \tparam etype expression type, sa namespace::type
*/
template<typename OP, typename TA, typename DType, int etype>
struct UnaryMapExp: public Exp<UnaryMapExp<OP, TA, DType, etype>,
DType, etype> {
/*! \brief source expression */
const TA &src_;
/*! \brief constructor */
explicit UnaryMapExp(const TA &src) : src_(src) {}
};
/*! \brief make expression */
template<typename OP, typename TA, typename DType, int ta>
inline UnaryMapExp<OP, TA, DType, (ta|type::kMapper)>
MakeExp(const Exp<TA, DType, ta> &src) {
return UnaryMapExp<OP, TA, DType, (ta|type::kMapper)>(src.self());
}
/*!
* \brief short hand for MakeExp, usage F<op>(src), create a unary operation expression
* \param src source expression
* \return the result expression
* \tparam operator
* \tparam TA source expression
* \tparam ta source expression type
* \sa mshadow::op
*/
template<typename OP, typename TA, typename DType, int ta>
inline UnaryMapExp<OP, TA, DType, (ta|type::kMapper)>
F(const Exp<TA, DType, ta> &src) {
return MakeExp<OP>(src);
}
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXPRESSION_H_