blob: d20c0e0ab9dd0831c4bc3d22bb6ea806568c81b5 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, unused-argument, no-init, import-outside-toplevel
"""
Tensorflow Model Helpers
========================
Some helper definitions for tensorflow models.
"""
import re
import os.path
import collections
import numpy as np
# Tensorflow imports
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tvm.contrib.download import download_testdata
try:
tf_compat_v1 = tf.compat.v1
except (ImportError, AttributeError):
tf_compat_v1 = tf
######################################################################
# Some helper functions
# ---------------------
def ProcessGraphDefParam(graph_def):
"""Type-checks and possibly canonicalizes `graph_def`.
Parameters
----------
graph_def : Obj
tensorflow graph definition.
Returns
-------
graph_def : Obj
tensorflow graph definition
"""
if not isinstance(graph_def, graph_pb2.GraphDef):
# `graph_def` could be a dynamically-created message, so try a duck-typed
# approach
try:
old_graph_def = graph_def
graph_def = graph_pb2.GraphDef()
graph_def.MergeFrom(old_graph_def)
except TypeError:
raise TypeError("graph_def must be a GraphDef proto.")
return graph_def
def convert_to_list(x):
if not isinstance(x, list):
x = [x]
return x
def AddShapesToGraphDef(session, out_node):
"""Add shapes attribute to nodes of the graph.
Input graph here is the default graph in context.
Parameters
----------
session : tf.Session
Tensorflow session
out_node : String or List
Final output node of the graph.
Returns
-------
graph_def : Obj
tensorflow graph definition with shapes attribute added to nodes.
"""
graph_def = tf_compat_v1.graph_util.convert_variables_to_constants(
session,
session.graph.as_graph_def(add_shapes=True),
convert_to_list(out_node),
)
return graph_def
class NodeLookup(object):
"""Converts integer node ID's to human readable labels."""
def __init__(self, label_lookup_path=None, uid_lookup_path=None):
self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
def load(self, label_lookup_path, uid_lookup_path):
"""Loads a human readable English name for each softmax node.
Parameters
----------
label_lookup_path: String
File containing String UID to integer node ID mapping .
uid_lookup_path: String
File containing String UID to human-readable string mapping.
Returns
-------
node_id_to_name: dict
dict from integer node ID to human-readable string.
"""
if not tf_compat_v1.gfile.Exists(uid_lookup_path):
tf.logging.fatal("File does not exist %s", uid_lookup_path)
if not tf_compat_v1.gfile.Exists(label_lookup_path):
tf.logging.fatal("File does not exist %s", label_lookup_path)
# Loads mapping from string UID to human-readable string
proto_as_ascii_lines = tf_compat_v1.gfile.GFile(uid_lookup_path).readlines()
uid_to_human = {}
p = re.compile(r"[n\d]*[ \S,]*")
for line in proto_as_ascii_lines:
parsed_items = p.findall(line)
uid = parsed_items[0]
human_string = parsed_items[2]
uid_to_human[uid] = human_string
# Loads mapping from string UID to integer node ID.
node_id_to_uid = {}
proto_as_ascii = tf_compat_v1.gfile.GFile(label_lookup_path).readlines()
for line in proto_as_ascii:
if line.startswith(" target_class:"):
target_class = int(line.split(": ")[1])
if line.startswith(" target_class_string:"):
target_class_string = line.split(": ")[1]
node_id_to_uid[target_class] = target_class_string[1:-2]
# Loads the final mapping of integer node ID to human-readable string
node_id_to_name = {}
for key, val in node_id_to_uid.items():
if val not in uid_to_human:
tf.logging.fatal("Failed to locate: %s", val)
name = uid_to_human[val]
node_id_to_name[key] = name
return node_id_to_name
def id_to_string(self, node_id):
if node_id not in self.node_lookup:
return ""
return self.node_lookup[node_id]
def get_workload_official(model_url, model_sub_path):
"""Import workload from tensorflow official
Parameters
----------
model_url: str
URL from where it will be downloaded.
model_sub_path:
Sub path in extracted tar for the ftozen protobuf file.
Returns
-------
model_path: str
Full path to saved model file
"""
model_tar_name = os.path.basename(model_url)
model_path = download_testdata(model_url, model_tar_name, module=["tf", "official"])
dir_path = os.path.dirname(model_path)
if model_path.endswith("tgz") or model_path.endswith("gz"):
import tarfile
tar = tarfile.open(model_path)
tar.extractall(path=dir_path)
tar.close()
elif model_path.endswith("zip"):
import zipfile
zip_object = zipfile.ZipFile(model_path)
zip_object.extractall(path=dir_path)
zip_object.close()
else:
raise RuntimeError("Could not decompress the file: " + model_path)
return os.path.join(dir_path, model_sub_path)
def get_workload(model_path, model_sub_path=None, inputs_dict=None, output=None):
"""Import workload from frozen protobuf
Parameters
----------
model_path: str
model_path on remote repository to download from.
model_sub_path: str
Model path in the compressed archive.
Returns
-------
graph_def: graphdef
graph_def is the tensorflow workload.
"""
if model_sub_path:
path_model = get_workload_official(model_path, model_sub_path)
else:
repo_base = "https://github.com/dmlc/web-data/raw/main/tensorflow/models/"
model_url = os.path.join(repo_base, model_path)
path_model = download_testdata(model_url, model_path, module="tf")
# Creates graph from saved graph_def.pb.
with tf_compat_v1.gfile.FastGFile(path_model, "rb") as f:
graph_def = tf_compat_v1.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf_compat_v1.import_graph_def(graph_def, name="", input_map=inputs_dict)
if inputs_dict is not None:
# graph is changed so generate graph_def again
with tf_compat_v1.Session(graph=graph) as sess:
graph_def = AddShapesToGraphDef(sess, output)
return graph_def
#######################################################################
# PTB LSTMBlockCell Model
# -----------------------
class PTBSmallConfig(object):
"""Small config.
This configurations are used when training the model
"""
num_layers = 2
num_steps = 1
hidden_size = 200
batch_size = 1
vocab_size = 10000
init_scale = 0.1
def get_config():
"""Configuration used for training the model"""
return PTBSmallConfig()
def pick_from_weight(weight, pows=1.0):
"""Identify token from Softmax output.
This token will be mapped to word in the vocabulary.
"""
weight = weight ** pows
t = np.cumsum(weight)
s = np.sum(weight)
return int(np.searchsorted(t, 0.5 * s))
def do_tf_sample(session, data, in_states, num_samples):
"""Sampled from the model"""
samples = []
sample = None
# Cell inputs c and h should be passed for each layer explicitly.
state_input_name = [
"Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros:0",
"Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1:0",
"Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros:0",
"Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1:0",
]
state = in_states
# Graph nodes to be fetched as run output. Tensorflow LSTMBlockCell create internal
# nodes for intermediate operations (gates) in the cell during run.
# Cell state (c) is ':1'and cell output (h) is ':6' for each layer.
fetches = [
[
"Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:1",
"Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:6",
"Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:1",
"Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:6",
],
"Model/Softmax:0",
]
def _get_feed_dict(input_name, input_data):
"""Create feed dict"""
feed_dict = {}
if isinstance(input_data, list):
for i, e in enumerate(input_name):
feed_dict[e] = input_data[i]
else:
feed_dict[input_name] = input_data
return feed_dict
for x in data:
feed_dict = _get_feed_dict(state_input_name, state)
feed_dict["Model/Placeholder:0"] = [[x]]
state, probs = session.run(fetches, feed_dict)
sample = pick_from_weight(probs[0])
if sample is not None:
samples.append(sample)
else:
samples.append(0)
k = 1
while k < num_samples:
feed_dict = _get_feed_dict(state_input_name, state)
feed_dict["Model/Placeholder:0"] = [[samples[-1]]]
state, probs = session.run(fetches, feed_dict)
sample = pick_from_weight(probs[0])
samples.append(sample)
k += 1
return samples, state
def _create_ptb_vocabulary(data_dir):
"""Read the PTB sample data input to create vocabulary"""
data_path = os.path.join(data_dir, "simple-examples/data/")
file_name = "ptb.train.txt"
def _read_words(filename):
"""Read the data for creating vocabulary"""
with tf_compat_v1.gfile.GFile(filename, "r") as f:
return f.read().encode("utf-8").decode("utf-8").replace("\n", "<eos>").split()
def _build_vocab(filename):
"""Create vocabulary"""
data = _read_words(filename)
counter = collections.Counter(data)
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*count_pairs))
word_to_id = dict(zip(words, range(len(words))))
# for python 3.x
id_to_word = dict((v, k) for k, v in word_to_id.items())
return word_to_id, id_to_word
def ptb_raw_data(data_path, file_name):
"""Read the sample data and create vocabulary"""
train_path = os.path.join(data_path, file_name)
word_to_id, id_2_word = _build_vocab(train_path)
return word_to_id, id_2_word
return ptb_raw_data(data_path, file_name)
def get_workload_ptb():
"""Import ptb workload from frozen protobuf
Parameters
----------
Nothing.
Returns
-------
graph_def: graphdef
graph_def is the tensorflow workload for ptb.
word_to_id : dict
English word to integer id mapping
id_to_word : dict
Integer id to English word mapping
"""
sample_repo = "http://www.fit.vutbr.cz/~imikolov/rnnlm/"
sample_data_file = "simple-examples.tgz"
sample_url = sample_repo + sample_data_file
ptb_model_file = "RNN/ptb/ptb_model_with_lstmblockcell.pb"
# pylint: disable=import-outside-toplevel
import tarfile
file_path = download_testdata(sample_url, sample_data_file, module=["data", "ptb_data"])
dir_path = os.path.dirname(file_path)
t = tarfile.open(file_path, "r")
t.extractall(dir_path)
word_to_id, id_to_word = _create_ptb_vocabulary(dir_path)
dtype = "float32"
shape = (1, 200)
# Convert states of LSTMBlockCell to placeholder, so TVM can feed data
state_name = [
"Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros:0",
"Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1:0",
"Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros:0",
"Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1:0",
]
inputs_dict = {
state_name[0]: tf_compat_v1.placeholder(dtype, shape, state_name[0].split(":")[0]),
state_name[1]: tf_compat_v1.placeholder(dtype, shape, state_name[1].split(":")[0]),
state_name[2]: tf_compat_v1.placeholder(dtype, shape, state_name[2].split(":")[0]),
state_name[3]: tf_compat_v1.placeholder(dtype, shape, state_name[3].split(":")[0]),
}
return (
word_to_id,
id_to_word,
get_workload(ptb_model_file, inputs_dict=inputs_dict, output="Model/Softmax"),
)