Fix mini batch and char encoding (#36)
diff --git a/tf-ner-poc/src/main/python/namefinder/namefinder.py b/tf-ner-poc/src/main/python/namefinder/namefinder.py
index 548a3a9..8cd1503 100644
--- a/tf-ner-poc/src/main/python/namefinder/namefinder.py
+++ b/tf-ner-poc/src/main/python/namefinder/namefinder.py
@@ -68,8 +68,9 @@
class NameFinder:
label_dict = {}
- def __init__(self, vector_size=100):
+ def __init__(self, use_lower_case_embeddings=False, vector_size=100):
self.__vector_size = vector_size
+ self.__use_lower_case_embeddings = use_lower_case_embeddings
def load_data(self, word_dict, file):
with open(file) as f:
@@ -82,12 +83,21 @@
for line in raw_data:
name_sample = NameSample(line)
sentence = []
+ tokens = []
if len(name_sample.tokens) == 0:
continue
for token in name_sample.tokens:
- vector = 0
+
+ chars_set.update(list(token)) # Add all chars to the set
+ tokens.append(token) # Add original token so chars can be encoded correctly later
+
+ if self.__use_lower_case_embeddings:
+ token = token.lower()
+
+ # TODO: implement NUM encoding
+
if word_dict.get(token) is not None:
vector = word_dict[token]
else:
@@ -95,15 +105,13 @@
sentence.append(vector)
- for c in token:
- chars_set.add(c)
-
label = ["other"] * len(name_sample.tokens)
for name in name_sample.names:
label[name[0]] = name[2] + "-start"
for i in range(name[0] + 1, name[1]):
label[i] = name[2] + "-cont"
- sentences.append(sentence)
+
+ sentences.append((sentence, tokens)) # Add a tuple of list of word vectors and list of original words
labels.append(label)
for label_string in label:
@@ -119,14 +127,14 @@
return label_ids
- def mini_batch(self, rev_word_dict, char_dict, sentences, labels, batch_size, batch_index):
+ def mini_batch(self, char_dict, sentences, labels, batch_size, batch_index):
begin = batch_size * batch_index
end = min(batch_size * (batch_index + 1), len(labels))
# Determine the max sentence length in the batch
max_length = 0
for i in range(begin, end):
- length = len(sentences[i])
+ length = len(sentences[i][0])
if length > max_length:
max_length = length
@@ -134,15 +142,15 @@
lb = []
seq_length = []
for i in range(begin, end):
- sb.append(sentences[i] + [0] * max(max_length - len(sentences[i]), 0))
+ sb.append(sentences[i][0] + [0] * max(max_length - len(sentences[i][0]), 0))
lb.append(self.encode_labels(labels[i]) + [0] * max(max_length - len(labels[i]), 0))
- seq_length.append(len(sentences[i]))
+ seq_length.append(len(sentences[i][0]))
# Determine the max word length in the batch
max_word_length = 0
for i in range(begin, end):
- for word in sentences[i]:
- length = len(rev_word_dict[word])
+ for word in sentences[i][1]:
+ length = len(word)
if length > max_word_length:
max_word_length = length
@@ -151,11 +159,11 @@
for i in range(begin, end):
sentence_word_length = []
sentence_word_chars = []
- for word in sentences[i]:
+ for word in sentences[i][1]:
word_chars = []
- for c in rev_word_dict[word]:
- word_chars.append(char_dict[c]) # TODO: This fails if c is not present
+ for c in word:
+ word_chars.append(char_dict[c])
sentence_word_length.append(len(word_chars))
word_chars = word_chars + [0] * max(max_word_length - len(word_chars), 0)
@@ -412,7 +420,7 @@
# mini_batch should also return char_ids and word length ...
sentences_batch, chars_batch, word_length_batch, labels_batch, lengths = \
- name_finder.mini_batch(rev_word_dict, char_dict, sentences, labels, batch_size, batch_index)
+ name_finder.mini_batch(char_dict, sentences, labels, batch_size, batch_index)
feed_dict = {token_ids_ph: sentences_batch, char_ids_ph: chars_batch,
word_lengths_ph: word_length_batch, sequence_lengths_ph: lengths,
@@ -424,8 +432,7 @@
correct_preds, total_correct, total_preds = 0., 0., 0.
for batch_index in range(floor(len(sentences_dev) / batch_size)):
sentences_test_batch, chars_batch_test, word_length_batch_test, \
- labels_test_batch, length_test = name_finder.mini_batch(rev_word_dict,
- char_dict,
+ labels_test_batch, length_test = name_finder.mini_batch(char_dict,
sentences_dev,
labels_dev,
batch_size,