blob: 0aedfbfdce6bc6341de1371603d6b06ffd47bbe7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.selection.paramgrid;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* Generates tuples of hyper parameter values by given map.
*
* In the given map keys are names of hyper parameters
* and values are arrays of values for hyper parameter presented in the key.
*/
public class ParameterSetGenerator {
/** Size of parameter vector. Default value is 100. */
private int sizeOfParamVector = 100;
/** Params. */
private List<Double[]> params = new ArrayList<>();
/** The given map of hyper parameters and its values. */
private Map<Integer, Double[]> map;
/**
* Creates an instance of the generator.
*
* @param map In the given map keys are names of hyper parameters
* and values are arrays of values for hyper parameter presented in the key.
*/
public ParameterSetGenerator(Map<Integer, Double[]> map) {
assert map != null;
assert !map.isEmpty();
this.map = map;
sizeOfParamVector = map.size();
}
/**
* Returns the list of tuples presented as arrays.
*/
public List<Double[]> generate() {
Double[] nextPnt = new Double[sizeOfParamVector];
traverseTree(map, nextPnt, -1);
return Collections.unmodifiableList(params);
}
/**
* Traverse tree on the current level and starts procedure of child traversing.
*
* @param map The current state of the data.
* @param nextPnt Next point.
* @param dimensionNum Dimension number.
*/
private void traverseTree(Map<Integer, Double[]> map, Double[] nextPnt, int dimensionNum) {
dimensionNum++;
if (dimensionNum == sizeOfParamVector){
Double[] paramSet = Arrays.copyOf(nextPnt, sizeOfParamVector);
params.add(paramSet);
return;
}
Double[] valuesOfCurrDimension = map.get(dimensionNum);
for (Double specificValue : valuesOfCurrDimension) {
nextPnt[dimensionNum] = specificValue;
traverseTree(map, nextPnt, dimensionNum);
}
}
}