blob: 29e9d7f2bfcc98ac203ef840343466ff36e8b89a [file] [log] [blame]
package joshua.discriminative.training.risk_annealer.nbest;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import joshua.discriminative.FileUtilityOld;
import joshua.discriminative.training.NbestMerger;
import joshua.discriminative.training.risk_annealer.AbstractMinRiskMERT;
import joshua.discriminative.training.risk_annealer.DeterministicAnnealer;
import joshua.discriminative.training.risk_annealer.GradientComputer;
import joshua.discriminative.training.risk_annealer.hypergraph.MRConfig;
import joshua.util.FileUtility;
/**
* @author Zhifei Li, <zhifei.work@gmail.com>
* @version $LastChangedDate: 2008-10-20 00:12:30 -0400 $
*/
public abstract class NbestMinRiskDAMert extends AbstractMinRiskMERT {
int totalNumHyp = 0;
String nbestPrefix;
boolean useShortestRef;
private static Logger logger = Logger.getLogger(NbestMinRiskDAMert.class.getSimpleName());
public NbestMinRiskDAMert(boolean useShortestRef, String decoderConfigFile, int numSentInTrainSet, String[] refFiles, String nbestPrefix) {
super(decoderConfigFile, numSentInTrainSet, refFiles);
this.nbestPrefix = nbestPrefix;
this.useShortestRef = useShortestRef;
initialize();
}
public abstract void writeConfigFile(double[] weights, String configFileTemplate, String configOutFile);
public void mainLoop(){
for(int iter=1; iter<=MRConfig.maxNumIter; iter++){
//#re-normalize weights
normalizeWeightsByFirstFeature(lastWeightVector,0);
//############decoding
writeConfigFile(lastWeightVector, configFile, configFile+"." + iter);
String curNbestFile = nbestPrefix +"." + iter;
decodingTestSet(lastWeightVector, curNbestFile); //call decoder to produce an nbest using the new weight vector
//##############merge nbest and check convergency
String oldNbestMergedFile = nbestPrefix +".merged." + (iter-1);
String newNbestMergedFile = nbestPrefix +".merged." + (iter);
if(iter ==1){
copyNbest(curNbestFile, newNbestMergedFile);
}else{
boolean haveNewHyp = true;
if(true){
int newTotalNumHyp = NbestMerger.mergeNbest(oldNbestMergedFile, curNbestFile, newNbestMergedFile);
if(newTotalNumHyp!=totalNumHyp)
haveNewHyp = true;
totalNumHyp = newTotalNumHyp;
}else{
haveNewHyp = mergeNbest(oldNbestMergedFile, curNbestFile, newNbestMergedFile);
}
if(haveNewHyp==false) {
System.out.println("No new hypotheses generated at iteration " + iter);
break;
}
}
//String f_nbest_merged_new = "C:/Users/zli/Documents/minriskannealer.nbest.merged.17";//????????????
//String f_nbest_merged_new = "C:/Users/zli/Documents/minriskannealer.nbest.merged.1";//????????????
GradientComputer gradientComputer = new NbestRiskGradientComputer(newNbestMergedFile, referenceFiles, useShortestRef, numTrainingSentence,
numPara, MRConfig.gainFactor, 1.0, 0.0, true, MRConfig.linearCorpusGainThetas);
annealer = new DeterministicAnnealer( numPara, lastWeightVector, MRConfig.isMinimizer, gradientComputer,
MRConfig.useL2Regula, MRConfig.varianceForL2, MRConfig.useModelDivergenceRegula, MRConfig.lambda, MRConfig.printFirstN);
if(MRConfig.annealingMode==0)//do not anneal
lastWeightVector = annealer.runWithoutAnnealing(MRConfig.isScalingFactorTunable, MRConfig.startScaleAtNoAnnealing, MRConfig.temperatureAtNoAnnealing);
else if(MRConfig.annealingMode==1)
lastWeightVector = annealer.runQuenching(1.0);
else if(MRConfig.annealingMode==2)
lastWeightVector = annealer.runDAAndQuenching();
else{
logger.severe("unsorported anneal mode, " + MRConfig.annealingMode);
System.exit(0);
}
//last_weight_vector will be used intial weights in the next iteration
}
//final output
normalizeWeightsByFirstFeature(lastWeightVector,0);
writeConfigFile(lastWeightVector, configFile, configFile+".final");
System.out.println("#### Final weights are: ");
annealer.getLBFGSRunner().printStatistics(-1, -1, null, lastWeightVector);
}
//return false: if the nbest does not add any new hyp
//TODO: decide converged if the number of new hyp generate is very small
//TODO: terminate decoding when the weights does not change much; this one makes more sense, as if the weights do not change much; then new hypotheses will be rare
public static boolean mergeNbest(String oldMergedNbestFile, String newNbestFile, String newMergedNbestFile){
boolean haveNewHyp =false;
BufferedReader oldMergedNbestReader = FileUtilityOld.getReadFileStream(oldMergedNbestFile);
BufferedReader newNbestReader = FileUtilityOld.getReadFileStream(newNbestFile);
BufferedWriter newMergedNbestReader = FileUtilityOld.getWriteFileStream(newMergedNbestFile);
int oldSentID=-1;
String line;
String previousLineInNewNbest = FileUtilityOld.readLineLzf(newNbestReader);;
HashMap<String, String> oldNbests = new HashMap<String, String>();//key: hyp itself, value: remaining fds exlcuding sent_id
while((line=FileUtilityOld.readLineLzf(oldMergedNbestReader))!=null){
String[] fds = line.split("\\s+\\|{3}\\s+");
int newSentID = new Integer(fds[0]);
if(oldSentID!=-1 && oldSentID!=newSentID){
boolean[] t_have_new_hyp = new boolean[1];
previousLineInNewNbest = processNbest(newNbestReader, newMergedNbestReader, oldSentID, oldNbests, previousLineInNewNbest, t_have_new_hyp);
if(t_have_new_hyp[0]==true)
haveNewHyp = true;
}
oldSentID = newSentID;
oldNbests.put(fds[1], fds[2]);//last field is not needed
}
//last nbest
boolean[] t_have_new_hyp = new boolean[1];
previousLineInNewNbest= processNbest(newNbestReader, newMergedNbestReader, oldSentID, oldNbests, previousLineInNewNbest, t_have_new_hyp);
if(previousLineInNewNbest!=null){
System.out.println("last line is not null, must be wrong");
System.exit(0);
}
if(t_have_new_hyp[0]==true)
haveNewHyp = true;
FileUtilityOld.closeReadFile(oldMergedNbestReader);
FileUtilityOld.closeReadFile(newNbestReader);
FileUtilityOld.closeWriteFile(newMergedNbestReader);
return haveNewHyp;
}
private static String processNbest(BufferedReader newNbestReader, BufferedWriter newMergedNbestReader, int oldSentID, HashMap<String, String> oldNbests,
String previousLine, boolean[] have_new_hyp){
have_new_hyp[0] = false;
String previousLineInNewNbest = previousLine;
// #### read new nbest and merge into nbests
while(true){
String[] t_fds = previousLineInNewNbest.split("\\s+\\|{3}\\s+");
int t_new_id = new Integer(t_fds[0]);
if( t_new_id == oldSentID){
if(oldNbests.containsKey(t_fds[1])==false){//new hyp
have_new_hyp[0] = true;
oldNbests.put(t_fds[1], t_fds[2]);//merge into nbests
}
}else{
break;
}
previousLineInNewNbest = FileUtilityOld.readLineLzf(newNbestReader);
if(previousLineInNewNbest==null)
break;
}
//#### print the nbest: order is not important; and the last field is ignored
for (Map.Entry<String, String> entry : oldNbests.entrySet()){
FileUtilityOld.writeLzf(newMergedNbestReader, oldSentID + " ||| " + entry.getKey() + " ||| " + entry.getValue() + "\n");
}
oldNbests.clear();
return previousLineInNewNbest;
}
//return false: if the nbest does not add any new hyp
public static void copyNbest(String newNbestFile, String newMergedNbestFile){
/*
BufferedReader newNbestReader = FileUtilityOld.getReadFileStream(newNbestFile);
BufferedWriter newMergedNbestReader = FileUtilityOld.getWriteFileStream(newMergedNbestFile);
String line;
while((line=FileUtilityOld.readLineLzf(newNbestReader))!=null){
String[] fds = line.split("\\s+\\|{3}\\s+");
FileUtilityOld.writeLzf(newMergedNbestReader, fds[0] + " ||| " + fds[1] + " ||| " + fds[2] + "\n");
}
FileUtilityOld.closeReadFile(newNbestReader);
FileUtilityOld.closeWriteFile(newMergedNbestReader);*/
try {
FileUtility.copyFile(newNbestFile, newMergedNbestFile);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
// set lastWeightVector and google linear corpus
private void initialize() {
//===== read configurations
MRConfig.readConfigFile(this.configFile);
logger.info("intilize features and weights");
//== get the weights
List<Double> weights = readBaselineFeatureWeights(this.configFile);
/**initialize the weights*/
int numPara=weights.size();
this.lastWeightVector = new double[numPara];
for(int i=0; i<lastWeightVector.length; i++){
lastWeightVector[i] = weights.get(i);
logger.info("weight: " + lastWeightVector[i]);
}
this.numPara = this.lastWeightVector.length;
}
}