#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

/*
 * Utility functions.
 */

channel_sums = function(matrix[double] X, int C, int Hin, int Win)
    return (matrix[double] out) {
  /*
   * Computes a channel-wise summation over a 4D input.
   *
   * Inputs:
   *  - X: Inputs, of shape (N, C*Hin*Win).
   *  - C: Number of input channels (dimensionality of input depth).
   *  - Hin: Input height.
   *  - Win: Input width.
   *
   * Outputs:
   *  - out: Outputs, of shape (C, 1).
   */
  # Here we sum each column, reshape to (C, Hin*Win), and sum each row to result in the summation
  # for each channel.
  out = rowSums(matrix(colSums(X), rows=C, cols=Hin*Win))  # shape (C, 1)
}

predict_class = function(matrix[double] Prob, int C, int H, int W) return (matrix[double] Prediction) {
  /*
   * Computes the class labels from the probabilities.
   *
   * Inputs:
   *  - Prob: Input Probability
   *  - C: Number of output labels
   *  - Hin: Input height.
   *  - Win: Input width.
   *
   * Outputs:
   *  - Prediction: Class Labels.
   */
  if(H == 1 & W == 1) {
    Prediction = rowIndexMax(Prob); # assuming one-based label mapping
  }
  else {
    N = nrow(Prob);
    Prediction = matrix(0, rows=N, cols=H*W);
    parfor(n in 1:N) {
      Prob1 = matrix(Prob[n,], rows=C, cols=H*W);
      Prediction[n,] = t(rowIndexMax(t(Prob1))); # assuming one-based label mapping
    }
  }
}

im2col = function(matrix[double] img, int Hin, int Win, int Hf, int Wf, int strideh, int stridew)
    return (matrix[double] img_cols) {
  /*
   * Rearrange local image regions (patches) into columns.
   *
   * Assumes image has already been padded as necessary.
   *
   * Inputs:
   *  - img: Input image, of shape (C, Hin*Win), where C is the number
   *      of input channels (depth).
   *  - Hin: Input height, including padding.
   *  - Win: Input width, including padding.
   *  - Hf: Filter height.
   *  - Wf: Filter width.
   *  - strideh: Stride over height.
   *  - stridew: Stride over width.
   *
   * Outputs:
   *  - img_cols: Local spatial regions (patches) of the image stretched
   *      out into columns, of shape (C*Hf*Wf, Hout*Wout).
   */
  C = nrow(img)
  Hout = as.integer(floor((Hin-Hf)/strideh + 1))
  Wout = as.integer(floor((Win-Wf)/stridew + 1))

  # Note: We start with `img_cols` transposed to allow for row-major
  # left-indexing inside the loop, which is more performant.
  img_cols = matrix(0, rows=Hout*Wout, cols=C*Hf*Wf)  # zeros
  parfor (hout in 1:Hout, check=0) {  # all output rows
    hin = (hout-1)*strideh + 1
    parfor (wout in 1:Wout, check=0) {  # all output columns
      win = (wout-1)*stridew + 1
      # Extract a local patch of the input image corresponding spatially to the filter sizes.
      img_patch = matrix(0, rows=C, cols=Hf*Wf)  # zeros
      parfor (c in 1:C) {  # all channels
        img_slice = matrix(img[c,], rows=Hin, cols=Win)  # reshape
        img_patch[c,] = matrix(img_slice[hin:hin+Hf-1, win:win+Wf-1], rows=1, cols=Hf*Wf)
      }
      img_cols[(hout-1)*Wout + wout,] = t(matrix(img_patch, rows=C*Hf*Wf, cols=1))  # reshape
    }
  }
  img_cols = t(img_cols)
}

col2im = function(matrix[double] img_cols, int C, int Hin, int Win, int Hf, int Wf,
                  int strideh, int stridew, string reduction)
    return (matrix[double] img) {
  /*
   * Create an image from columns of local image regions (patches).
   *
   * The reduction strategy determines how to deal with overlapping
   * patches.  If it is set to "add", any overlapping patches will be
   * added together when creating the image.  This is useful when
   * computing gradients on the original image given gradients on the
   * patches.  Otherwise, if "none" is provided, any overlapping
   * patches will just override previous ones when creating the image.
   * This is useful when recreating an image from the output of
   * `im2col`.
   *
   * Assumes original image was already padded as necessary.
   *
   * Inputs:
   *  - img_cols: Local spatial regions (patches) of the image stretched
   *      out into columns, of shape (C*Hf*Wf, Hout*Wout).
   *  - C: Number of input channels (dimensionality of input depth).
   *  - Hin: Input height, including padding.
   *  - Win: Input width, including padding.
   *  - Hf: Filter height.
   *  - Wf: Filter width.
   *  - strideh: Stride over height.
   *  - stridew: Stride over width.
   *  - reduction: The reduction strategy to use for overlapping
   *      patches.  Valid options are "add" and "none".
   *
   * Outputs:
   *  - img: Input image, of shape (C, Hin*Win).
   */
  Hout = as.integer(floor((Hin-Hf)/strideh + 1))
  Wout = as.integer(floor((Win-Wf)/stridew + 1))

  img = matrix(0, rows=C, cols=Hin*Win)  # zeros
  for (hout in 1:Hout) {  # all output rows
    hin = (hout-1)*strideh + 1
    for (wout in 1:Wout) {  # all output columns
      win = (wout-1)*stridew + 1
      # Extract a local patch of the input image corresponding spatially to the filter sizes.
      img_patch = matrix(img_cols[,(hout-1)*Wout + wout], rows=C, cols=Hf*Wf)  # zeros
      parfor (c in 1:C) {  # all channels
        img_patch_slice = matrix(img_patch[c,], rows=Hf, cols=Wf)  # reshape
        if (reduction == "add") {
          img_slice = matrix(0, rows=Hin, cols=Win)
          img_slice[hin:hin+Hf-1, win:win+Wf-1] = img_patch_slice
          img[c,] = img[c,] + matrix(img_slice, rows=1, cols=Hin*Win)
        } else {
          img_slice = matrix(img[c,], rows=Hin, cols=Win)
          img_slice[hin:hin+Hf-1, win:win+Wf-1] = img_patch_slice
          img[c,] = matrix(img_slice, rows=1, cols=Hin*Win)
        }
      }
    }
  }
}

pad_image = function(matrix[double] img, int Hin, int Win, int padh, int padw, double pad_value)
    return (matrix[double] img_padded) {
  /*
   * Pads an image along the height and width dimensions with zeros.
   *
   * Inputs:
   *  - img: Input image, of shape (C, Hin*Win), where C is the number
   *      of input channels (depth).
   *  - Hin: Input height.
   *  - Win: Input width.
   *  - padh: Padding for top and bottom sides.
   *  - padw: Padding for left and right sides.
   *  - pad_value: Value to use for the padding.
   *      A typical value is 0.
   *
   * Outputs:
   *  - img_padded: The input image padded along the height and width
   *      dimensions, of shape (C, (Hin+2*padh)*(Win+2*padw)).
   */
  C = nrow(img)
  img_padded = matrix(0, rows=C, cols=(Hin+2*padh)*(Win+2*padw))  # zeros
  parfor (c in 1:C) {
    img_slice = matrix(img[c,], rows=Hin, cols=Win)  # depth slice C reshaped
    img_padded_slice = matrix(pad_value, rows=Hin+2*padh, cols=Win+2*padw)
    img_padded_slice[padh+1:padh+Hin, padw+1:padw+Win] = img_slice
    img_padded[c,] = matrix(img_padded_slice, rows=1, cols=(Hin+2*padh)*(Win+2*padw))  # reshape
  }
}

unpad_image = function(matrix[double] img_padded, int Hin, int Win, int padh, int padw)
    return (matrix[double] img) {
  /*
   * Unpads an image along the height and width dimensions.
   *
   * Inputs:
   *  - img_padded: The input image padded along the height and width
   *      dimensions, of shape (C, (Hin+2*padh)*(Win+2*padw)).
   *  - Hin: Input height of unpadded image.
   *  - Win: Input width of unpadded image.
   *  - padh: Padding for top and bottom sides.
   *  - padw: Padding for left and right sides.
   *
   * Outputs:
   *  - img: Input image, of shape (C, Hin*Win), where C is the number
   *      of input channels (depth).
   */
  C = nrow(img_padded)
  img = matrix(0, rows=C, cols=Hin*Win)
  parfor (c in 1:C) {
    img_padded_slice = matrix(img_padded[c,], rows=(Hin+2*padh), cols=(Win+2*padw))
    img_slice = img_padded_slice[padh+1:padh+Hin, padw+1:padw+Win]
    img[c,] = matrix(img_slice, rows=1, cols=Hin*Win)
  }
}

threshold = function(matrix[double] X, double thresh)
    return (matrix[double] out) {
  /*
   * Computes an indicator matrix with values in {0, 1} depending on
   * whether or not the values in X are above the input threshold.
   *
   * Inputs:
   *  - X: Inputs, of shape (any, any).
   *  - thresh: Input threshold.
   *
   * Outputs:
   *  - out: Outputs, of same shape as X.
   */
  out = X > thresh
}

transpose_NCHW_to_CNHW = function(matrix[double] X, int C)
    return (matrix[double] out) {
  /*
   * Reshape util for tensors in NCHW format.
   * Transposes the 1st and 2nd dimensions.
   *
   * Inputs:
   *  - X: Inputs, of shape (N, C*H*W).
   *  - C: Number of channels (dimensionality of depth).
   *
   * Outputs:
   *  - out: Outputs with the N and C axes transposed, of
   *      shape (C, N*H*W).
   */
  N = nrow(X)
  D = ncol(X) / C

  # This is an easy reshape because the channels remain intact. By
  # reshaping X to a matrix with N*C rows, we can reduce our task to
  # re-ordering rows (followed by the obvious reshape to achieve the
  # required output shape with C rows).
  #
  # The difficult part is to obtain the permutation matrix required
  # for re-ordering the rows. In this case, since we want to bring the
  # ith channels from all rows together, we will need a column vector
  # of the following form:
  # [1, 1+C, 1+2C, ..., 1+(N-1)C,
  #  2, 2+C, ..., 2+(N-1)C,
  #  3, 3+C, ..., 3+(N-1)C,
  #  .
  #  .
  #  .
  #  C, 2C, ..., NC]'
  # This vector can be produced via an outer call.
  col_idx = outer(seq(1,C), C*t(seq(0,N-1)), "+")

  # Generate the permutation matrix by:
  # - reshaping the result of outer into a col
  # - invoking table
  permut = table(seq(1, N*C), matrix(col_idx, rows=N*C, cols=1), N*C, N*C)

  # Generate the output by:
  # - pre-multiplying the (reshaped) X with the permutation matrix
  # - reshape to get the output shape with C rows
  out = matrix(permut %*% matrix(X, rows=N*C, cols=D), rows=C, cols=N*D)
}

top_k_row = function(matrix[double] X, integer r, integer k)
    return (matrix[double] values, matrix[double] indices) {
  /*
   * Computes the top k values (i.e. probabilities) and associated
   * indices (i.e. classes) in the rth row of the input matrix X.
   *
   * Inputs:
   *  - X: Inputs, of shape (N, D).
   *  - r: Input row number of X to look for.
   *  - k: Input number of top elements to look for.
   *
   * Outputs:
   *  - values: The top k values at the rth row, of shape
   *    (1, k).
   *  - indices: The class indices, of shape (1, k).
   */

  #TODO: do r & k need to be checked in the valid range
  row = X[r, ]
  row_t = t(row)
  indices = order(target=row_t, by=1, decreasing=TRUE, index.return=TRUE)
  indices = t(indices)
  indices = indices[1, 1:k]

  values = matrix(0, rows=1, cols=k)
  for (i in 1:k) {
    values[1, i] = row[1, as.scalar(indices[1, i])]
  }
}

top_k = function(matrix[double] X, integer k)
     return (matrix[double] values, matrix[double] indices) {
  /*
   * Computes the top k values (i.e. probabilities) and associated
   * indices (i.e. classes) for the input matrix X.
   *
   * Inputs:
   *  - X: Inputs, of shape (N, D).
   *  - k: Input number of top elements to look for.
   *
   * Outputs:
   *  - values: The top k values along a certain dimension, of shape
   *    (N, k).
   *  - indices: The indices of classes, of shape (N, K).
   */
  N = nrow(X)
  D = ncol(X)
  values = matrix(0, rows=N, cols=k)
  indices = matrix(0, rows=N, cols=k)

  parfor (r in 1:N) {
    [value, index] = top_k_row(X, r, k)
    values[r, ] = value
    indices[r, ] = index
  }
}

top_k2d = function(matrix[double] X, int k, int C, int Hin, int Win)
     return (matrix[double] values, matrix[double] indices) {
  /*
   * Computes the top k values (i.e. probabilities) and associated
   * indices (i.e. classes) for the input matrix X.
   *
   * Inputs:
   *  - X: Inputs, of shape (N, C*Hin*Win).
   *  - k: Input number of top elements to look for.
   *  - C: Number of input channels (dimensionality of input depth).
   *  - Hin: Input height.
   *  - Win: Input width.
   *
   * Outputs:
   *  - values: The top k values along a certain dimension, of shape
   *    (N, k*Hin*Win).
   *  - indices: The indices of classes, of shape (N, k*Hin*Win).
   */
  N = nrow(X)

  # Reshape the input matrix (N, C*Hin*Win) to (N*Hin*Win, C)
  X_C_NHW = transpose_NCHW_to_CNHW(X, C)
  X_NHW_C = t(X_C_NHW)

  # Compute the top k for the reshape matrix.
  [values_NHW_K, indices_NHW_K] = top_k(X_NHW_C, k)  # shape: (N*Hin*Win, k)

  values_K_NHW = t(values_NHW_K)
  indices_K_NHW = t(indices_NHW_K)

  values =  transpose_NCHW_to_CNHW(values_K_NHW, N)
  indices = transpose_NCHW_to_CNHW(indices_K_NHW, N)
}

