Make batch size for normalizer inference dynamic
diff --git a/tf-ner-poc/src/main/java/org/apache/opennlp/normalizer/Normalizer.java b/tf-ner-poc/src/main/java/org/apache/opennlp/normalizer/Normalizer.java
index 2ad4809..52d44cf 100644
--- a/tf-ner-poc/src/main/java/org/apache/opennlp/normalizer/Normalizer.java
+++ b/tf-ner-poc/src/main/java/org/apache/opennlp/normalizer/Normalizer.java
@@ -73,8 +73,6 @@
 
   public String[] normalize(String[] texts) {
 
-    // TODO: Batch size is hard coded in the graph, make it dynamic or at padding here
-
     int textLengths[] = Arrays.stream(texts).mapToInt(String::length).toArray();
     int maxLength = Arrays.stream(textLengths).max().getAsInt();
 
@@ -89,11 +87,13 @@
     }
 
     try (Tensor<?> charTensor = Tensor.create(charIds);
-         Tensor<?> textLength = Tensor.create(textLengths)) {
+         Tensor<?> textLength = Tensor.create(textLengths);
+         Tensor<?> batchSize = Tensor.create(texts.length)) {
 
       List<Tensor<?>> result = session.runner()
           .feed("encoder_char_ids", charTensor)
           .feed("encoder_lengths", textLength)
+          .feed("batch_size", batchSize)
           .fetch("decode", 0).run();
 
       try (Tensor<?> translationTensor = result.get(0)) {
diff --git a/tf-ner-poc/src/main/python/normalizer/normalizer.py b/tf-ner-poc/src/main/python/normalizer/normalizer.py
index 86e735e..f721bb0 100644
--- a/tf-ner-poc/src/main/python/normalizer/normalizer.py
+++ b/tf-ner-poc/src/main/python/normalizer/normalizer.py
@@ -78,9 +78,11 @@
     encoder_char_dim = 100
     num_units = 256
 
+    batch_size_ph = tf.placeholder_with_default(batch_size, shape=(), name="batch_size")
+
     # Encoder
-    encoder_char_ids_ph = tf.placeholder(tf.int32, shape=[batch_size, None], name="encoder_char_ids")
-    encoder_lengths_ph = tf.placeholder(tf.int32, shape=[batch_size], name="encoder_lengths")
+    encoder_char_ids_ph = tf.placeholder(tf.int32, shape=[None, None], name="encoder_char_ids")
+    encoder_lengths_ph = tf.placeholder(tf.int32, shape=[None], name="encoder_lengths")
 
     encoder_embedding_weights = tf.get_variable(name="char_embeddings", dtype=tf.float32,
                         shape=[encoder_nchars, encoder_char_dim])
@@ -90,7 +92,7 @@
     encoder_emb_inp = tf.transpose(encoder_emb_inp, perm=[1, 0, 2])
 
     encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)
-    initial_state = encoder_cell.zero_state(batch_size, dtype=tf.float32)
+    initial_state = encoder_cell.zero_state(batch_size_ph, dtype=tf.float32)
 
     encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
         encoder_cell, encoder_emb_inp, initial_state=initial_state,
@@ -98,8 +100,8 @@
         time_major=True, swap_memory=True)
 
     # Decoder
-    decoder_char_ids_ph = tf.placeholder(tf.int32, shape=[batch_size, None], name="decoder_char_ids")
-    decoder_lengths = tf.placeholder(tf.int32, shape=[batch_size], name="decoder_lengths")
+    decoder_char_ids_ph = tf.placeholder(tf.int32, shape=[None, None], name="decoder_char_ids")
+    decoder_lengths = tf.placeholder(tf.int32, shape=[None], name="decoder_lengths")
 
     # decoder output (decoder_input shifted to the left by one)
 
@@ -121,7 +123,7 @@
         attention_layer_size=num_units)
 
     # decoder_initial_state = encoder_state
-    decoder_initial_state = decoder_cell.zero_state(dtype=tf.float32, batch_size=batch_size)
+    decoder_initial_state = decoder_cell.zero_state(dtype=tf.float32, batch_size=batch_size_ph)
 
     if "TRAIN" == mode:
 
@@ -166,16 +168,14 @@
     if "EVAL" == mode:
         helperE = tf.contrib.seq2seq.GreedyEmbeddingHelper(
             decoder_embedding_weights,
-            tf.fill([batch_size], decoder_nchars-2), decoder_nchars-1)
+            tf.fill([batch_size_ph], decoder_nchars-2), decoder_nchars-1)
         decoderE = tf.contrib.seq2seq.BasicDecoder(
             decoder_cell, helperE, decoder_initial_state,
             output_layer=projection_layer)
-        outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoderE, maximum_iterations=15)
-
+        outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoderE, maximum_iterations=20)
 
         translations = tf.identity(outputs.sample_id, name="decode")
 
-        # the outputs don't decode anything ...
         return encoder_char_ids_ph, encoder_lengths_ph, translations
 
 def encode_chars(names):