blob: b093e9efea3e794b3a7c0019bd9b8e7d397852cf [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import plpy
import math
from utilities.utilities import _assert
from utilities.control import MinWarning
class AutoMLSchema:
BRACKET = 's'
ROUND = 'i'
CONFIGURATIONS = 'n_i'
RESOURCES = 'r_i'
@MinWarning("warning")
class HyperbandSchedule():
"""The utility class for loading a hyperband schedule table with algorithm inputs.
Attributes:
schedule_table (string): Name of output table containing hyperband schedule.
R (int): Maximum number of resources (iterations) that can be allocated
to a single configuration.
eta (int): Controls the proportion of configurations discarded in
each round of successive halving.
skip_last (int): The number of last rounds to skip.
"""
def __init__(self, schedule_table, R, eta=3, skip_last=0):
self.schedule_table = schedule_table # table name to store hyperband schedule
self.R = R # maximum iterations/epochs allocated to a configuration
self.eta = eta # defines downsampling rate
self.skip_last = skip_last
self.validate_inputs()
# number of unique executions of Successive Halving (minus one)
self.s_max = int(math.floor(math.log(self.R, self.eta)))
self.validate_s_max()
self.schedule_vals = []
self.calculate_schedule()
def load(self):
"""
The entry point for loading the hyperband schedule table.
"""
self.create_schedule_table()
self.insert_into_schedule_table()
def validate_inputs(self):
"""
Validates user input values
"""
_assert(self.eta > 1, "DL: eta must be greater than 1")
_assert(self.R >= self.eta, "DL: R should not be less than eta")
def validate_s_max(self):
_assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, "DL: skip_last must be " +
"non-negative and less than {0}".format(self.s_max))
def calculate_schedule(self):
"""
Calculates the hyperband schedule (number of configs and allocated resources)
in each round of each bracket and skips the number of last rounds specified in 'skip_last'
"""
for s in reversed(range(self.s_max+1)):
n = int(math.ceil(int((self.s_max + 1)/(s+1))*math.pow(self.eta, s))) # initial number of configurations
r = self.R * math.pow(self.eta, -s)
for i in range((s+1) - int(self.skip_last)):
# Computing each of the
n_i = n*math.pow(self.eta, -i)
r_i = r*math.pow(self.eta, i)
self.schedule_vals.append({AutoMLSchema.BRACKET: s,
AutoMLSchema.ROUND: i,
AutoMLSchema.CONFIGURATIONS: int(n_i),
AutoMLSchema.RESOURCES: int(round(r_i))})
def create_schedule_table(self):
"""Initializes the output schedule table"""
create_query = """
CREATE TABLE {self.schedule_table} (
{s} INTEGER,
{i} INTEGER,
{n_i} INTEGER,
{r_i} INTEGER,
unique ({s}, {i})
);
""".format(self=self,
s=AutoMLSchema.BRACKET,
i=AutoMLSchema.ROUND,
n_i=AutoMLSchema.CONFIGURATIONS,
r_i=AutoMLSchema.RESOURCES)
with MinWarning('warning'):
plpy.execute(create_query)
def insert_into_schedule_table(self):
"""Insert everything in self.schedule_vals into the output schedule table."""
for sd in self.schedule_vals:
sd_s = sd[AutoMLSchema.BRACKET]
sd_i = sd[AutoMLSchema.ROUND]
sd_n_i = sd[AutoMLSchema.CONFIGURATIONS]
sd_r_i = sd[AutoMLSchema.RESOURCES]
insert_query = """
INSERT INTO
{self.schedule_table}(
{s_col},
{i_col},
{n_i_col},
{r_i_col}
)
VALUES (
{sd_s},
{sd_i},
{sd_n_i},
{sd_r_i}
)
""".format(s_col=AutoMLSchema.BRACKET,
i_col=AutoMLSchema.ROUND,
n_i_col=AutoMLSchema.CONFIGURATIONS,
r_i_col=AutoMLSchema.RESOURCES,
**locals())
plpy.execute(insert_query)