Merge branch 'devel' of github.com:jganitkevitch/joshua into cleanup
diff --git a/src/joshua/decoder/JoshuaConfiguration.java b/src/joshua/decoder/JoshuaConfiguration.java
index 70d72bb..de3d2fa 100644
--- a/src/joshua/decoder/JoshuaConfiguration.java
+++ b/src/joshua/decoder/JoshuaConfiguration.java
@@ -29,7 +29,7 @@
// old format specifying attributes of a single language model separately
public static String lm_type = "kenlm";
- public static double lm_ceiling_cost = 100;
+ public static float lm_ceiling_cost = 100;
public static boolean use_left_equivalent_state = false;
public static boolean use_right_equivalent_state = false;
public static int lm_order = 3;
@@ -271,7 +271,7 @@
}
} else if (parameter.equals(normalize_key("lm_ceiling_cost"))) {
- lm_ceiling_cost = Double.parseDouble(fds[1]);
+ lm_ceiling_cost = Float.parseFloat(fds[1]);
logger.finest(String.format("lm_ceiling_cost: %s", lm_ceiling_cost));
} else if (parameter.equals(normalize_key("use_left_equivalent_state"))) {
diff --git a/src/joshua/decoder/ff/TargetBigram.java b/src/joshua/decoder/ff/TargetBigram.java
index add084a..15a752e 100644
--- a/src/joshua/decoder/ff/TargetBigram.java
+++ b/src/joshua/decoder/ff/TargetBigram.java
@@ -41,13 +41,13 @@
int sentID) {
NgramDPState state = (NgramDPState) tailNode.getDPState(this.getStateComputer());
- Integer leftWord = state.getLeftLMStateWords().get(0);
- List<Integer> rightContext = state.getRightLMStateWords();
- Integer rightWord = rightContext.get(rightContext.size() - 1);
+ int leftWord = state.getLeftLMStateWords()[0];
+ int[] rightContext = state.getRightLMStateWords();
+ int rightWord = rightContext[rightContext.length - 1];
FeatureVector features = new FeatureVector();
- features.put("<s> " + leftWord.toString(), 1.0f);
- features.put(rightWord.toString() + " </s>", 1.0f);
+ features.put("<s> " + leftWord, 1.0f);
+ features.put(rightWord + " </s>", 1.0f);
return features;
}
@@ -66,7 +66,6 @@
* @param features
*/
private FeatureVector computeTransition(int[] enWords, List<HGNode> tailNodes) {
-
List<Integer> currentNgram = new LinkedList<Integer>();
FeatureVector features = new FeatureVector();
@@ -75,42 +74,28 @@
if (Vocabulary.nt(curID)) {
int index = -(curID + 1);
-
NgramDPState state = (NgramDPState) tailNodes.get(index)
.getDPState(this.getStateComputer());
- List<Integer> leftContext = state.getLeftLMStateWords();
- List<Integer> rightContext = state.getRightLMStateWords();
- if (leftContext.size() != rightContext.size()) {
- throw new RuntimeException(
- "computeTransition: left and right contexts have unequal lengths");
- }
+ int[] leftContext = state.getLeftLMStateWords();
+ int[] rightContext = state.getRightLMStateWords();
// Left context.
- for (int i = 0; i < leftContext.size(); i++) {
- int t = leftContext.get(i);
+ for (int t : leftContext) {
currentNgram.add(t);
-
if (currentNgram.size() == 2) {
- // System.err.println(String.format("NGRAM(%s) = %.5f",
- // Vocabulary.getWords(currentNgram), prob));
-
String ngram = join(currentNgram);
if (features.containsKey(ngram))
features.put(ngram, 1);
else
features.put(ngram, features.get(ngram) + 1);
-
currentNgram.remove(0);
}
}
-
- // Right context.
+ // Replace right context.
int tSize = currentNgram.size();
- for (int i = 0; i < rightContext.size(); i++) {
- // replace context
- currentNgram.set(tSize - rightContext.size() + i, rightContext.get(i));
- }
-
+ for (int i = 0; i < rightContext.length; i++)
+ currentNgram.set(tSize - rightContext.length + i, rightContext[i]);
+
} else { // terminal words
currentNgram.add(curID);
if (currentNgram.size() == 2) {
@@ -119,7 +104,6 @@
features.put(ngram, 1);
else
features.put(ngram, features.get(ngram) + 1);
-
currentNgram.remove(0);
}
}
diff --git a/src/joshua/decoder/ff/lm/AbstractLM.java b/src/joshua/decoder/ff/lm/AbstractLM.java
index e5e9b99..d9977db 100644
--- a/src/joshua/decoder/ff/lm/AbstractLM.java
+++ b/src/joshua/decoder/ff/lm/AbstractLM.java
@@ -35,17 +35,17 @@
}
- public final double sentenceLogProbability(List<Integer> sentence, int order, int startIndex) {
+ public final float sentenceLogProbability(int[] sentence, int order, int startIndex) {
return super.sentenceLogProbability(sentence, order, startIndex);
}
- public final double ngramLogProbability(int[] ngram) {
+ public final float ngramLogProbability(int[] ngram) {
return super.ngramLogProbability(ngram);
}
- public final double ngramLogProbability(int[] ngram, int order) {
+ public final float ngramLogProbability(int[] ngram, int order) {
if (ngram.length > order) {
throw new RuntimeException("ngram length is greather than the max order");
}
@@ -59,28 +59,28 @@
throw new RuntimeException("Error: history size is " + historySize);
// return 0;
}
- double probability = ngramLogProbability_helper(ngram, order);
+ float probability = ngramLogProbability_helper(ngram, order);
if (probability < -JoshuaConfiguration.lm_ceiling_cost) {
probability = -JoshuaConfiguration.lm_ceiling_cost;
}
return probability;
}
- protected abstract double ngramLogProbability_helper(int[] ngram, int order);
+ protected abstract float ngramLogProbability_helper(int[] ngram, int order);
/**
* @deprecated this function is much slower than the int[] version
*/
@Deprecated
- public final double logProbOfBackoffState(List<Integer> ngram, int order,
+ public final float logProbOfBackoffState(List<Integer> ngram, int order,
int qtyAdditionalBackoffWeight) {
return logProbabilityOfBackoffState(Support.subIntArray(ngram, 0, ngram.size()), order,
qtyAdditionalBackoffWeight);
}
- public final double logProbabilityOfBackoffState(int[] ngram, int order,
+ public final float logProbabilityOfBackoffState(int[] ngram, int order,
int qtyAdditionalBackoffWeight) {
if (ngram.length > order) {
throw new RuntimeException("ngram length is greather than the max order");
@@ -91,12 +91,12 @@
if (qtyAdditionalBackoffWeight > 0) {
return logProbabilityOfBackoffState_helper(ngram, order, qtyAdditionalBackoffWeight);
} else {
- return 0.0;
+ return 0;
}
}
- protected abstract double logProbabilityOfBackoffState_helper(int[] ngram, int order,
+ protected abstract float logProbabilityOfBackoffState_helper(int[] ngram, int order,
int qtyAdditionalBackoffWeight);
diff --git a/src/joshua/decoder/ff/lm/DefaultNGramLanguageModel.java b/src/joshua/decoder/ff/lm/DefaultNGramLanguageModel.java
index 7735087..a6223d0 100644
--- a/src/joshua/decoder/ff/lm/DefaultNGramLanguageModel.java
+++ b/src/joshua/decoder/ff/lm/DefaultNGramLanguageModel.java
@@ -15,12 +15,12 @@
*/
package joshua.decoder.ff.lm;
+import java.util.Arrays;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import joshua.corpus.Vocabulary;
-import joshua.decoder.Support;
/**
* This class provides a default implementation for the Equivalent LM State optimization (namely,
@@ -65,17 +65,17 @@
return false;
}
- public double sentenceLogProbability(List<Integer> sentence, int order, int startIndex) {
- if (sentence == null) return 0.0;
- int sentenceLength = sentence.size();
- if (sentenceLength <= 0) return 0.0;
+ public float sentenceLogProbability(int[] sentence, int order, int startIndex) {
+ if (sentence == null) return 0.0f;
+ int sentenceLength = sentence.length;
+ if (sentenceLength <= 0) return 0.0f;
- double probability = 0.0;
- // partial ngrams at the begining
+ float probability = 0.0f;
+ // partial ngrams at the beginning
for (int j = startIndex; j < order && j <= sentenceLength; j++) {
// TODO: startIndex dependents on the order, e.g., this.ngramOrder-1 (in srilm, for 3-gram lm,
// start_index=2. othercase, need to check)
- int[] ngram = Support.subIntArray(sentence, 0, j);
+ int[] ngram = Arrays.copyOfRange(sentence, 0, j);
double logProb = ngramLogProbability(ngram, order);
if (logger.isLoggable(Level.FINE)) {
String words = Vocabulary.getWords(ngram);
@@ -86,7 +86,7 @@
// regular-order ngrams
for (int i = 0; i <= sentenceLength - order; i++) {
- int[] ngram = Support.subIntArray(sentence, i, i + order);
+ int[] ngram = Arrays.copyOfRange(sentence, i, i + order);
double logProb = ngramLogProbability(ngram, order);
if (logger.isLoggable(Level.FINE)) {
String words = Vocabulary.getWords(ngram);
@@ -98,26 +98,18 @@
return probability;
}
-
- /** @deprecated this function is much slower than the int[] version */
- @Deprecated
- public double ngramLogProbability(List<Integer> ngram, int order) {
- return ngramLogProbability(Support.subIntArray(ngram, 0, ngram.size()), order);
- }
-
-
- public double ngramLogProbability(int[] ngram) {
+ public float ngramLogProbability(int[] ngram) {
return this.ngramLogProbability(ngram, this.ngramOrder);
}
- public abstract double ngramLogProbability(int[] ngram, int order);
+ public abstract float ngramLogProbability(int[] ngram, int order);
/**
* Will never be called, because BACKOFF_LEFT_LM_STATE_SYM_ID token will never exist. However,
* were it to be called, it should return a probability of 1 (logprob of 0).
*/
- public double logProbOfBackoffState(List<Integer> ngram, int order, int qtyAdditionalBackoffWeight) {
+ public float logProbOfBackoffState(List<Integer> ngram, int order, int qtyAdditionalBackoffWeight) {
return 0; // log(1) == 0;
}
@@ -125,7 +117,7 @@
* Will never be called, because BACKOFF_LEFT_LM_STATE_SYM_ID token will never exist. However,
* were it to be called, it should return a probability of 1 (logprob of 0).
*/
- public double logProbabilityOfBackoffState(int[] ngram, int order, int qtyAdditionalBackoffWeight) {
+ public float logProbabilityOfBackoffState(int[] ngram, int order, int qtyAdditionalBackoffWeight) {
return 0; // log(1) == 0;
}
diff --git a/src/joshua/decoder/ff/lm/LanguageModelFF.java b/src/joshua/decoder/ff/lm/LanguageModelFF.java
index 1a480bc..7168dcc 100644
--- a/src/joshua/decoder/ff/lm/LanguageModelFF.java
+++ b/src/joshua/decoder/ff/lm/LanguageModelFF.java
@@ -6,12 +6,13 @@
import java.util.logging.Logger;
import joshua.corpus.Vocabulary;
+import joshua.decoder.Support;
import joshua.decoder.chart_parser.SourcePath;
import joshua.decoder.ff.FeatureVector;
import joshua.decoder.ff.StatefulFF;
-import joshua.decoder.ff.state_maintenance.StateComputer;
import joshua.decoder.ff.state_maintenance.DPState;
import joshua.decoder.ff.state_maintenance.NgramDPState;
+import joshua.decoder.ff.state_maintenance.StateComputer;
import joshua.decoder.ff.tm.Rule;
import joshua.decoder.hypergraph.HGNode;
@@ -25,6 +26,7 @@
* </ol>
*
* @author Matt Post <post@cs.jhu.edu>
+ * @author Juri Ganitkevitch <juri@cs.jhu.edu>
* @author Zhifei Li, <zhifei.work@gmail.com>
*/
public class LanguageModelFF extends StatefulFF {
@@ -174,7 +176,7 @@
*/
private float computeTransition(int[] enWords, List<HGNode> tailNodes) {
- List<Integer> currentNgram = new LinkedList<Integer>();
+ LinkedList<Integer> currentNgram = new LinkedList<Integer>();
float transitionLogP = 0.0f;
for (int c = 0; c < enWords.length; c++) {
@@ -185,16 +187,12 @@
NgramDPState state = (NgramDPState) tailNodes.get(index)
.getDPState(this.getStateComputer());
- List<Integer> leftContext = state.getLeftLMStateWords();
- List<Integer> rightContext = state.getRightLMStateWords();
- if (leftContext.size() != rightContext.size()) {
- throw new RuntimeException(
- "computeTransition: left and right contexts have unequal lengths");
- }
+ int[] leftContext = state.getLeftLMStateWords();
+ int[] rightContext = state.getRightLMStateWords();
// Left context.
- for (int i = 0; i < leftContext.size(); i++) {
- int t = leftContext.get(i);
+ for (int i = 0; i < leftContext.length; i++) {
+ int t = leftContext[i];
currentNgram.add(t);
// Always calculate logP for <bo>: additional backoff weight
@@ -206,31 +204,33 @@
currentNgram.size(), numAdditionalBackoffWeight);
if (currentNgram.size() == this.ngramOrder) {
- currentNgram.remove(0);
+ currentNgram.removeFirst();
}
} else if (currentNgram.size() == this.ngramOrder) {
- // compute the current word probablity, and remove it
- float prob = (float) this.lmGrammar.ngramLogProbability(currentNgram, this.ngramOrder);
+ // Compute the current word probability, and remove it.s
+ float prob = this.lmGrammar.ngramLogProbability(this.toArray(currentNgram),
+ this.ngramOrder);
// System.err.println(String.format("NGRAM(%s) = %.5f",
// Vocabulary.getWords(currentNgram), prob));
transitionLogP += prob;
- currentNgram.remove(0);
+ currentNgram.removeFirst();
}
}
// Right context.
int tSize = currentNgram.size();
- for (int i = 0; i < rightContext.size(); i++) {
+ for (int i = 0; i < rightContext.length; i++) {
// replace context
- currentNgram.set(tSize - rightContext.size() + i, rightContext.get(i));
+ currentNgram.set(tSize - rightContext.length + i, rightContext[i]);
}
} else { // terminal words
currentNgram.add(curID);
if (currentNgram.size() == this.ngramOrder) {
// compute the current word probablity, and remove it
- float prob = (float) this.lmGrammar.ngramLogProbability(currentNgram, this.ngramOrder);
+ float prob = this.lmGrammar.ngramLogProbability(this.toArray(currentNgram),
+ this.ngramOrder);
transitionLogP += prob;
// System.err.println(String.format("NGRAM(%s) = %.5f", Vocabulary.getWords(currentNgram),
// prob));
@@ -252,21 +252,16 @@
private float computeFinalTransition(NgramDPState state) {
float res = 0.0f;
- List<Integer> currentNgram = new LinkedList<Integer>();
- List<Integer> leftContext = state.getLeftLMStateWords();
- List<Integer> rightContext = state.getRightLMStateWords();
-
- if (leftContext.size() != rightContext.size()) {
- throw new RuntimeException(
- "LMModel.compute_equiv_state_final_transition: left and right contexts have unequal lengths");
- }
+ LinkedList<Integer> currentNgram = new LinkedList<Integer>();
+ int[] leftContext = state.getLeftLMStateWords();
+ int[] rightContext = state.getRightLMStateWords();
// ================ left context
if (addStartAndEndSymbol)
currentNgram.add(START_SYM_ID);
- for (int i = 0; i < leftContext.size(); i++) {
- int t = leftContext.get(i);
+ for (int i = 0; i < leftContext.length; i++) {
+ int t = leftContext[i];
currentNgram.add(t);
if (t == BACKOFF_LEFT_LM_STATE_SYM_ID) {// calculate logP for <bo>: additional backoff weight
@@ -279,29 +274,31 @@
} else { // partial ngram
// compute the current word probablity
if (currentNgram.size() >= 2) { // start from bigram
- float prob = (float) this.lmGrammar
- .ngramLogProbability(currentNgram, currentNgram.size());
+ float prob = this.lmGrammar.ngramLogProbability(this.toArray(currentNgram),
+ currentNgram.size());
// System.err.println(String.format("NGRAM(%s) = %.5f", Vocabulary.getWords(currentNgram),
// prob));
res += prob;
}
}
if (currentNgram.size() == this.ngramOrder) {
- currentNgram.remove(0);
+ currentNgram.removeFirst();
}
}
// ================ right context
// switch context, we will never score the right context probability because they are either
- // duplicate or partional ngrams
+ // duplicate or partial ngrams
if (addStartAndEndSymbol) {
int tSize = currentNgram.size();
- for (int i = 0; i < rightContext.size(); i++) {// replace context
- currentNgram.set(tSize - rightContext.size() + i, rightContext.get(i));
- }
+ for (int i = 0; i < rightContext.length; i++)
+ currentNgram.removeLast();
+ for (int i = 0; i < rightContext.length; i++)
+ currentNgram.add(rightContext[i]);
currentNgram.add(STOP_SYM_ID);
- float prob = (float) this.lmGrammar.ngramLogProbability(currentNgram, currentNgram.size());
+ float prob = this.lmGrammar.ngramLogProbability(this.toArray(currentNgram),
+ currentNgram.size());
res += prob;
// System.err.println(String.format("NGRAM(%s) = %.5f", Vocabulary.getWords(currentNgram),
// prob));
@@ -327,7 +324,6 @@
int currentWord = enWords[c];
if (Vocabulary.nt(currentWord)) {
estimate += scoreChunkLogP(words, considerIncompleteNgrams, skipStart);
- considerIncompleteNgrams = true;
words.clear();
skipStart = false;
} else {
@@ -344,13 +340,14 @@
private float estimateStateLogProb(NgramDPState state, boolean addStart, boolean addEnd) {
float res = 0.0f;
- List<Integer> leftContext = state.getLeftLMStateWords();
+ int[] leftContext = state.getLeftLMStateWords();
if (null != leftContext) {
List<Integer> words = new ArrayList<Integer>();
if (addStart == true)
words.add(START_SYM_ID);
- words.addAll(leftContext);
+ for (int w : leftContext)
+ words.add(w);
boolean considerIncompleteNgrams = true;
boolean skipStart = true;
@@ -361,8 +358,10 @@
}
if (addEnd == true) {
- List<Integer> rightContext = state.getRightLMStateWords();
- List<Integer> list = new ArrayList<Integer>(rightContext);
+ int[] rightContext = state.getRightLMStateWords();
+ List<Integer> list = new ArrayList<Integer>(rightContext.length);
+ for (int w : rightContext)
+ list.add(w);
list.add(STOP_SYM_ID);
float tem = scoreChunkLogP(list, false, false);
res += tem;
@@ -390,8 +389,16 @@
startIndex = 1;
}
// System.err.println("Estimate: " + Vocabulary.getWords(words));
- return (float) this.lmGrammar.sentenceLogProbability(words, this.ngramOrder, startIndex);
+ return (float) this.lmGrammar.sentenceLogProbability(
+ Support.subIntArray(words, 0, words.size()), this.ngramOrder, startIndex);
}
}
+ private final int[] toArray(List<Integer> input) {
+ int[] output = new int[input.size()];
+ int i = 0;
+ for (int v : input)
+ output[i++] = v;
+ return output;
+ }
}
diff --git a/src/joshua/decoder/ff/lm/NGramLanguageModel.java b/src/joshua/decoder/ff/lm/NGramLanguageModel.java
index 2da7ff3..75358f3 100644
--- a/src/joshua/decoder/ff/lm/NGramLanguageModel.java
+++ b/src/joshua/decoder/ff/lm/NGramLanguageModel.java
@@ -55,7 +55,6 @@
boolean registerWord(String token, int id);
- // BUG: why do we pass the order? Does this method reduce the order as well?
/**
* @param sentence the sentence to be scored
* @param order the order of N-grams for the LM
@@ -63,17 +62,11 @@
* get the prob for the whole sentence, then startIndex should be 1
* @return the LogP of the whole sentence
*/
- double sentenceLogProbability(List<Integer> sentence, int order, int startIndex);
+ float sentenceLogProbability(int[] sentence, int order, int startIndex);
+ float ngramLogProbability(int[] ngram, int order);
- /**
- * @param order used to temporarily reduce the order used by the model.
- */
- double ngramLogProbability(List<Integer> ngram, int order);
-
- double ngramLogProbability(int[] ngram, int order);
-
- double ngramLogProbability(int[] ngram);
+ float ngramLogProbability(int[] ngram);
// ===============================================================
@@ -86,9 +79,9 @@
* for each such token, and then call ngramLogProbability for the remaining actual N-gram.
*/
// TODO Is this really the best interface?
- double logProbOfBackoffState(List<Integer> ngram, int order, int qtyAdditionalBackoffWeight);
+ float logProbOfBackoffState(List<Integer> ngram, int order, int qtyAdditionalBackoffWeight);
- double logProbabilityOfBackoffState(int[] ngram, int order, int qtyAdditionalBackoffWeight);
+ float logProbabilityOfBackoffState(int[] ngram, int order, int qtyAdditionalBackoffWeight);
int[] leftEquivalentState(int[] originalState, int order, double[] cost);
diff --git a/src/joshua/decoder/ff/lm/berkeley_lm/LMGrammarBerkeley.java b/src/joshua/decoder/ff/lm/berkeley_lm/LMGrammarBerkeley.java
index e69e460..6dbd2c1 100644
--- a/src/joshua/decoder/ff/lm/berkeley_lm/LMGrammarBerkeley.java
+++ b/src/joshua/decoder/ff/lm/berkeley_lm/LMGrammarBerkeley.java
@@ -28,13 +28,11 @@
import java.util.logging.Logger;
import joshua.corpus.Vocabulary;
-import joshua.decoder.JoshuaConfiguration;
-import joshua.decoder.Support;
import joshua.decoder.ff.lm.DefaultNGramLanguageModel;
import edu.berkeley.nlp.lm.ArrayEncodedNgramLanguageModel;
import edu.berkeley.nlp.lm.ConfigOptions;
-import edu.berkeley.nlp.lm.WordIndexer;
import edu.berkeley.nlp.lm.StringWordIndexer;
+import edu.berkeley.nlp.lm.WordIndexer;
import edu.berkeley.nlp.lm.cache.ArrayEncodedCachingLmWrapper;
import edu.berkeley.nlp.lm.io.LmReaders;
import edu.berkeley.nlp.lm.util.StrUtils;
@@ -123,17 +121,17 @@
}
@Override
- public double sentenceLogProbability(List<Integer> sentence, int order, int startIndex) {
- if (sentence == null) return 0.0;
- int sentenceLength = sentence.size();
- if (sentenceLength <= 0) return 0.0;
+ public float sentenceLogProbability(int[] sentence, int order, int startIndex) {
+ if (sentence == null) return 0;
+ int sentenceLength = sentence.length;
+ if (sentenceLength <= 0) return 0;
- double probability = 0.0;
+ float probability = 0;
// partial ngrams at the begining
for (int j = startIndex; j < order && j <= sentenceLength; j++) {
- // TODO: startIndex dependents on the order, e.g., this.ngramOrder-1 (in srilm, for 3-gram lm,
+ // TODO: startIndex dependens on the order, e.g., this.ngramOrder-1 (in srilm, for 3-gram lm,
// start_index=2. othercase, need to check)
- int[] ngram = Support.subIntArray(sentence, 0, j);
+ int[] ngram = Arrays.copyOfRange(sentence, 0, j);
double logProb = ngramLogProbability_helper(ngram, false);
if (logger.isLoggable(Level.FINE)) {
String words = Vocabulary.getWords(ngram);
@@ -144,7 +142,7 @@
// regular-order ngrams
for (int i = 0; i <= sentenceLength - order; i++) {
- int[] ngram = Support.subIntArray(sentence, i, i + order);
+ int[] ngram = Arrays.copyOfRange(sentence, i, i + order);
double logProb = ngramLogProbability_helper(ngram, false);
if (logger.isLoggable(Level.FINE)) {
String words = Vocabulary.getWords(ngram);
@@ -156,7 +154,7 @@
return probability;
}
- protected double ngramLogProbability_helper(int[] ngram, boolean log) {
+ protected float ngramLogProbability_helper(int[] ngram, boolean log) {
int[] mappedNgram = arrayScratch.get();
if (mappedNgram.length < ngram.length) {
@@ -182,15 +180,15 @@
logHandler = handler;
}
- public double ngramLogProbability(int[] ngram) {
+ public float ngramLogProbability(int[] ngram) {
return ngramLogProbability_helper(ngram,true);
}
- public double logProbOfBackoffState(List<Integer> ngram, int order, int qtyAdditionalBackoffWeight) {
+ public float logProbOfBackoffState(List<Integer> ngram, int order, int qtyAdditionalBackoffWeight) {
return 0;
}
- public double logProbabilityOfBackoffState(int[] ngram, int order, int qtyAdditionalBackoffWeight) {
+ public float logProbabilityOfBackoffState(int[] ngram, int order, int qtyAdditionalBackoffWeight) {
return 0;
}
@@ -203,7 +201,7 @@
}
@Override
- public double ngramLogProbability(int[] ngram, int order) {
+ public float ngramLogProbability(int[] ngram, int order) {
return ngramLogProbability(ngram);
}
diff --git a/src/joshua/decoder/ff/lm/bloomfilter_lm/BloomFilterLanguageModel.java b/src/joshua/decoder/ff/lm/bloomfilter_lm/BloomFilterLanguageModel.java
index 1fc3634..0ab8fe9 100644
--- a/src/joshua/decoder/ff/lm/bloomfilter_lm/BloomFilterLanguageModel.java
+++ b/src/joshua/decoder/ff/lm/bloomfilter_lm/BloomFilterLanguageModel.java
@@ -165,20 +165,20 @@
*
* @return the linearly-interpolated Witten-Bell smoothed probability of an ngram
*/
- private double wittenBell(int[] ngram, int ngramOrder) {
+ private float wittenBell(int[] ngram, int ngramOrder) {
int end = ngram.length;
double p = p0; // current calculated probability
// note that p0 and lambda0 are independent of the given
// ngram so they are calculated ahead of time.
int MAX_QCOUNT = getCount(ngram, ngram.length - 1, ngram.length, maxQ);
if (MAX_QCOUNT == 0) // OOV!
- return p;
+ return (float) p;
double pML = Math.log(unQuantize(MAX_QCOUNT)) - numTokens;
// p += lambda0 * pML;
p = logAdd(p, (lambda0 + pML));
if (ngram.length == 1) { // if it's a unigram, we're done
- return p;
+ return (float) p;
}
// otherwise we calculate the linear interpolation
// with higher order models.
@@ -188,7 +188,7 @@
// terms in the interpolation must be zero, so we
// are done here.
if (historyCnt == 0) {
- return p;
+ return (float) p;
}
int historyTypesAfter = getTypesAfter(ngram, i, end, historyCnt);
// unQuantize the counts we got from the BF
@@ -202,11 +202,11 @@
int wordCount = getCount(ngram, i + 1, end, historyTypesAfter);
double WC = unQuantize(wordCount);
// p += lambda * p_ML(w|h)
- if (WC == 0) return p;
+ if (WC == 0) return (float) p;
p = logAdd(p, lambda + Math.log(WC) - Math.log(HC));
MAX_QCOUNT = wordCount;
}
- return p;
+ return (float) p;
}
/**
@@ -561,7 +561,7 @@
}
@Override
- protected double logProbabilityOfBackoffState_helper(int[] ngram, int order,
+ protected float logProbabilityOfBackoffState_helper(int[] ngram, int order,
int qtyAdditionalBackoffWeight) {
throw new UnsupportedOperationException(
"probabilityOfBackoffState_helper undefined for bloom filter LM");
@@ -577,7 +577,7 @@
* @return the language model score of the ngram
*/
@Override
- protected double ngramLogProbability_helper(int[] ngram, int order) {
+ protected float ngramLogProbability_helper(int[] ngram, int order) {
int[] lm_ngram = new int[ngram.length];
for (int i = 0; i < ngram.length; i++) {
lm_ngram[i] = Vocabulary.id(Vocabulary.word(ngram[i]));
diff --git a/src/joshua/decoder/ff/lm/buildin_lm/LMGrammarJAVA.java b/src/joshua/decoder/ff/lm/buildin_lm/LMGrammarJAVA.java
index c3bfb54..09089a1 100644
--- a/src/joshua/decoder/ff/lm/buildin_lm/LMGrammarJAVA.java
+++ b/src/joshua/decoder/ff/lm/buildin_lm/LMGrammarJAVA.java
@@ -109,14 +109,14 @@
* srilm may be smaller than by java, this happens only when the LM file have "<unk>" in backoff
* state
*/
- protected double ngramLogProbability_helper(int[] ngram, int order) {
+ protected float ngramLogProbability_helper(int[] ngram, int order) {
Double res;
int[] ngram_wrds = replace_with_unk(ngram); // TODO
// TODO: wrong implementation in hiero
if (ngram_wrds[ngram_wrds.length - 1] == UNK_SYM_ID) {
- res = -JoshuaConfiguration.lm_ceiling_cost;
+ res = (double) -JoshuaConfiguration.lm_ceiling_cost;
} else {
// TODO: untranslated words
if (null == root) {
@@ -149,7 +149,7 @@
}
res = prob + bow_sum;
}
- return res;
+ return (float) res.doubleValue();
}
private Double get_valid_prob(LMHash pos, int wrd) {
@@ -343,10 +343,10 @@
return (null != prob && prob <= MIN_LOG_P);
}
- protected double logProbabilityOfBackoffState_helper(int[] ngram_wrds, int order,
+ protected float logProbabilityOfBackoffState_helper(int[] ngram_wrds, int order,
int n_additional_bow) {
int[] backoff_wrds = Support.sub_int_array(ngram_wrds, 0, ngram_wrds.length - 1);
- double[] sum_bow = new double[1];
+ float[] sum_bow = new float[1];
check_backoff_weight(backoff_wrds, sum_bow, n_additional_bow);
return sum_bow[0];
}
@@ -355,10 +355,10 @@
// backoff weight
// if there is no backoff weight for backoff_words, then, we can return the
// finalized backoff weight
- private boolean check_backoff_weight(int[] backoff_words, double[] sum_bow, int num_backoff) {
+ private boolean check_backoff_weight(int[] backoff_words, float[] sum_bow, int num_backoff) {
if (backoff_words.length <= 0) return false;
- double sum = 0;
+ float sum = 0;
LMHash pos = root;
// the start index that backoff should be applied
diff --git a/src/joshua/decoder/ff/lm/kenlm/jni/KenLM.java b/src/joshua/decoder/ff/lm/kenlm/jni/KenLM.java
index 469c630..31583b2 100644
--- a/src/joshua/decoder/ff/lm/kenlm/jni/KenLM.java
+++ b/src/joshua/decoder/ff/lm/kenlm/jni/KenLM.java
@@ -2,7 +2,6 @@
import java.util.List;
-import joshua.decoder.Support;
import joshua.decoder.ff.lm.NGramLanguageModel;
// TODO(Joshua devs): include my state object with your LM state then
@@ -61,40 +60,28 @@
return probString(pointer, words, start - 1);
}
- /* implement NGramLanguageModel */
- /**
- * @deprecated pass int arrays to prob instead.
- */
- @Deprecated
- public double sentenceLogProbability(List<Integer> sentence, int order, int startIndex) {
- return probString(Support.subIntArray(sentence, 0, sentence.size()), startIndex);
+ @Override
+ public float sentenceLogProbability(int[] sentence, int order, int startIndex) {
+ return probString(sentence, startIndex);
}
- public double ngramLogProbability(int[] ngram, int order) {
+ public float ngramLogProbability(int[] ngram, int order) {
if (order != N && order != ngram.length)
throw new RuntimeException("Lower order not supported.");
return prob(ngram);
}
- public double ngramLogProbability(int[] ngram) {
+ public float ngramLogProbability(int[] ngram) {
return prob(ngram);
}
- /**
- * @deprecated pass int arrays to prob instead.
- */
- @Deprecated
- public double ngramLogProbability(List<Integer> ngram, int order) {
- return prob(Support.subIntArray(ngram, 0, ngram.size()));
- }
-
// TODO(Joshua devs): fix the rest of your code to use LM state properly.
// Then fix this.
- public double logProbOfBackoffState(List<Integer> ngram, int order, int qtyAdditionalBackoffWeight) {
+ public float logProbOfBackoffState(List<Integer> ngram, int order, int qtyAdditionalBackoffWeight) {
return 0;
}
- public double logProbabilityOfBackoffState(int[] ngram, int order, int qtyAdditionalBackoffWeight) {
+ public float logProbabilityOfBackoffState(int[] ngram, int order, int qtyAdditionalBackoffWeight) {
return 0;
}
diff --git a/src/joshua/decoder/ff/similarity/EdgePhraseSimilarityFF.java b/src/joshua/decoder/ff/similarity/EdgePhraseSimilarityFF.java
index a70ee1b..31283b9 100644
--- a/src/joshua/decoder/ff/similarity/EdgePhraseSimilarityFF.java
+++ b/src/joshua/decoder/ff/similarity/EdgePhraseSimilarityFF.java
@@ -82,7 +82,7 @@
int lm_state_size = 0;
for (HGNode node : tailNodes) {
NgramDPState state = (NgramDPState) node.getDPState(stateComputer);
- lm_state_size += state.getLeftLMStateWords().size() + state.getRightLMStateWords().size();
+ lm_state_size += state.getLeftLMStateWords().length + state.getRightLMStateWords().length;
}
ArrayList<int[]> batch = new ArrayList<int[]>();
diff --git a/src/joshua/decoder/ff/state_maintenance/DPState.java b/src/joshua/decoder/ff/state_maintenance/DPState.java
index 310be34..dc892a6 100644
--- a/src/joshua/decoder/ff/state_maintenance/DPState.java
+++ b/src/joshua/decoder/ff/state_maintenance/DPState.java
@@ -17,11 +17,16 @@
package joshua.decoder.ff.state_maintenance;
/**
- * No longer necessary, actually, since it doesn't enforce anything.
+ * Abstract class enforcing explicit implementation of the standard methods.
*
* @author Zhifei Li, <zhifei.work@gmail.com>
* @author Juri Ganitkevitch, <juri@cs.jhu.edu>
*/
-public interface DPState {
- // Nothing.
+public abstract class DPState {
+
+ public abstract String toString();
+
+ public abstract int hashCode();
+
+ public abstract boolean equals(Object other);
}
diff --git a/src/joshua/decoder/ff/state_maintenance/NgramDPState.java b/src/joshua/decoder/ff/state_maintenance/NgramDPState.java
index 23bc92c..fe0391b 100644
--- a/src/joshua/decoder/ff/state_maintenance/NgramDPState.java
+++ b/src/joshua/decoder/ff/state_maintenance/NgramDPState.java
@@ -1,6 +1,6 @@
package joshua.decoder.ff.state_maintenance;
-import java.util.List;
+import java.util.Arrays;
import joshua.corpus.Vocabulary;
@@ -8,75 +8,75 @@
* @author Zhifei Li, <zhifei.work@gmail.com>
* @author Juri Ganitkevitch, <juri@cs.jhu.edu>
*/
-public class NgramDPState implements DPState {
+public class NgramDPState extends DPState {
- private List<Integer> leftLMStateWords;
- private List<Integer> rightLMStateWords;
+ private int[] left;
+ private int[] right;
private int hash = 0;
- public NgramDPState(List<Integer> leftLMStateWords, List<Integer> rightLMStateWords) {
- this.leftLMStateWords = leftLMStateWords;
- this.rightLMStateWords = rightLMStateWords;
+ public NgramDPState(int[] l, int[] r) {
+ left = l;
+ right = r;
+ assertLengths();
}
- public void setLeftLMStateWords(List<Integer> words_) {
- this.leftLMStateWords = words_;
+ public void setLeftLMStateWords(int[] words) {
+ left = words;
+ assertLengths();
}
- public List<Integer> getLeftLMStateWords() {
- return this.leftLMStateWords;
+ public int[] getLeftLMStateWords() {
+ return left;
}
- public void setRightLMStateWords(List<Integer> words_) {
- this.rightLMStateWords = words_;
+ public void setRightLMStateWords(int[] words) {
+ right = words;
+ assertLengths();
}
- public List<Integer> getRightLMStateWords() {
- return this.rightLMStateWords;
+ public int[] getRightLMStateWords() {
+ return right;
+ }
+
+ private final void assertLengths() {
+ if (left.length != right.length)
+ throw new RuntimeException("Unequal lengths in left and right state: < "
+ + Vocabulary.getWords(left) + " | " + Vocabulary.getWords(right) + " >");
}
@Override
public int hashCode() {
if (hash == 0) {
- hash = 31 + stateHash(leftLMStateWords);
- hash = hash * 19 + stateHash(rightLMStateWords);
+ hash = 31 + Arrays.hashCode(left);
+ hash = hash * 19 + Arrays.hashCode(right);
}
return hash;
}
-
+
@Override
public boolean equals(Object other) {
if (other instanceof NgramDPState) {
NgramDPState that = (NgramDPState) other;
- if (this.leftLMStateWords.size() != that.leftLMStateWords.size()) return false;
- if (this.rightLMStateWords.size() != that.rightLMStateWords.size()) return false;
- for (int i = 0; i < this.leftLMStateWords.size(); ++i)
- if (!this.leftLMStateWords.get(i).equals(that.leftLMStateWords.get(i)))
- return false;
- for (int i = 0; i < this.rightLMStateWords.size(); ++i)
- if (!this.rightLMStateWords.get(i).equals(that.rightLMStateWords.get(i)))
+ if (this.left.length != that.left.length)
+ return false;
+ if (this.right.length != that.right.length)
+ return false;
+ for (int i = 0; i < left.length; ++i)
+ if (this.left[i] != that.left[i] || this.right[i] != that.right[i])
return false;
return true;
}
return false;
}
- private int stateHash(List<Integer> state) {
- int state_hash = 17;
- if (null != state)
- for (int i : state)
- state_hash = state_hash * 19 + i;
- return state_hash;
- }
-
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("<");
- for (int id : leftLMStateWords)
+ for (int id : left)
sb.append(" " + Vocabulary.word(id));
sb.append(" |");
- for (int id : rightLMStateWords)
+ for (int id : right)
sb.append(" " + Vocabulary.word(id));
sb.append(" >");
return sb.toString();
diff --git a/src/joshua/decoder/ff/state_maintenance/NgramStateComputer.java b/src/joshua/decoder/ff/state_maintenance/NgramStateComputer.java
index 57c914c..37117fb 100644
--- a/src/joshua/decoder/ff/state_maintenance/NgramStateComputer.java
+++ b/src/joshua/decoder/ff/state_maintenance/NgramStateComputer.java
@@ -1,6 +1,6 @@
package joshua.decoder.ff.state_maintenance;
-import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -10,8 +10,8 @@
import joshua.decoder.ff.tm.Rule;
import joshua.decoder.hypergraph.HGNode;
-
-public class NgramStateComputer implements StateComputer<NgramDPState>, Comparable {
+public class NgramStateComputer implements StateComputer<NgramDPState>,
+ Comparable<NgramStateComputer> {
private int ngramOrder;
@@ -28,121 +28,58 @@
}
@Override
- public int compareTo(Object otherState) {
+ public int compareTo(NgramStateComputer otherState) {
if (this == otherState)
return 0;
else
return -1;
}
- public NgramDPState computeFinalState(HGNode tailNode, int i, int j,
- SourcePath srcPath) {
- // no state is required
+ public NgramDPState computeFinalState(HGNode tailNode, int i, int j, SourcePath srcPath) {
+ // No state is required.
return null;
}
+ public NgramDPState computeState(Rule rule, List<HGNode> tail_nodes, int span_start,
+ int span_end, SourcePath src_path) {
+ int[] tgt = rule.getEnglish();
- public NgramDPState computeState(Rule rule, List<HGNode> tailNodes, int spanStart, int spanEnd, SourcePath srcPath) {
+ int[] left = new int[ngramOrder - 1];
+ int lcount = 0;
- List<Integer> leftStateSequence = new ArrayList<Integer>();
- List<Integer> currentNgram = new ArrayList<Integer>();
-
- int hypLen = 0;
- int[] enWords = rule.getEnglish();
-
- for (int c = 0; c < enWords.length; c++) {
- int curID = enWords[c];
+ for (int c = 0; c < tgt.length && lcount < left.length; ++c) {
+ int curID = tgt[c];
if (Vocabulary.idx(curID)) {
- // == get left- and right-context
int index = -(curID + 1);
-
- if (logger.isLoggable(Level.FINEST))
+ if (logger.isLoggable(Level.FINEST))
logger.finest("Looking up state at: " + index);
-
- NgramDPState tailState = (NgramDPState) tailNodes.get(index).getDPState(this);
- List<Integer> leftContext = tailState.getLeftLMStateWords();
- List<Integer> rightContext = tailState.getRightLMStateWords();
-
- if (leftContext.size() != rightContext.size()) {
- throw new RuntimeException(
- "NgramStateComputer.computeState: left and right contexts have unequal lengths");
- }
-
- // ================ left context
- for (int i = 0; i < leftContext.size(); i++) {
- int t = leftContext.get(i);
- currentNgram.add(t);
-
- // always calculate cost for <bo>: additional backoff weight
- /*
- * if (t == BACKOFF_LEFT_LM_STATE_SYM_ID) { int numAdditionalBackoffWeight =
- * currentNgram.size() - (i+1);//number of non-state words
- *
- * //compute additional backoff weight transitionCost -=
- * this.lmGrammar.logProbOfBackoffState(currentNgram, currentNgram.size(),
- * numAdditionalBackoffWeight);
- *
- * if (currentNgram.size() == this.ngramOrder) { currentNgram.remove(0); } } else
- */if (currentNgram.size() == this.ngramOrder) {
- // compute the current word probablity, and remove it
- // transitionCost -= this.lmGrammar.ngramLogProbability(currentNgram, this.ngramOrder);
-
- currentNgram.remove(0);
- }
-
- if (leftStateSequence.size() < this.ngramOrder - 1) {
- leftStateSequence.add(t);
- }
- }
-
- // ================ right context
- // note: left_state_org_wrds will never take words from right context because it is either
- // duplicate or out of range
- // also, we will never score the right context probablity because they are either duplicate
- // or partional ngram
- int tSize = currentNgram.size();
- for (int i = 0; i < rightContext.size(); i++) {
- // replace context
- currentNgram.set(tSize - rightContext.size() + i, rightContext.get(i));
- }
-
- } else {// terminal words
- hypLen++;
- currentNgram.add(curID);
- if (currentNgram.size() == this.ngramOrder) {
- // compute the current word probablity, and remove it
- // transitionCost -= this.lmGrammar.ngramLogProbability(currentNgram, this.ngramOrder);
-
-
- currentNgram.remove(0);
- }
- if (leftStateSequence.size() < this.ngramOrder - 1) {
- leftStateSequence.add(curID);
- }
+ NgramDPState tail_state = (NgramDPState) tail_nodes.get(index).getDPState(this);
+ int[] leftContext = tail_state.getLeftLMStateWords();
+ for (int i = 0; i < leftContext.length && lcount < left.length; i++)
+ left[lcount++] = leftContext[i];
+ } else {
+ left[lcount++] = curID;
}
}
+ int[] right = new int[ngramOrder - 1];
+ int rcount = right.length - 1;
- // ===== get left euquiv state
- // double[] lmLeftCost = new double[2];
- // int[] equivLeftState =
- // this.lmGrammar.leftEquivalentState(Support.subIntArray(leftLMStateWrds, 0,
- // leftLMStateWrds.size()), this.ngramOrder, lmLeftCost);
-
-
- // ===== trabsition and estimate cost
- // transitionCost += lmLeftCost[0];//add finalized cost for the left state words
- // left and right should always have the same size
- List<Integer> rightStateSequence = currentNgram;
- if (leftStateSequence.size() > rightStateSequence.size()) {
- throw new RuntimeException("left has a bigger size right; " + "; left="
- + leftStateSequence.size() + "; right=" + rightStateSequence.size());
+ for (int c = tgt.length - 1; c >= 0 && rcount >= 0; --c) {
+ int curID = tgt[c];
+ if (Vocabulary.idx(curID)) {
+ int index = -(curID + 1);
+ if (logger.isLoggable(Level.FINEST))
+ logger.finest("Looking up state at: " + index);
+ NgramDPState tail_state = (NgramDPState) tail_nodes.get(index).getDPState(this);
+ int[] rightContext = tail_state.getRightLMStateWords();
+ for (int i = rightContext.length - 1; i >= 0 && rcount >= 0; --i)
+ right[rcount--] = rightContext[i];
+ } else {
+ right[rcount--] = curID;
+ }
}
- while (rightStateSequence.size() > leftStateSequence.size()) {
- rightStateSequence.remove(0);// TODO: speed up
- }
-
- return new NgramDPState(leftStateSequence, rightStateSequence);
+ return new NgramDPState(Arrays.copyOfRange(left, 0, lcount), Arrays.copyOfRange(right,
+ rcount + 1, right.length));
}
-
}