Added lazy phrase implementation...
...but it's not *quite* working, and on further thought it won't help with what
I need to do, so I'm abandoning it for the moment.
diff --git a/src/joshua/decoder/Decoder.java b/src/joshua/decoder/Decoder.java
index b49ac60..c4de111 100644
--- a/src/joshua/decoder/Decoder.java
+++ b/src/joshua/decoder/Decoder.java
@@ -532,7 +532,7 @@
} else if (format.equals("phrase")) {
joshuaConfiguration.phrase_based = true;
- grammar = new PhraseTable(file, owner, joshuaConfiguration);
+ grammar = new PhraseTable(file, owner, joshuaConfiguration, featureFunctions);
} else {
// thrax, hiero, samt
diff --git a/src/joshua/decoder/phrase/LazyRuleCollection.java b/src/joshua/decoder/phrase/LazyRuleCollection.java
new file mode 100644
index 0000000..d02ba9d
--- /dev/null
+++ b/src/joshua/decoder/phrase/LazyRuleCollection.java
@@ -0,0 +1,98 @@
+package joshua.decoder.phrase;
+
+import java.util.List;
+
+import joshua.corpus.Vocabulary;
+import joshua.decoder.ff.tm.BasicRuleCollection;
+import joshua.decoder.ff.tm.BilingualRule;
+import joshua.decoder.ff.tm.Rule;
+import joshua.decoder.ff.tm.format.HieroFormatReader;
+
+public class LazyRuleCollection extends BasicRuleCollection {
+
+ private List<String> ruleStrings;
+ private int lhs;
+ private int owner;
+
+ /**
+ * Constructs an initially empty rule collection.
+ *
+ * @param arity Number of nonterminals in the source pattern
+ * @param sourceTokens Sequence of terminals and nonterminals in the source
+ * pattern
+ */
+ public LazyRuleCollection(int owner, int arity, int[] sourceTokens) {
+ super(arity, sourceTokens);
+
+ this.owner = owner;
+ this.lhs = Vocabulary.id("[X]");
+ }
+
+ public LazyRuleCollection(int owner, int arity, int[] sourceTokens, List<String> targetSides) {
+ super(arity, sourceTokens);
+
+ this.owner = owner;
+ this.ruleStrings = targetSides;
+ this.lhs = Vocabulary.id("[X]");
+
+// System.err.println(String.format("LazyRuleCollection(%s): created new with %d", Vocabulary.getWords(sourceTokens),
+// targetSides.size()));
+ }
+
+ static String fieldDelimiter = "\\s+\\|{3}\\s+";
+
+ /**
+ * This function transforms the unprocessed strings (read from the text file)
+ * into {@link BilingualRule} objects. These have not yet been scored.
+ */
+ public List<Rule> getRules() {
+ if (ruleStrings.size() > rules.size()) {
+ for (String line : ruleStrings) {
+ String[] fields = line.split(fieldDelimiter);
+
+ // foreign side
+ int[] french = new int[sourceTokens.length + 1];
+ french[0] = lhs;
+ System.arraycopy(sourceTokens, 0, french, 1, sourceTokens.length);
+
+ // English side
+ String[] englishWords = fields[0].split("\\s+");
+ int[] english = new int[englishWords.length + 1];
+ english[0] = -1;
+ for (int i = 0; i < englishWords.length; i++) {
+ english[i + 1] = Vocabulary.id(englishWords[i]);
+ }
+
+ // transform feature values
+ StringBuffer values = new StringBuffer();
+ for (String value : fields[1].split(" ")) {
+ float f = Float.parseFloat(value);
+ values.append(String.format("%f ", f <= 0.0 ? -100 : -Math.log(f)));
+ }
+ String sparse_features = values.toString().trim();
+
+ // alignments
+ byte[] alignment = null;
+ if (fields.length > 3) { // alignments are included
+ alignment = HieroFormatReader.readAlignment(fields[2]);
+ } else {
+ alignment = null;
+ }
+
+ // System.out.println(String.format("parseLine: %s\n ->%s", line,
+ // sparse_features));
+
+ BilingualRule rule = new BilingualRule(lhs, french, english, sparse_features, arity,
+ alignment);
+ rule.setOwner(owner);
+ rules.add(rule);
+ }
+ }
+
+ return this.rules;
+ }
+
+ public boolean isSorted() {
+ return sorted;
+ }
+}
diff --git a/src/joshua/decoder/phrase/PhraseChart.java b/src/joshua/decoder/phrase/PhraseChart.java
index 23a6ba5..b3cf099 100644
--- a/src/joshua/decoder/phrase/PhraseChart.java
+++ b/src/joshua/decoder/phrase/PhraseChart.java
@@ -6,7 +6,7 @@
import joshua.decoder.Decoder;
import joshua.decoder.ff.FeatureFunction;
-import joshua.decoder.ff.tm.RuleCollection;
+import joshua.decoder.ff.tm.Rule;
import joshua.decoder.segment_file.Sentence;
/**
@@ -19,10 +19,12 @@
private int max_source_phrase_length;
// Banded array: different source lengths are next to each other.
- private List<TargetPhrases> entries;
+ private List<List<Rule>> entries;
// number of translation options
- int numOptions = 20;
+ private int numOptions = 20;
+
+ private List<FeatureFunction> features;
/**
* Create a new PhraseChart object, which represents all phrases that are
@@ -34,6 +36,8 @@
*/
public PhraseChart(PhraseTable[] tables, List<FeatureFunction> features, Sentence source, int num_options) {
+ this.features = features;
+
float startTime = System.currentTimeMillis();
max_source_phrase_length = 0;
@@ -42,32 +46,22 @@
tables[i].getMaxSourcePhraseLength());
sentence_length = source.length();
-// System.err.println(String.format(
-// "PhraseChart()::Initializing chart for sentlen %d max %d from %s", sentence_length,
-// max_source_phrase_length, source));
-
- entries = new ArrayList<TargetPhrases>();
+ entries = new ArrayList<List<Rule>>(sentence_length * max_source_phrase_length);
for (int i = 0; i < sentence_length * max_source_phrase_length; i++)
entries.add(null);
// There's some unreachable ranges off the edge. Meh.
- for (int begin = 0; begin != sentence_length; ++begin) {
+ for (int begin = 0; begin < sentence_length; ++begin) {
for (int end = begin + 1; (end != sentence_length + 1)
&& (end <= begin + max_source_phrase_length); ++end) {
if (source.hasPath(begin, end)) {
for (PhraseTable table : tables)
- SetRange(begin, end,
- table.Phrases(Arrays.copyOfRange(source.intSentence(), begin, end)));
+ addToRange(begin, end,
+ table.getPhrases(Arrays.copyOfRange(source.intSentence(), begin, end)));
}
-
}
}
- for (TargetPhrases phrases: entries) {
- if (phrases != null)
- phrases.finish(features, Decoder.weights, num_options);
- }
-
System.err.println(String.format("[%d] Collecting options took %.3f seconds", source.id(),
(System.currentTimeMillis() - startTime) / 1000.0f));
}
@@ -101,16 +95,16 @@
*/
public TargetPhrases getRange(int begin, int end) {
int index = offset(begin, end);
-// System.err.println(String.format("PhraseChart::Range(%d,%d): found %d entries", begin, end,
+// System.err.println(String.format("PhraseChart::getRange(%d,%d): found %d entries", begin, end,
// entries.get(index) == null ? 0 : entries.get(index).size()));
-// if (entries.get(index) != null)
-// for (Rule phrase: entries.get(index))
-// System.err.println(" RULE: " + phrase);
if (index < 0 || index >= entries.size() || entries.get(index) == null)
return null;
- return entries.get(index);
+ TargetPhrases phrases = new TargetPhrases(entries.get(index));
+ phrases.finish(features, Decoder.weights, numOptions);
+
+ return phrases;
}
/**
@@ -120,14 +114,15 @@
* @param end
* @param to
*/
- private void SetRange(int begin, int end, RuleCollection to) {
+ private void addToRange(int begin, int end, List<Rule> to) {
if (to != null) {
+// System.err.println(String.format("PhraseChart::addToRange(%d, %d) = %d targets", begin, end, to.size()));
+
try {
int offset = offset(begin, end);
if (entries.get(offset) == null)
- entries.set(offset, new TargetPhrases(to.getRules()));
- else
- entries.get(offset).addAll(to.getRules());
+ entries.set(offset, new ArrayList<Rule>());
+ entries.get(offset).addAll(to);
} catch (java.lang.IndexOutOfBoundsException e) {
System.err.println(String.format("Whoops! %s [%d-%d] too long (%d)", to, begin, end,
entries.size()));
diff --git a/src/joshua/decoder/phrase/PhraseTable.java b/src/joshua/decoder/phrase/PhraseTable.java
index ad3cbc6..3b90a20 100644
--- a/src/joshua/decoder/phrase/PhraseTable.java
+++ b/src/joshua/decoder/phrase/PhraseTable.java
@@ -1,16 +1,21 @@
package joshua.decoder.phrase;
import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
import java.util.List;
import joshua.corpus.Vocabulary;
import joshua.decoder.JoshuaConfiguration;
import joshua.decoder.ff.FeatureFunction;
+import joshua.decoder.ff.tm.BasicRuleCollection;
import joshua.decoder.ff.tm.BilingualRule;
+import joshua.decoder.ff.tm.Grammar;
+import joshua.decoder.ff.tm.Rule;
import joshua.decoder.ff.tm.RuleCollection;
import joshua.decoder.ff.tm.Trie;
-import joshua.decoder.ff.tm.format.HieroFormatReader;
-import joshua.decoder.ff.tm.hash_based.MemoryBasedBatchGrammar;
+import joshua.util.io.LineReader;
/**
* Represents a phrase table. Inherits from grammars so we can code-share with the syntax-
@@ -20,8 +25,16 @@
*
*/
-public class PhraseTable extends MemoryBasedBatchGrammar {
+public class PhraseTable implements Grammar {
+ private String grammarFile;
+ private int owner;
+ private JoshuaConfiguration config;
+ private HashMap<PhraseWrapper, RuleCollection> entries;
+ private int numRules;
+ private List<FeatureFunction> features;
+ private int maxSourceLength;
+
/**
* Chain to the super with a number of defaults. For example, we only use a single nonterminal,
* and there is no span limit.
@@ -31,22 +44,74 @@
* @param config
* @throws IOException
*/
- public PhraseTable(String grammarFile, String owner, JoshuaConfiguration config) throws IOException {
- super("phrase", grammarFile, owner, "[X]", -1, config);
+ public PhraseTable(String grammarFile, String owner, JoshuaConfiguration config, List<FeatureFunction> features) throws IOException {
+ this.config = config;
+ this.owner = Vocabulary.id(owner);
+ this.grammarFile = grammarFile;
+ this.features = features;
+ this.maxSourceLength = 0;
+ Vocabulary.id("[X]");
+
+ this.entries = new HashMap<PhraseWrapper, RuleCollection>();
+
+ loadPhraseTable();
}
- public PhraseTable(String owner, JoshuaConfiguration config) {
- super(owner, config);
+ public PhraseTable(String owner, JoshuaConfiguration config, List<FeatureFunction> features) {
+ this.config = config;
+ this.owner = Vocabulary.id(owner);
+ this.features = features;
+ this.maxSourceLength = 0;
+
+ this.entries = new HashMap<PhraseWrapper, RuleCollection>();
}
+ private void loadPhraseTable() throws IOException {
+
+ String prevSourceSide = null;
+ List<String> rules = new ArrayList<String>();
+ int[] french = null;
+
+ for (String line: new LineReader(this.grammarFile)) {
+ int sourceEnd = line.indexOf(" ||| ");
+ String source = line.substring(0, sourceEnd);
+ String rest = line.substring(sourceEnd + 5);
+
+ rules.add(rest);
+
+ if (prevSourceSide == null || ! source.equals(prevSourceSide)) {
+
+ // New source side, store accumulated rules
+ if (prevSourceSide != null) {
+ System.err.println(String.format("loadPhraseTable: %s -> %d rules", Vocabulary.getWords(french), rules.size()));
+ entries.put(new PhraseWrapper(french), new LazyRuleCollection(owner, 1, french, rules));
+ rules = new ArrayList<String>();
+ }
+
+ String[] foreignWords = source.split("\\s+");
+ french = new int[foreignWords.length];
+ for (int i = 0; i < foreignWords.length; i++)
+ french[i] = Vocabulary.id(foreignWords[i]);
+
+ maxSourceLength = Math.max(french.length, getMaxSourcePhraseLength());
+
+ prevSourceSide = source;
+ }
+ }
+
+ if (french != null) {
+ entries.put(new PhraseWrapper(french), new LazyRuleCollection(owner, 1, french, rules));
+ System.err.println(String.format("loadPhraseTable: %s -> %d rules", Vocabulary.getWords(french), rules.size()));
+ }
+ }
+
/**
* Returns the longest source phrase read, subtracting off the nonterminal that was added.
*
* @return
*/
- @Override
public int getMaxSourcePhraseLength() {
- return maxSourcePhraseLength - 1;
+ return maxSourceLength;
}
/**
@@ -55,33 +120,140 @@
* @param sourceWords the sequence of source words
* @return the rules
*/
- public RuleCollection Phrases(int[] sourceWords) {
- if (sourceWords.length != 0) {
- Trie pointer = getTrieRoot().match(Vocabulary.id("[X]"));
- int i = 0;
- while (pointer != null && i < sourceWords.length)
- pointer = pointer.match(sourceWords[i++]);
-
- if (pointer != null && pointer.hasRules())
- return pointer.getRuleCollection();
+ public List<Rule> getPhrases(int[] sourceWords) {
+ RuleCollection rules = entries.get(new PhraseWrapper(sourceWords));
+ if (rules != null) {
+// System.err.println(String.format("PhraseTable::getPhrases(%s) = %d of them", Vocabulary.getWords(sourceWords),
+// rules.getRules().size()));
+ return rules.getSortedRules(features);
}
-
return null;
}
+ public void addEOSRule() {
+ int[] french = { Vocabulary.id("[X]"), Vocabulary.id("</s>") };
+
+ maxSourceLength = Math.max(getMaxSourcePhraseLength(), 1);
+
+ RuleCollection rules = new BasicRuleCollection(1, french);
+ rules.getRules().add(Hypothesis.END_RULE);
+ entries.put(new PhraseWrapper(new int[] { Vocabulary.id("</s>") }), rules);
+
+// List<String> rules = new ArrayList<String>();
+// rules.add("[X,1] </s> ||| 0");
+// entries.put(new PhraseWrapper(new int[] { Vocabulary.id("</s>") }), new LazyRuleCollection(owner, 1, french, rules));
+ }
+
@Override
- public void addOOVRules(int sourceWord, List<FeatureFunction> featureFunctions) {
+ public void addOOVRules(int sourceWord, List<FeatureFunction> features) {
// TODO: _OOV shouldn't be outright added, since the word might not be OOV for the LM (but now almost
// certainly is)
- int targetWord = joshuaConfiguration.mark_oovs
- ? Vocabulary.id(Vocabulary.word(sourceWord) + "_OOV")
- : sourceWord;
+ int[] french = { Vocabulary.id("[X]"), sourceWord };
+
+ String targetWord = (config.mark_oovs
+ ? Vocabulary.word(sourceWord) + "_OOV"
+ : Vocabulary.word(sourceWord));
- String ruleString = String.format("[X] ||| [X,1] %s ||| [X,1] %s ||| -1 ||| 0-0 1-1",
- Vocabulary.word(sourceWord), Vocabulary.word(targetWord));
- BilingualRule oovRule = new HieroFormatReader().parseLine(ruleString);
- oovRule.setOwner(Vocabulary.id("oov"));
- addRule(oovRule);
- oovRule.estimateRuleCost(featureFunctions);
+ int[] english = { -1, Vocabulary.id(targetWord) };
+ final byte[] align = { 0, 0 };
+
+ maxSourceLength = Math.max(getMaxSourcePhraseLength(), 1);
+
+ BilingualRule oovRule = new BilingualRule(Vocabulary.id("[X]"), french, english, "", 1, align);
+ oovRule.setOwner(owner);
+ oovRule.estimateRuleCost(features);
+
+// List<String> rules = new ArrayList<String>();
+// rules.add(String.format("[X,1] %s ||| -1 ||| 0-0 1-1", targetWord));
+// entries.put(new PhraseWrapper(new int[] { sourceWord }), new LazyRuleCollection(owner, 1, french, rules));
+
+ RuleCollection rules = new BasicRuleCollection(1, french);
+ rules.getRules().add(oovRule);
+ entries.put(new PhraseWrapper(new int[] { sourceWord }), rules);
+ }
+
+ /**
+ * The phrase table doesn't use a trie.
+ */
+ @Override
+ public Trie getTrieRoot() {
+ return null;
+ }
+
+ /**
+ * We don't pre-sort grammars!
+ */
+ @Override
+ public void sortGrammar(List<FeatureFunction> models) {
+ }
+
+ /**
+ * We never pre-sort grammars! Why would you?
+ */
+ @Override
+ public boolean isSorted() {
+ return false;
+ }
+
+ @Override
+ public boolean hasRuleForSpan(int startIndex, int endIndex, int pathLength) {
+ // No limit on maximum phrase length
+ return true;
+ }
+
+ @Override
+ public int getNumRules() {
+ return numRules;
+ }
+
+ @Override
+ public Rule constructManualRule(int lhs, int[] sourceWords, int[] targetWords, float[] scores,
+ int aritity) {
+ return null;
+ }
+
+ @Override
+ public void writeGrammarOnDisk(String file) {
+ }
+
+ @Override
+ public boolean isRegexpGrammar() {
+ return false;
+ }
+
+ /**
+ * A simple wrapper around an int[] used for hashing
+ */
+ private class PhraseWrapper {
+ public int[] words;
+
+ /**
+ * Initial from the source side of the rule. Delete the nonterminal that will be there, since
+ * later indexing will not have it.
+ *
+ * @param source the source phrase, e.g., [-1, 17, 91283]
+ */
+ public PhraseWrapper(int[] source) {
+ this.words = Arrays.copyOfRange(source, 0, source.length);
+ }
+
+ @Override
+ public int hashCode() {
+ return Arrays.hashCode(words);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof PhraseWrapper) {
+ PhraseWrapper that = (PhraseWrapper) other;
+ if (words.length == that.words.length) {
+ for (int i = 0; i < words.length; i++)
+ if (words[i] != that.words[i])
+ return false;
+ return true;
+ }
+ }
+ return false;
+ }
}
}
diff --git a/src/joshua/decoder/phrase/Stacks.java b/src/joshua/decoder/phrase/Stacks.java
index a752f29..03d79af 100644
--- a/src/joshua/decoder/phrase/Stacks.java
+++ b/src/joshua/decoder/phrase/Stacks.java
@@ -56,10 +56,10 @@
if (grammars[i] instanceof PhraseTable)
phraseTables[j++] = (PhraseTable) grammars[i];
- phraseTables[phraseTables.length - 2] = new PhraseTable("null", config);
- phraseTables[phraseTables.length - 2].addRule(Hypothesis.END_RULE);
+ phraseTables[phraseTables.length - 2] = new PhraseTable("null", config, featureFunctions);
+ phraseTables[phraseTables.length - 2].addEOSRule();
- phraseTables[phraseTables.length - 1] = new PhraseTable("oov", config);
+ phraseTables[phraseTables.length - 1] = new PhraseTable("oov", config, featureFunctions);
AbstractGrammar.addOOVRules(phraseTables[phraseTables.length - 1], sentence.intLattice(), featureFunctions, config.true_oovs_only);
this.chart = new PhraseChart(phraseTables, featureFunctions, sentence, config.num_translation_options);
diff --git a/src/joshua/decoder/phrase/TargetPhrases.java b/src/joshua/decoder/phrase/TargetPhrases.java
index 6d1b893..84dfa3f 100644
--- a/src/joshua/decoder/phrase/TargetPhrases.java
+++ b/src/joshua/decoder/phrase/TargetPhrases.java
@@ -43,21 +43,21 @@
* some trouble and should probably be reworked.
*/
public void finish(List<FeatureFunction> features, FeatureVector weights, int num_options) {
- for (Rule rule: this) {
- if (rule.getPrecomputableCost() <= Float.NEGATIVE_INFINITY) {
- float score = rule.getFeatureVector().innerProduct(weights);
- rule.setPrecomputableCost(score);
- }
- rule.estimateRuleCost(features);
-// System.err.println("TargetPhrases:finish(): " + rule);
- }
+// for (Rule rule: this) {
+// if (rule.getPrecomputableCost() <= Float.NEGATIVE_INFINITY) {
+// float score = rule.getFeatureVector().innerProduct(weights);
+// rule.setPrecomputableCost(score);
+// }
+// rule.estimateRuleCost(features);
+//// System.err.println("TargetPhrases:finish(): " + rule);
+// }
Collections.sort(this, Rule.EstimatedCostComparator);
if (this.size() > num_options)
this.removeRange(num_options, this.size());
-// System.err.println("TargetPhrases::finish()");
-// for (Rule rule: this)
-// System.err.println(" " + rule);
+ System.err.println("TargetPhrases::finish()");
+ for (Rule rule: this)
+ System.err.println(" " + rule);
}
}