blob: 858704923ac20160180197f4695145f6256d26fb [file] [log] [blame]
#!/usr/bin/env python
"""
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from abc import abstractmethod, ABCMeta
import tensorflow as tf
flags = tf.app.flags
# Flags for configuring the task
flags.DEFINE_string("job_name", None, "job name: worker or ps")
flags.DEFINE_integer("task_index", 0, "Worker task index, should be >= 0")
flags.DEFINE_string("ps_hosts", "", "Comma-separated list of hostname:port pairs")
flags.DEFINE_string("worker_hosts", "", "Comma-separated list of hostname:port pairs")
flags.DEFINE_string("ckp_dir", None, "Directory for storing the checkpoints")
flags.DEFINE_string("work_dir", "/tmp/tf_on_yarn", "Work directory")
FLAGS = flags.FLAGS
class YarnBootstrap(object):
def __init__(self):
pass
__metaclass__ = ABCMeta
@abstractmethod
def worker_do(self, server, cluster_spec, task_id):
pass
@abstractmethod
def ps_do(self, server, cluster_spec, task_id):
pass
def device_and_server(self):
# If FLAGS.job_name is not set, we're running single-machine TensorFlow.
# Don't set a device.
if FLAGS.job_name is None:
print("Running single-machine training")
return (None, "", "")
# Otherwise we're running distributed TensorFlow.
print("Running distributed training")
if FLAGS.task_index is None or FLAGS.task_index == "":
raise ValueError("Must specify an explicit `task_index`")
if FLAGS.ps_hosts is None or FLAGS.ps_hosts == "":
raise ValueError("Must specify an explicit `ps_hosts`")
if FLAGS.worker_hosts is None or FLAGS.worker_hosts == "":
raise ValueError("Must specify an explicit `worker_hosts`")
cluster_spec = tf.train.ClusterSpec({
"ps": FLAGS.ps_hosts.split(","),
"worker": FLAGS.worker_hosts.split(","),
})
server = tf.train.Server(
cluster_spec, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
time.sleep(60)
if FLAGS.job_name == "ps":
self.ps_do(server, cluster_spec, FLAGS.task_index)
server.join()
worker_device = "/job:worker/task:{}".format(FLAGS.task_index)
return (
tf.train.replica_device_setter(
worker_device=worker_device,
cluster=cluster_spec),
server, cluster_spec
)
def start(self, unused_args):
if FLAGS.ckp_dir is None or FLAGS.ckp_dir == "":
raise ValueError("Must specify an explicit `ckp_dir`")
device, server, cluster_spec = self.device_and_server()
with tf.device(device):
self.worker_do(server, cluster_spec, FLAGS.task_index)