#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

m_hyperband = function(Matrix[Double] X_train, Matrix[Double] y_train, 
  Matrix[Double] X_val, Matrix[Double] y_val, List[String] params, 
  Matrix[Double] paramRanges, Scalar[int] R = 81, Scalar[int] eta = 3, 
  Boolean verbose = TRUE) 
  return (Matrix[Double] bestWeights, Frame[Unknown] bestHyperParams) 
{
  # variable names follow publication where algorithm is introduced

  numParams = length(params);
  
  assert(numParams == nrow(paramRanges));
  assert(ncol(paramRanges) == 2);
  assert(nrow(X_train) == nrow(y_train));
  assert(nrow(X_val) == nrow(y_val));
  assert(ncol(X_train) == ncol(X_val));
  assert(ncol(y_train) == ncol(y_val));

  s_max = floor(log(R,eta));
  B = (s_max + 1) * R;
  bracketWinners = matrix(0, s_max+1, numParams+1);
  winnerWeights = matrix(0, s_max+1, ncol(X_train));

  parfor( s in s_max:0 ) {
    debugMsgs = "--------------------------";
    
    if( verbose ) {
      debugMsgs = append(debugMsgs, "BRACKET s = " + s + "\n");
    }
    
    n = ceil(floor(B/R/(s+1)) * eta^s);
    r = R * eta^(-s);
    
    scoreboard = matrix(0,n,1+numParams);
    candidateWeights = matrix(0,n,ncol(X_train));
    # candidateWeights is not read until last round, as models are retrained 
    # from zero in every trial at the moment

    # draw parameter values from uniform distribution
    # draw e.g. regularisation factor for all the candidates at once
    for( curParam in 1:numParams ) { 
      scoreboard[,curParam+1] = 
        rand(rows=n, cols=1, min=as.scalar(paramRanges[curParam, 1]), 
          max=as.scalar(paramRanges[curParam, 2]), pdf="uniform");
    }
    
    for( i in 0:s ) {
      n_i = as.integer(floor(n * eta^(-i)));
      r_i = as.integer(floor(r * eta^i)); 
      # when using number of iterations as a resource, r_i has to be an
      # integer; when using other types of resources, like portion of the
      # dataset, this is not the case This implementation hard-coded
      # iterations as the resource. floor() for r_i is not included in 
      # publication of hyperband
        
      if( verbose ) {
        debugMsgs = append(debugMsgs, "+++++++++++++++");
        debugMsgs = append(debugMsgs, "i: " + i + " (current round)");
        debugMsgs = append(debugMsgs, "n_i: " + n_i + " (number of configurations evaluated)");
        debugMsgs = append(debugMsgs, "r_i: " + r_i + " (maximum number of iterations)\n");
      }
      
      parfor( curCandidate in 1:n_i ) {
        # TODO argument list has to be passed from outside as well
        # args is a residue from the implementation with eval("lmCG", args)
        # init argument list
        args = list(X=X_train, y=y_train, icpt=0, reg=1e-7, 
          tol=1e-7, maxi=r_i, verbose=TRUE);
          
        for( curParam in 1:numParams ) {
          # replace default values with values of the candidate at the 
          # corresponding location
          args[as.scalar(params[curParam])] = 
            as.scalar(scoreboard[curCandidate,curParam+1]);
        }
        # original version
        # weights = eval(learnAlgo, arguments);
        
        # would be better to pass the whole list at once, this solution is error 
        # prone depending on the order of the list. hyper parameters to optimize
        # are taken from args, as there they are reordered to be invariant to the
        # order used at calling hyperband
        weights = eval("lmCG", list(X=X_train, y=y_train, icpt=0, 
          tol=as.scalar(args[1]), reg=as.scalar(args[2]), maxi=r_i, verbose=FALSE));
        
        candidateWeights[curCandidate] = t(weights)
        preds = lmpredict(X=X_val, w=weights);
        scoreboard[curCandidate,1] = as.matrix(sum((y_val - preds)^2));
      }
      
      # reorder both matrices by same order
      reorder = order(target=scoreboard, index.return=TRUE);
      P = table(seq(1,n_i), reorder); # permutation matrix
      scoreboard = P %*% scoreboard;
      candidateWeights = P %*% candidateWeights;
      
      if( verbose ) {
        debugMsgs = append(debugMsgs, "validation loss | parameter values:");
        debugMsgs = append(debugMsgs, toString(scoreboard));
      }
      
      numToKeep = floor(n_i/eta);
      
      # in some cases, the list of remaining candidates would get emptied
      if( numToKeep >= 1 ) {
        scoreboard = scoreboard[1:numToKeep]
        candidateWeights = candidateWeights[1:numToKeep];
      }
    }
    
    if( verbose ) {
      debugMsgs = append(debugMsgs, "Winner of Bracket: ");
      debugMsgs = append(debugMsgs, toString(scoreboard[1]));
      print(debugMsgs); # make print atomic because of parfor
    }
    bracketWinners[s+1] = scoreboard[1];
    winnerWeights[s+1] = candidateWeights[1];
  } 
  
  if( verbose ) {
    print("--------------------------");
    print("WINNERS OF EACH BRACKET (from s = 0 to s_max):");
    print("validation loss | parameter values:");
    print(toString(bracketWinners));
  }
  
  # reorder both matrices by same order
  reorder2 = order(target=bracketWinners, index.return=TRUE);
  P2 = table(seq(1,s_max+1), reorder2); # permutation matrix
  bracketWinners = P2 %*% bracketWinners;
  winnerWeights = P2 %*% winnerWeights;
  
  bestHyperParams = as.frame(t(bracketWinners[1,2:1+numParams]));
  bestWeights = t(winnerWeights[1]);
  
  if( verbose ) {
    print("Hyper parameters returned:");
    print(toString(bestHyperParams));
    print("Weights returned:");
    print(toString(t(bestWeights)));
  }
}
