blob: 80775f891802b9d2fab8dba26f80f1c1395ac495 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tensorflow as tf
from submarine.ml.tensorflow_v2.optimizer import get_optimizer
def _get_session_config_from_env_var(params):
"""Returns a tf.ConfigProto instance with appropriate device_filters set."""
tf.compat.v1.disable_v2_behavior()
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
if tf_config and "task" in tf_config and "type" in tf_config["task"] and "index" in tf_config["task"]:
# Master should only communicate with itself and ps.
if tf_config["task"]["type"] == "master":
return tf.compat.v1.ConfigProto(
device_filters=["/job:ps", "/job:master"],
intra_op_parallelism_threads=params["resource"]["num_thread"],
inter_op_parallelism_threads=params["resource"]["num_thread"],
)
# Worker should only communicate with itself and ps.
elif tf_config["task"]["type"] == "worker":
return tf.compat.v1.ConfigProto( # gpu_options=gpu_options,
device_filters=["/job:ps", "/job:worker/task:%d" % tf_config["task"]["index"]],
intra_op_parallelism_threads=params["resource"]["num_thread"],
inter_op_parallelism_threads=params["resource"]["num_thread"],
)
return None
def get_tf_config(params):
"""
Get TF_CONFIG to run local or distributed training, If user don't set TF_CONFIG environment
variables, by default set local mode
:param params: model parameters that contain total number of gpu or cpu the model
intends to use
:type params: Dictionary
:return: The class specifies the configurations for an Estimator run
"""
if params["training"]["mode"] == "local": # local mode
tf.compat.v1.disable_v2_behavior()
tf_config = tf.estimator.RunConfig().replace(
session_config=tf.compat.v1.ConfigProto(
device_count={
"GPU": params["resource"]["num_gpu"],
"CPU": params["resource"]["num_cpu"],
},
intra_op_parallelism_threads=params["resource"]["num_thread"],
inter_op_parallelism_threads=params["resource"]["num_thread"],
),
log_step_count_steps=params["training"]["log_steps"],
save_summary_steps=params["training"]["log_steps"],
)
elif params["training"]["mode"] == "distributed":
pass
# need to be rewrite via tf2
# ParameterServerStrategy = tf.distribute.experimental.ParameterServerStrategy()
# tf_config = tf.estimator.RunConfig(
# train_distribute=ParameterServerStrategy,
# eval_distribute=ParameterServerStrategy,
# session_config=_get_session_config_from_env_var(params),
# save_summary_steps=params["training"]["log_steps"],
# log_step_count_steps=params["training"]["log_steps"],
# )
else:
raise ValueError("mode should be local or distributed")
return tf_config
def get_estimator_spec(logit, labels, mode, params):
"""
Returns `EstimatorSpec` that a model_fn can return.
:param logit: logits `Tensor` to be used.
:param labels: Labels `Tensor`, or `dict` of same.
:param mode: Estimator's `ModeKeys`.
:param params: Optional dict of hyperparameters. Will receive what is passed to Estimator
in params parameter.
:return:
"""
learning_rate = params["training"]["learning_rate"]
optimizer = params["training"]["optimizer"]
metric = params["output"]["metric"]
output = tf.sigmoid(logit)
predictions = {"probabilities": output}
export_outputs = {
# https://github.com/psf/black/issues/2434
# fmt: off
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
tf.estimator.export.PredictOutput(predictions)
# fmt: on
}
# Provide an estimator spec for `ModeKeys.PREDICT`
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs)
with tf.compat.v1.name_scope("Loss"):
loss = tf.reduce_mean(
input_tensor=tf.nn.sigmoid_cross_entropy_with_logits(logits=logit, labels=labels)
)
# Provide an estimator spec for `ModeKeys.EVAL`
eval_metric_ops = {}
if metric == "auc":
eval_metric_ops["auc"] = tf.compat.v1.metrics.auc(labels, output)
else:
raise TypeError("Invalid metric :", metric)
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(
mode=mode, predictions=predictions, loss=loss, eval_metric_ops=eval_metric_ops
)
with tf.compat.v1.name_scope("Train"):
op = get_optimizer(optimizer, learning_rate)
train_op = op.minimize(loss, global_step=tf.compat.v1.train.get_global_step())
# Provide an estimator spec for `ModeKeys.TRAIN` modes
if mode == tf.estimator.ModeKeys.TRAIN:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op)