blob: b0d944ca09ada8f646bc580dec788ef55c2e4ff1 [file] [log] [blame]
package joshua.decoder.ff.lm;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import joshua.corpus.Vocabulary;
import joshua.decoder.JoshuaConfiguration;
import joshua.decoder.chart_parser.SourcePath;
import joshua.decoder.ff.FeatureVector;
import joshua.decoder.ff.lm.kenlm.jni.KenLM;
import joshua.decoder.ff.lm.kenlm.jni.KenLM.StateProbPair;
import joshua.decoder.ff.state_maintenance.DPState;
import joshua.decoder.ff.state_maintenance.KenLMState;
import joshua.decoder.ff.tm.Rule;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.segment_file.Sentence;
/**
* Wrapper for KenLM LMs with left-state minimization. We inherit from the regular
*
* @author Matt Post <post@cs.jhu.edu>
* @author Juri Ganitkevitch <juri@cs.jhu.edu>
*/
public class StateMinimizingLanguageModel extends LanguageModelFF {
// maps from sentence numbers to KenLM-side pools used to allocate state
private static final ConcurrentHashMap<Integer, Long> poolMap = new ConcurrentHashMap<Integer, Long>();
public StateMinimizingLanguageModel(FeatureVector weights, String[] args, JoshuaConfiguration config) {
super(weights, args, config);
this.type = "kenlm";
if (parsedArgs.containsKey("lm_type") && ! parsedArgs.get("lm_type").equals("kenlm")) {
System.err.println("* FATAL: StateMinimizingLanguageModel only supports 'kenlm' lm_type backend");
System.err.println("* Remove lm_type from line or set to 'kenlm'");
System.exit(-1);
}
}
/**
* Initializes the underlying language model.
*
* @param config
* @param type
* @param path
*/
@Override
public void initializeLM() {
// Override type (only KenLM supports left-state minimization)
this.languageModel = new KenLM(ngramOrder, path);
Vocabulary.registerLanguageModel(this.languageModel);
Vocabulary.id(config.default_non_terminal);
LanguageModelFF.START_SYM_ID = Vocabulary.id(Vocabulary.START_SYM);
LanguageModelFF.STOP_SYM_ID = Vocabulary.id(Vocabulary.STOP_SYM);
}
/**
* Estimates the cost of a rule. We override here since KenLM can do it more efficiently
* than the default {@link LanguageModelFF} class.
*
* Most of this function implementation is redundant with compute().
*/
@Override
public float estimateCost(Rule rule, Sentence sentence) {
int[] ruleWords = rule.getEnglish();
// The IDs we'll pass to KenLM
long[] words = new long[ruleWords.length];
for (int x = 0; x < ruleWords.length; x++) {
int id = ruleWords[x];
if (Vocabulary.nt(id)) {
// For the estimate, we can just mark negative values
words[x] = -1;
} else {
// Terminal: just add it
words[x] = id;
}
}
// Get the probability of applying the rule and the new state
return weight * ((KenLM) languageModel).estimateRule(words);
}
/**
* Computes the features incurred along this edge. Note that these features are unweighted costs
* of the feature; they are the feature cost, not the model cost, or the inner product of them.
*/
@Override
public DPState compute(Rule rule, List<HGNode> tailNodes, int i, int j, SourcePath sourcePath,
Sentence sentence, Accumulator acc) {
int[] ruleWords = config.source_annotations
? getTags(rule, i, j, sentence)
: rule.getEnglish();
// The IDs we'll pass to KenLM
long[] words = new long[ruleWords.length];
for (int x = 0; x < ruleWords.length; x++) {
int id = ruleWords[x];
if (Vocabulary.nt(id)) {
// Nonterminal: retrieve the KenLM long that records the state
int index = -(id + 1);
KenLMState state = (KenLMState) tailNodes.get(index).getDPState(stateIndex);
words[x] = -state.getState();
} else {
// Terminal: just add it
words[x] = id;
}
}
int sentID = sentence.id();
if (!poolMap.containsKey(sentID))
poolMap.put(sentID, KenLM.createPool());
// Get the probability of applying the rule and the new state
StateProbPair pair = ((KenLM) languageModel).probRule(words, poolMap.get(sentID));
// Record the prob
acc.add(name, pair.prob);
// Return the state
return pair.state;
}
/**
* Destroys the pool created to allocate state for this sentence. Called from the
* {@link joshua.decoder.Translation} class after outputting the sentence or k-best list. Hosting
* this map here in KenLMFF statically allows pools to be shared across KenLM instances.
*
* @param sentId
*/
public void destroyPool(int sentId) {
if (poolMap.containsKey(sentId))
KenLM.destroyPool(poolMap.get(sentId));
poolMap.remove(sentId);
}
/**
* This function differs from regular transitions because we incorporate the cost of incomplete
* left-hand ngrams, as well as including the start- and end-of-sentence markers (if they were
* requested when the object was created).
*
* KenLM already includes the prefix probabilities (of shorter n-grams on the left-hand side), so
* there's nothing that needs to be done.
*/
@Override
public DPState computeFinal(HGNode tailNode, int i, int j, SourcePath sourcePath, Sentence sentence,
Accumulator acc) {
// KenLMState state = (KenLMState) tailNode.getDPState(getStateIndex());
// This is unnecessary
// acc.add(name, 0.0f);
// The state is the same since no rule was applied
return new KenLMState();
}
/**
* KenLM probs already include the prefix probabilities (they are substracted out when merging
* states), so this doesn't need to do anything.
*/
@Override
public float estimateFutureCost(Rule rule, DPState currentState, Sentence sentence) {
return 0.0f;
}
}