| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| import collections |
| |
| import six |
| import tensorflow as tf |
| from tensorflow.contrib.learn.python.learn import run_config |
| from tensorflow.core.framework import node_def_pb2 |
| from tensorflow.python.framework import device as pydev |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.training import ( |
| basic_session_run_hooks, |
| device_setter, |
| session_run_hook, |
| training_util, |
| ) |
| |
| |
| # TODO(b/64848083) Remove once uid bug is fixed |
| class RunConfig(tf.contrib.learn.RunConfig): |
| def uid(self, whitelist=None): |
| """Generates a 'Unique Identifier' based on all internal fields. |
| Caller should use the uid string to check `RunConfig` instance integrity |
| in one session use, but should not rely on the implementation details, which |
| is subject to change. |
| Args: |
| whitelist: A list of the string names of the properties uid should not |
| include. If `None`, defaults to `_DEFAULT_UID_WHITE_LIST`, which |
| includes most properties user allowed to change. |
| Returns: |
| A uid string. |
| """ |
| if whitelist is None: |
| whitelist = run_config._DEFAULT_UID_WHITE_LIST |
| |
| state = {k: v for k, v in self.__dict__.items() if not k.startswith("__")} |
| # Pop out the keys in whitelist. |
| for k in whitelist: |
| state.pop("_" + k, None) |
| |
| ordered_state = collections.OrderedDict(sorted(state.items(), key=lambda t: t[0])) |
| # For class instance without __repr__, some special cares are required. |
| # Otherwise, the object address will be used. |
| if "_cluster_spec" in ordered_state: |
| ordered_state["_cluster_spec"] = collections.OrderedDict( |
| sorted(ordered_state["_cluster_spec"].as_dict().items(), key=lambda t: t[0]) |
| ) |
| return ", ".join("%s=%r" % (k, v) for (k, v) in six.iteritems(ordered_state)) |
| |
| |
| class ExamplesPerSecondHook(session_run_hook.SessionRunHook): |
| """Hook to print out examples per second. |
| |
| Total time is tracked and then divided by the total number of steps |
| to get the average step time and then batch_size is used to determine |
| the running average of examples per second. The examples per second for the |
| most recent interval is also logged. |
| """ |
| |
| def __init__( |
| self, |
| batch_size, |
| every_n_steps=100, |
| every_n_secs=None, |
| ): |
| """Initializer for ExamplesPerSecondHook. |
| |
| Args: |
| batch_size: Total batch size used to calculate examples/second from |
| global time. |
| every_n_steps: Log stats every n steps. |
| every_n_secs: Log stats every n seconds. |
| """ |
| if (every_n_steps is None) == (every_n_secs is None): |
| raise ValueError("exactly one of every_n_steps and every_n_secs should be provided.") |
| self._timer = basic_session_run_hooks.SecondOrStepTimer( |
| every_steps=every_n_steps, every_secs=every_n_secs |
| ) |
| |
| self._step_train_time = 0 |
| self._total_steps = 0 |
| self._batch_size = batch_size |
| |
| def begin(self): |
| self._global_step_tensor = training_util.get_global_step() |
| if self._global_step_tensor is None: |
| raise RuntimeError("Global step should be created to use StepCounterHook.") |
| |
| def before_run(self, run_context): # pylint: disable=unused-argument |
| return basic_session_run_hooks.SessionRunArgs(self._global_step_tensor) |
| |
| def after_run(self, run_context, run_values): |
| _ = run_context |
| |
| global_step = run_values.results |
| if self._timer.should_trigger_for_step(global_step): |
| elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(global_step) |
| if elapsed_time is not None: |
| steps_per_sec = elapsed_steps / elapsed_time |
| self._step_train_time += elapsed_time |
| self._total_steps += elapsed_steps |
| |
| average_examples_per_sec = self._batch_size * (self._total_steps / self._step_train_time) |
| current_examples_per_sec = steps_per_sec * self._batch_size |
| # Average examples/sec followed by current examples/sec |
| logging.info( |
| "%s: %g (%g), step = %g", |
| "Average examples/sec", |
| average_examples_per_sec, |
| current_examples_per_sec, |
| self._total_steps, |
| ) |
| |
| |
| def local_device_setter( |
| num_devices=1, ps_device_type="cpu", worker_device="/cpu:0", ps_ops=None, ps_strategy=None |
| ): |
| if ps_ops == None: |
| ps_ops = ["Variable", "VariableV2", "VarHandleOp"] |
| |
| if ps_strategy is None: |
| ps_strategy = device_setter._RoundRobinStrategy(num_devices) |
| if not six.callable(ps_strategy): |
| raise TypeError("ps_strategy must be callable") |
| |
| def _local_device_chooser(op): |
| current_device = pydev.DeviceSpec.from_string(op.device or "") |
| |
| node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def |
| if node_def.op in ps_ops: |
| ps_device_spec = pydev.DeviceSpec.from_string("/{}:{}".format(ps_device_type, ps_strategy(op))) |
| |
| ps_device_spec.merge_from(current_device) |
| return ps_device_spec.to_string() |
| else: |
| worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "") |
| worker_device_spec.merge_from(current_device) |
| return worker_device_spec.to_string() |
| |
| return _local_device_chooser |