| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| m_hyperband = function(Matrix[Double] X_train, Matrix[Double] y_train, |
| Matrix[Double] X_val, Matrix[Double] y_val, List[String] params, |
| Matrix[Double] paramRanges, Scalar[int] R = 81, Scalar[int] eta = 3, |
| Boolean verbose = TRUE) |
| return (Matrix[Double] bestWeights, Frame[Unknown] bestHyperParams) |
| { |
| # variable names follow publication where algorithm is introduced |
| |
| numParams = length(params); |
| |
| assert(numParams == nrow(paramRanges)); |
| assert(ncol(paramRanges) == 2); |
| assert(nrow(X_train) == nrow(y_train)); |
| assert(nrow(X_val) == nrow(y_val)); |
| assert(ncol(X_train) == ncol(X_val)); |
| assert(ncol(y_train) == ncol(y_val)); |
| |
| s_max = floor(log(R,eta)); |
| B = (s_max + 1) * R; |
| bracketWinners = matrix(0, s_max+1, numParams+1); |
| winnerWeights = matrix(0, s_max+1, ncol(X_train)); |
| |
| parfor( s in s_max:0 ) { |
| debugMsgs = "--------------------------"; |
| |
| if( verbose ) { |
| debugMsgs = append(debugMsgs, "BRACKET s = " + s + "\n"); |
| } |
| |
| n = ceil(floor(B/R/(s+1)) * eta^s); |
| r = R * eta^(-s); |
| |
| scoreboard = matrix(0,n,1+numParams); |
| candidateWeights = matrix(0,n,ncol(X_train)); |
| # candidateWeights is not read until last round, as models are retrained |
| # from zero in every trial at the moment |
| |
| # draw parameter values from uniform distribution |
| # draw e.g. regularisation factor for all the candidates at once |
| for( curParam in 1:numParams ) { |
| scoreboard[,curParam+1] = |
| rand(rows=n, cols=1, min=as.scalar(paramRanges[curParam, 1]), |
| max=as.scalar(paramRanges[curParam, 2]), pdf="uniform"); |
| } |
| |
| for( i in 0:s ) { |
| n_i = as.integer(floor(n * eta^(-i))); |
| r_i = as.integer(floor(r * eta^i)); |
| # when using number of iterations as a resource, r_i has to be an |
| # integer; when using other types of resources, like portion of the |
| # dataset, this is not the case This implementation hard-coded |
| # iterations as the resource. floor() for r_i is not included in |
| # publication of hyperband |
| |
| if( verbose ) { |
| debugMsgs = append(debugMsgs, "+++++++++++++++"); |
| debugMsgs = append(debugMsgs, "i: " + i + " (current round)"); |
| debugMsgs = append(debugMsgs, "n_i: " + n_i + " (number of configurations evaluated)"); |
| debugMsgs = append(debugMsgs, "r_i: " + r_i + " (maximum number of iterations)\n"); |
| } |
| |
| parfor( curCandidate in 1:n_i ) { |
| # TODO argument list has to be passed from outside as well |
| # args is a residue from the implementation with eval("lmCG", args) |
| # init argument list |
| args = list(X=X_train, y=y_train, icpt=0, reg=1e-7, |
| tol=1e-7, maxi=r_i, verbose=TRUE); |
| |
| for( curParam in 1:numParams ) { |
| # replace default values with values of the candidate at the |
| # corresponding location |
| args[as.scalar(params[curParam])] = |
| as.scalar(scoreboard[curCandidate,curParam+1]); |
| } |
| # original version |
| # weights = eval(learnAlgo, arguments); |
| |
| # would be better to pass the whole list at once, this solution is error |
| # prone depending on the order of the list. hyper parameters to optimize |
| # are taken from args, as there they are reordered to be invariant to the |
| # order used at calling hyperband |
| weights = eval("lmCG", list(X=X_train, y=y_train, icpt=0, |
| tol=as.scalar(args[1]), reg=as.scalar(args[2]), maxi=r_i, verbose=FALSE)); |
| |
| candidateWeights[curCandidate] = t(weights) |
| preds = lmpredict(X=X_val, w=weights); |
| scoreboard[curCandidate,1] = as.matrix(sum((y_val - preds)^2)); |
| } |
| |
| # reorder both matrices by same order |
| reorder = order(target=scoreboard, index.return=TRUE); |
| P = table(seq(1,n_i), reorder); # permutation matrix |
| scoreboard = P %*% scoreboard; |
| candidateWeights = P %*% candidateWeights; |
| |
| if( verbose ) { |
| debugMsgs = append(debugMsgs, "validation loss | parameter values:"); |
| debugMsgs = append(debugMsgs, toString(scoreboard)); |
| } |
| |
| numToKeep = floor(n_i/eta); |
| |
| # in some cases, the list of remaining candidates would get emptied |
| if( numToKeep >= 1 ) { |
| scoreboard = scoreboard[1:numToKeep] |
| candidateWeights = candidateWeights[1:numToKeep]; |
| } |
| } |
| |
| if( verbose ) { |
| debugMsgs = append(debugMsgs, "Winner of Bracket: "); |
| debugMsgs = append(debugMsgs, toString(scoreboard[1])); |
| print(debugMsgs); # make print atomic because of parfor |
| } |
| bracketWinners[s+1] = scoreboard[1]; |
| winnerWeights[s+1] = candidateWeights[1]; |
| } |
| |
| if( verbose ) { |
| print("--------------------------"); |
| print("WINNERS OF EACH BRACKET (from s = 0 to s_max):"); |
| print("validation loss | parameter values:"); |
| print(toString(bracketWinners)); |
| } |
| |
| # reorder both matrices by same order |
| reorder2 = order(target=bracketWinners, index.return=TRUE); |
| P2 = table(seq(1,s_max+1), reorder2); # permutation matrix |
| bracketWinners = P2 %*% bracketWinners; |
| winnerWeights = P2 %*% winnerWeights; |
| |
| bestHyperParams = as.frame(t(bracketWinners[1,2:1+numParams])); |
| bestWeights = t(winnerWeights[1]); |
| |
| if( verbose ) { |
| print("Hyper parameters returned:"); |
| print(toString(bestHyperParams)); |
| print("Weights returned:"); |
| print(toString(t(bestWeights))); |
| } |
| } |