blob: 99dd49e26adc6fac1a72f380ff588fade1ad7ff7 [file] [log] [blame]
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer for the chicago_taxi demo."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import tensorflow as tf
import tensorflow_model_analysis as tfma
import tensorflow_transform as tft
from trainer import model
from trainer import taxi
SERVING_MODEL_DIR = 'serving_model_dir'
EVAL_MODEL_DIR = 'eval_model_dir'
TRAIN_BATCH_SIZE = 40
EVAL_BATCH_SIZE = 40
# Number of nodes in the first layer of the DNN
FIRST_DNN_LAYER_SIZE = 100
NUM_DNN_LAYERS = 4
DNN_DECAY_FACTOR = 0.7
def train_and_maybe_evaluate(hparams):
"""Run the training and evaluate using the high level API.
Args:
hparams: Holds hyperparameters used to train the model as name/value pairs.
Returns:
The estimator that was used for training (and maybe eval)
"""
schema = taxi.read_schema(hparams.schema_file)
tf_transform_output = tft.TFTransformOutput(hparams.tf_transform_dir)
train_input = lambda: model.input_fn(
hparams.train_files,
tf_transform_output,
batch_size=TRAIN_BATCH_SIZE
)
eval_input = lambda: model.input_fn(
hparams.eval_files,
tf_transform_output,
batch_size=EVAL_BATCH_SIZE
)
train_spec = tf.estimator.TrainSpec(
train_input, max_steps=hparams.train_steps)
serving_receiver_fn = lambda: model.example_serving_receiver_fn(
tf_transform_output, schema)
exporter = tf.estimator.FinalExporter('chicago-taxi', serving_receiver_fn)
eval_spec = tf.estimator.EvalSpec(
eval_input,
steps=hparams.eval_steps,
exporters=[exporter],
name='chicago-taxi-eval')
run_config = tf.estimator.RunConfig(
save_checkpoints_steps=999, keep_checkpoint_max=1)
serving_model_dir = os.path.join(hparams.output_dir, SERVING_MODEL_DIR)
run_config = run_config.replace(model_dir=serving_model_dir)
estimator = model.build_estimator(
tf_transform_output,
# Construct layers sizes with exponetial decay
hidden_units=[
max(2, int(FIRST_DNN_LAYER_SIZE * DNN_DECAY_FACTOR**i))
for i in range(NUM_DNN_LAYERS)
],
config=run_config)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
return estimator
def run_experiment(hparams):
"""Train the model then export it for tf.model_analysis evaluation.
Args:
hparams: Holds hyperparameters used to train the model as name/value pairs.
"""
estimator = train_and_maybe_evaluate(hparams)
schema = taxi.read_schema(hparams.schema_file)
tf_transform_output = tft.TFTransformOutput(hparams.tf_transform_dir)
# Save a model for tfma eval
eval_model_dir = os.path.join(hparams.output_dir, EVAL_MODEL_DIR)
receiver_fn = lambda: model.eval_input_receiver_fn(
tf_transform_output, schema)
tfma.export.export_eval_savedmodel(
estimator=estimator,
export_dir_base=eval_model_dir,
eval_input_receiver_fn=receiver_fn)
def main():
parser = argparse.ArgumentParser()
# Input Arguments
parser.add_argument(
'--train-files',
help='GCS or local paths to training data',
nargs='+',
required=True)
parser.add_argument(
'--tf-transform-dir',
help='Tf-transform directory with model from preprocessing step',
required=True)
parser.add_argument(
'--output-dir',
help="""\
Directory under which which the serving model (under /serving_model_dir)\
and the tf-mode-analysis model (under /eval_model_dir) will be written\
""",
required=True)
parser.add_argument(
'--eval-files',
help='GCS or local paths to evaluation data',
nargs='+',
required=True)
# Training arguments
parser.add_argument(
'--job-dir',
help='GCS location to write checkpoints and export models',
required=True)
# Argument to turn on all logging
parser.add_argument(
'--verbosity',
choices=['DEBUG', 'ERROR', 'FATAL', 'INFO', 'WARN'],
default='INFO',
)
# Experiment arguments
parser.add_argument(
'--train-steps',
help='Count of steps to run the training job for',
required=True,
type=int)
parser.add_argument(
'--eval-steps',
help='Number of steps to run evalution for at each checkpoint',
default=100,
type=int)
parser.add_argument(
'--schema-file',
help='File holding the schema for the input data')
args = parser.parse_args()
# Set python level verbosity
tf.logging.set_verbosity(args.verbosity)
# Set C++ Graph Execution level verbosity
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(
tf.logging.__dict__[args.verbosity] / 10)
# Run the training job
hparams = tf.contrib.training.HParams(**args.__dict__)
run_experiment(hparams)
if __name__ == '__main__':
main()