blob: 937827a000745737ab5b6c88ae63dd160c3fc9e4 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import collections
import six
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import run_config
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import (
basic_session_run_hooks,
device_setter,
session_run_hook,
training_util,
)
# TODO(b/64848083) Remove once uid bug is fixed
class RunConfig(tf.contrib.learn.RunConfig):
def uid(self, whitelist=None):
"""Generates a 'Unique Identifier' based on all internal fields.
Caller should use the uid string to check `RunConfig` instance integrity
in one session use, but should not rely on the implementation details, which
is subject to change.
Args:
whitelist: A list of the string names of the properties uid should not
include. If `None`, defaults to `_DEFAULT_UID_WHITE_LIST`, which
includes most properties user allowed to change.
Returns:
A uid string.
"""
if whitelist is None:
whitelist = run_config._DEFAULT_UID_WHITE_LIST
state = {k: v for k, v in self.__dict__.items() if not k.startswith("__")}
# Pop out the keys in whitelist.
for k in whitelist:
state.pop("_" + k, None)
ordered_state = collections.OrderedDict(sorted(state.items(), key=lambda t: t[0]))
# For class instance without __repr__, some special cares are required.
# Otherwise, the object address will be used.
if "_cluster_spec" in ordered_state:
ordered_state["_cluster_spec"] = collections.OrderedDict(
sorted(ordered_state["_cluster_spec"].as_dict().items(), key=lambda t: t[0])
)
return ", ".join("%s=%r" % (k, v) for (k, v) in six.iteritems(ordered_state))
class ExamplesPerSecondHook(session_run_hook.SessionRunHook):
"""Hook to print out examples per second.
Total time is tracked and then divided by the total number of steps
to get the average step time and then batch_size is used to determine
the running average of examples per second. The examples per second for the
most recent interval is also logged.
"""
def __init__(
self,
batch_size,
every_n_steps=100,
every_n_secs=None,
):
"""Initializer for ExamplesPerSecondHook.
Args:
batch_size: Total batch size used to calculate examples/second from
global time.
every_n_steps: Log stats every n steps.
every_n_secs: Log stats every n seconds.
"""
if (every_n_steps is None) == (every_n_secs is None):
raise ValueError("exactly one of every_n_steps and every_n_secs should be provided.")
self._timer = basic_session_run_hooks.SecondOrStepTimer(
every_steps=every_n_steps, every_secs=every_n_secs
)
self._step_train_time = 0
self._total_steps = 0
self._batch_size = batch_size
def begin(self):
self._global_step_tensor = training_util.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError("Global step should be created to use StepCounterHook.")
def before_run(self, run_context): # pylint: disable=unused-argument
return basic_session_run_hooks.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values):
_ = run_context
global_step = run_values.results
if self._timer.should_trigger_for_step(global_step):
elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(global_step)
if elapsed_time is not None:
steps_per_sec = elapsed_steps / elapsed_time
self._step_train_time += elapsed_time
self._total_steps += elapsed_steps
average_examples_per_sec = self._batch_size * (self._total_steps / self._step_train_time)
current_examples_per_sec = steps_per_sec * self._batch_size
# Average examples/sec followed by current examples/sec
logging.info(
"%s: %g (%g), step = %g",
"Average examples/sec",
average_examples_per_sec,
current_examples_per_sec,
self._total_steps,
)
def local_device_setter(
num_devices=1, ps_device_type="cpu", worker_device="/cpu:0", ps_ops=None, ps_strategy=None
):
if ps_ops == None:
ps_ops = ["Variable", "VariableV2", "VarHandleOp"]
if ps_strategy is None:
ps_strategy = device_setter._RoundRobinStrategy(num_devices)
if not six.callable(ps_strategy):
raise TypeError("ps_strategy must be callable")
def _local_device_chooser(op):
current_device = pydev.DeviceSpec.from_string(op.device or "")
node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
if node_def.op in ps_ops:
ps_device_spec = pydev.DeviceSpec.from_string("/{}:{}".format(ps_device_type, ps_strategy(op)))
ps_device_spec.merge_from(current_device)
return ps_device_spec.to_string()
else:
worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "")
worker_device_spec.merge_from(current_device)
return worker_device_spec.to_string()
return _local_device_chooser