added broadcasted() flag for tensor, added broadcast to mult, refactor gpu traverse unary
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 47a73b9..aea988d 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -114,6 +114,15 @@
     return false;
   }
 
+  bool broadcasted() const {
+    int strideProduct = 1;
+    for (const auto &i : stride_) strideProduct *= i;
+    if (strideProduct == 0) {
+      return true;
+    }
+    return false;
+  }
+
   const vector<int> &stride() const { return stride_; }
 
   /// Return true if the content of the tensor is initialized
@@ -250,7 +259,7 @@
 
   /// Return a view of the input tensor whose shape is broadcasted to be
   /// compitable with the given shape
-  Tensor &Broadcast(const Shape &shape);
+  Tensor &Broadcast(const Shape &shape, const int ignore_last_dim = 0);
 
   /// Reset the shape, device, and data type as given tensor.
   /// If block size changes, then reallocate a new block.
@@ -332,7 +341,8 @@
 
 /// Return a view of the input tensor whose shape is broadcasted to be
 /// compitable with the given shape
-Tensor Broadcast(const Tensor &in, const Shape &shape);
+Tensor Broadcast(const Tensor &in, const Shape &shape,
+                 const int ignore_last_dim = 0);
 
 /// Change the axes
 Tensor Transpose(const Tensor &in, const vector<size_t> &axes);
diff --git a/src/core/tensor/math_kernel.cu b/src/core/tensor/math_kernel.cu
index 3777a06..43be56d 100644
--- a/src/core/tensor/math_kernel.cu
+++ b/src/core/tensor/math_kernel.cu
@@ -69,7 +69,9 @@
 }
 */
 
-__global__ void KernelBroadcastTo(const size_t n, size_t nDim, const float *in,const float* shape, const float* stride, float *out) {
+__global__ void KernelTraverseUnaryTransform(const size_t n, size_t nDim,
+                                             const float *in, const int *shape,
+                                             const int *stride, float *out) {
   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
        i += blockDim.x * gridDim.x) {
     int shape_accu = n;
@@ -77,10 +79,10 @@
     int remains = i;
 
     for (int k = 0; k < nDim; k++) {
-      shape_accu = shape_accu/shape[k];
-      int idx = remains/shape_accu;
-      remains = remains%shape_accu;
-      offset = offset + idx*stride[k];
+      shape_accu = shape_accu / shape[k];
+      int idx = remains / shape_accu;
+      remains = remains % shape_accu;
+      offset = offset + idx * stride[k];
     }
     out[i] = in[offset];
   }
@@ -651,8 +653,11 @@
   KernelAdd <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, x, out);
 }
 
-void broadcast_to(const size_t n, size_t nDim,const float *in,const float* shape, const float* stride, float *out, cudaStream_t s) {
-  KernelBroadcastTo <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, nDim, in, shape, stride, out);
+void traverse_unary_transform(const size_t n, size_t nDim, const float *in,
+                              const int *shape, const int *stride, float *out,
+                              cudaStream_t s) {
+  KernelTraverseUnaryTransform<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(
+      n, nDim, in, shape, stride, out);
 }
 
 void mult(const size_t n, const float *in, const float x, float *out,
diff --git a/src/core/tensor/math_kernel.h b/src/core/tensor/math_kernel.h
index 206fa1a..69e5047 100644
--- a/src/core/tensor/math_kernel.h
+++ b/src/core/tensor/math_kernel.h
@@ -84,9 +84,9 @@
 void mult(const size_t n, const float *in, const float x, float *out,
           cudaStream_t s);
 
-void broadcast_to(const size_t n, size_t nDim, const float *in,
-                  const float *shape, const float *stride, float *out,
-                  cudaStream_t s);
+void traverse_unary_transform(const size_t n, size_t nDim, const float *in,
+                              const int *shape, const int *stride, float *out,
+                              cudaStream_t s);
 
 void div(const size_t n, const float x, const float *in, float *out,
          cudaStream_t s);
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index 530288e..518ad48 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -430,15 +430,19 @@
   return;
 }
 
-Tensor &Tensor::Broadcast(const Shape &shape) {
+Tensor &Tensor::Broadcast(const Shape &shape, const int ignore_last_dim) {
   // TODO(wangwei) do we need to transform the mem layout if the tensor was
   // transposed?
   auto m = shape_.size() - 1, n = shape.size() - 1;
-  for (size_t i = 0; i <= std::min(m, n); i++) {
-    if ((shape.at(n - i) != shape_.at(m - i)) && (shape.at(n - i) != 1)) {
-      CHECK_EQ(shape_.at(m - i), 1) << "i= " << i << "\n";  // << Backtrace();
-      shape_.at(m - i) = shape.at(n - i);
-      stride_.at(m - i) = 0;
+  // ignore_last_dim is useful for mult broadcast
+  // e.g. (2,3,4)x(4,5) to (2,3,4)x(2,4,5)
+  if (ignore_last_dim < std::min(m, n)) {
+    for (size_t i = ignore_last_dim; i <= std::min(m, n); i++) {
+      if ((shape.at(n - i) != shape_.at(m - i)) && (shape.at(n - i) != 1)) {
+        CHECK_EQ(shape_.at(m - i), 1) << "i= " << i << "\n";  // << Backtrace();
+        shape_.at(m - i) = shape.at(n - i);
+        stride_.at(m - i) = 0;
+      }
     }
   }
   if (m < n) {
@@ -450,9 +454,10 @@
   return *this;
 }
 
-Tensor Broadcast(const Tensor &in, const Shape &shape) {
+Tensor Broadcast(const Tensor &in, const Shape &shape,
+                 const int ignore_last_dim) {
   Tensor out(in);
-  return out.Broadcast(shape);
+  return out.Broadcast(shape, ignore_last_dim);
 }
 
 Tensor &Tensor::T() {
@@ -710,8 +715,8 @@
         { __VA_ARGS__ }                                        \
         break;                                                 \
       }                                                        \
-      case ((kInt << _SwitchShift) + kCuda): {             \
-        typedef int DType;                                   \
+      case ((kInt << _SwitchShift) + kCuda): {                 \
+        typedef int DType;                                     \
         typedef lang::Cuda Lang;                               \
         { __VA_ARGS__ }                                        \
         break;                                                 \
@@ -723,7 +728,7 @@
         break;                                                 \
       }                                                        \
       case ((kInt << _SwitchShift) + kCpp): {                  \
-        typedef int DType;                                   \
+        typedef int DType;                                     \
         typedef lang::Cpp Lang;                                \
         { __VA_ARGS__ }                                        \
         break;                                                 \
@@ -1547,19 +1552,15 @@
 }
 
 Tensor Mult(const Tensor &A, const Tensor &B) {
-  Shape s;
-  s.push_back(A.shape(0));
-  if (B.nDim() == 2) s.push_back(B.shape(1));
-  if (A.nDim() > 2) {
-    // for n>2 dim
-    // A {..., m1, m2} x B {..., m2, m3} = C {..., m1, m3}
-    s = A.shape();
-    s.pop_back();
-    s.push_back(B.shape(B.nDim() - 1));
-  }
+  auto A_ = Broadcast(A, B.shape(), 2);
+  auto B_ = Broadcast(B, A.shape(), 2);
+
+  Shape s = A_.shape();
+  s.pop_back();
+  s.push_back(B.shape(B.nDim() - 1));
 
   Tensor out(s, A.device(), A.data_type());
-  Mult(A, B, &out);
+  Mult(A_, B_, &out);
   return out;
 }
 
@@ -1611,14 +1612,14 @@
       Tensor A_tmp;
       Tensor B_tmp;
 
-      if (A.transpose()) {
+      if (A.transpose() || A.broadcasted()) {
         A_tmp = Tensor(A.shape(), A.device(), A.data_type());
         singa::Transform(A, &A_tmp);
       } else {
         A_tmp = A;
       }
 
-      if (B.transpose()) {
+      if (B.transpose() || B.broadcasted()) {
         B_tmp = Tensor(B.shape(), B.device(), B.data_type());
         singa::Transform(B, &B_tmp);
       } else {
diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h
index f3a3173..bbc7395 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -206,30 +206,36 @@
                              generate_tensor_nd_desc(*out), outPtr));
 }
 
-inline Tensor get_broadcasted_tensor(const Tensor& in1, Context* ctx) {
-  Tensor in1Bc(in1.shape(), in1.device(), in1.data_type());
-  Tensor shape(Shape{in1.nDim()}, in1.device(), in1.data_type());
-  Tensor stride(Shape{in1.nDim()}, in1.device(), in1.data_type());
-  const vector<float> strideVec(in1.stride().begin(), in1.stride().end());
-  const vector<float> shapeVec(in1.shape().begin(), in1.shape().end());
+template <typename T>
+void TraverseUnaryTransformImpl(const Tensor& in1, Tensor* in1Bc,
+                                Context* ctx) {
+  Tensor shape(Shape{in1.nDim()}, in1.device(), kInt);
+  Tensor stride(Shape{in1.nDim()}, in1.device(), kInt);
+  const vector<int> strideVec(in1.stride().begin(), in1.stride().end());
+  const vector<int> shapeVec(in1.shape().begin(), in1.shape().end());
   shape.CopyDataFromHostPtr(shapeVec.data(), in1.nDim());
   stride.CopyDataFromHostPtr(strideVec.data(), in1.nDim());
+  const int* shapePtr = static_cast<const int*>(shape.block()->data());
+  const int* stridePtr = static_cast<const int*>(stride.block()->data());
 
-  const float* shapePtr = static_cast<const float*>(shape.block()->data());
-  const float* stridePtr = static_cast<const float*>(stride.block()->data());
-  const float* inPtr1 = static_cast<const float*>(in1.block()->data());
-  float* inBcPtr1 = static_cast<float*>(in1Bc.block()->mutable_data());
+  const T* inPtr1 = static_cast<const T*>(in1.block()->data());
+  T* inBcPtr1 = static_cast<T*>(in1Bc->block()->mutable_data());
 
-  const size_t n = Product(in1Bc.shape());
+  const size_t n = Product(in1Bc->shape());
 
-  cuda::broadcast_to(n, in1.nDim(), inPtr1, shapePtr, stridePtr, inBcPtr1,
-                     ctx->stream);
-
-  return in1Bc;
+  cuda::traverse_unary_transform(n, in1.nDim(), inPtr1, shapePtr, stridePtr,
+                                 inBcPtr1, ctx->stream);
 }
+template void TraverseUnaryTransformImpl<float>(const Tensor& in1,
+                                                Tensor* in1Bc, Context* ctx);
 
 template <>
 void Transform<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) {
+  if (in.broadcasted()) {
+    TraverseUnaryTransformImpl<float>(in, out, ctx);
+    return;
+  }
+
   const float* inPtr = static_cast<const float*>(in.block()->data());
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
 
@@ -251,13 +257,7 @@
     float* outPtr = static_cast<float*>(out->block()->mutable_data());    \
     const size_t num = out->Size();                                       \
                                                                           \
-    int strideProduct1 = 1;                                               \
-    for (const auto& i : in1.stride()) strideProduct1 *= i;               \
-                                                                          \
-    int strideProduct2 = 1;                                               \
-    for (const auto& i : in2.stride()) strideProduct2 *= i;               \
-                                                                          \
-    if ((strideProduct1 * strideProduct2) != 0) {                         \
+    if (!in1.broadcasted() && !in2.broadcasted()) {                       \
       if (!in1.transpose() && !in2.transpose() &&                         \
           (in1.stride() == in2.stride())) {                               \
         kernel(num, inPtr1, inPtr2, outPtr, ctx->stream);                 \
@@ -278,16 +278,18 @@
         }                                                                 \
       }                                                                   \
     } else {                                                              \
-      Tensor in1Bc;                                                       \
-      Tensor in2Bc;                                                       \
-      if (strideProduct1 == 0) {                                          \
-        in1Bc = get_broadcasted_tensor(in1, ctx);                         \
-        inPtr1 = static_cast<const float*>(in1Bc.block()->data());        \
+      Tensor in1bc;                                                       \
+      Tensor in2bc;                                                       \
+      if (in1.broadcasted()) {                                            \
+        in1bc = Tensor(in1.shape(), in1.device(), in1.data_type());       \
+        Transform<float, lang::Cuda>(in1, &in1bc, ctx);                   \
+        inPtr1 = static_cast<const float*>(in1bc.block()->data());        \
       }                                                                   \
                                                                           \
-      if (strideProduct2 == 0) {                                          \
-        in2Bc = get_broadcasted_tensor(in2, ctx);                         \
-        inPtr2 = static_cast<const float*>(in2Bc.block()->data());        \
+      if (in2.broadcasted()) {                                            \
+        in2bc = Tensor(in2.shape(), in2.device(), in2.data_type());       \
+        Transform<float, lang::Cuda>(in2, &in2bc, ctx);                   \
+        inPtr2 = static_cast<const float*>(in2bc.block()->data());        \
       }                                                                   \
                                                                           \
       kernel(num, inPtr1, inPtr2, outPtr, ctx->stream);                   \
@@ -524,7 +526,6 @@
   cuda::eq(num, outPtr, 0.0, outPtr, ctx->stream);
 }
 
-
 /// Natual logarithm, the base is e, Neper number out[i]=ln(in[i]).
 template <>
 void Log<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) {
diff --git a/test/python/test_tensor.py b/test/python/test_tensor.py
index 4d83d32..1afd12a 100644
--- a/test/python/test_tensor.py
+++ b/test/python/test_tensor.py
@@ -161,6 +161,32 @@
         y = 2 / x
         self.assertEqual(tensor.average(y), 2.)
 
+    def matmul_high_dim_helper(self, dev):
+        configs = [
+            [(1, 12, 7, 64), (1, 12, 64, 7)],
+            [(1, 7, 768), (768, 768)],
+        ]
+        print()
+        for config in configs:
+            X = np.random.random(config[0]).astype(np.float32)
+            x = tensor.from_numpy(X)
+            x.to_device(dev)
+
+            W = np.random.random(config[1]).astype(np.float32)
+            w = tensor.from_numpy(W)
+            w.to_device(dev)
+
+            y_t = np.matmul(X, W)
+            y = autograd.matmul(x, w)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(y), y_t, 3)
+
+    def test_matmul_high_dim_cpu(self):
+        self.matmul_high_dim_helper(cpu_dev)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_matmul_high_dim_gpu(self):
+        self.matmul_high_dim_helper(gpu_dev)
+
     def test_tensor_inplace_api(self):
         """ tensor inplace methods alter internal state and also return self
         """