Replace hard coded paths with args
diff --git a/tf-ner-poc/src/main/python/namefinder.py b/tf-ner-poc/src/main/python/namefinder.py
index c55d835..bc203a7 100644
--- a/tf-ner-poc/src/main/python/namefinder.py
+++ b/tf-ner-poc/src/main/python/namefinder.py
@@ -1,4 +1,4 @@
-
+#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,12 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+#
# This poc is based on source code taken from:
# https://github.com/guillaumegenthial/sequence_tagging
+import sys
from math import floor
-
import tensorflow as tf
import re
import numpy as np
@@ -324,17 +325,19 @@
def main():
+ if len(sys.argv) != 5:
+ print("Usage namefinder.py embedding_file train_file dev_file test_file")
+ return
+
name_finder = NameFinder()
- # word_dict, rev_word_dict, embeddings = name_finder.load_glove("/home/burn/Downloads/glove.840B.300d.txt")
- word_dict, rev_word_dict, embeddings = name_finder.load_glove("/home/blue/Downloads/fastText/memorial.vec")
- sentences, labels, char_set = name_finder.load_data(word_dict, "train.txt")
- #sentences_test, labels_test, char_set_test = name_finder.load_data(word_dict,"conll03.testa")
- sentences_test, labels_test, char_set_test = name_finder.load_data(word_dict,"dev.txt")
+ word_dict, rev_word_dict, embeddings = name_finder.load_glove(sys.argv[1])
+ sentences, labels, char_set = name_finder.load_data(word_dict, sys.argv[2])
+ sentences_dev, labels_dev, char_set_dev = name_finder.load_data(word_dict, sys.argv[3])
embedding_ph, token_ids_ph, char_ids_ph, word_lengths_ph, sequence_lengths_ph, labels_ph, train_op \
- = name_finder.create_graph(len(char_set | char_set_test), embeddings)
+ = name_finder.create_graph(len(char_set | char_set_dev), embeddings)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
log_device_placement=True))
@@ -363,11 +366,11 @@
accs = []
correct_preds, total_correct, total_preds = 0., 0., 0.
- for batch_index in range(floor(len(sentences_test) / batch_size)):
+ for batch_index in range(floor(len(sentences_dev) / batch_size)):
sentences_test_batch, chars_batch_test, word_length_batch_test, \
labels_test_batch, length_test = name_finder.mini_batch(rev_word_dict,
- sentences_test,
- labels_test,
+ sentences_dev,
+ labels_dev,
batch_size,
batch_index)