add suport for hyp_merge_mode=2 when using hg-based training
git-svn-id: https://joshua.svn.sf.net/svnroot/joshua/trunk@1415 0ae5e6b2-d358-4f09-a895-f82f13dd62a4
diff --git a/aexperiment/example.config.HG.javalm b/aexperiment/example.config.HG.javalm
index 14cbe73..a0157a4 100644
--- a/aexperiment/example.config.HG.javalm
+++ b/aexperiment/example.config.HG.javalm
@@ -56,7 +56,7 @@
#disk hg
save_disk_hg=true
-use_kbest_hg=true
+use_kbest_hg=false
forest_pruning=false
forest_pruning_threshold=150
@@ -93,7 +93,7 @@
#discriminative aexperiment/featureFile 1.0
#general
-maxNumIter=4
+maxNumIter=15
useSemiringV2=true
maxNumHGInQueue=100
numThreads=10
diff --git a/src/joshua/discriminative/training/NbestMerger.java b/src/joshua/discriminative/training/NbestMerger.java
new file mode 100644
index 0000000..8aa1acc
--- /dev/null
+++ b/src/joshua/discriminative/training/NbestMerger.java
@@ -0,0 +1,67 @@
+package joshua.discriminative.training;
+
+import java.io.BufferedWriter;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+
+import joshua.discriminative.bleu_approximater.NbestReader;
+import joshua.util.FileUtility;
+import joshua.util.Regex;
+
+public class NbestMerger {
+
+ public static int mergeNbest(String nbestFile1, String nbestFile2, String nbestOutFile){
+ int totalNumHyp = 0;
+ try {
+ NbestReader nbestReader1 = new NbestReader(nbestFile1);
+ NbestReader nbestReader2 = new NbestReader(nbestFile2);
+ BufferedWriter outWriter = FileUtility.getWriteFileStream(nbestOutFile);
+
+ while(nbestReader1.hasNext()){
+ List<String> nbest1 = nbestReader1.next();
+ List<String> nbest2 = nbestReader2.next();
+
+ List<String> newNbest = processOneSentence(nbest1, nbest2);
+ for(String hyp : newNbest){
+ outWriter.write(hyp+"\n");
+ }
+ totalNumHyp += newNbest.size();
+ }
+ outWriter.close();
+
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ System.out.println("totalNumHyp="+totalNumHyp);
+ return totalNumHyp;
+ }
+
+ private static List<String> processOneSentence(List<String> nbest1, List<String> nbest2){
+
+ List<String> newNbest = new ArrayList<String>();
+ Set<String> uniqueNbests = new HashSet<String>();
+ processOneNbest(nbest1, uniqueNbests, newNbest);
+ processOneNbest(nbest2, uniqueNbests, newNbest);
+ return newNbest;
+ }
+
+ private static void processOneNbest(List<String> nbest, Set<String> uniqueNbests, List<String> newNbest){
+ for(String line : nbest){
+ String[] fds = Regex.threeBarsWithSpace.split(line);
+ String hypItself = fds[1];
+
+ if(uniqueNbests.contains(hypItself)){
+ //skip
+ }else{
+ uniqueNbests.add(hypItself);
+ newNbest.add(line);
+ }
+ }
+ }
+
+}
diff --git a/src/joshua/discriminative/training/risk_annealer/hypergraph/HGMinRiskDAMert.java b/src/joshua/discriminative/training/risk_annealer/hypergraph/HGMinRiskDAMert.java
index 6e01a98..9edddf1 100644
--- a/src/joshua/discriminative/training/risk_annealer/hypergraph/HGMinRiskDAMert.java
+++ b/src/joshua/discriminative/training/risk_annealer/hypergraph/HGMinRiskDAMert.java
@@ -28,9 +28,11 @@
import joshua.discriminative.feature_related.feature_template.TMFT;
import joshua.discriminative.feature_related.feature_template.TargetTMFT;
import joshua.discriminative.ranker.HGRanker;
+import joshua.discriminative.training.NbestMerger;
import joshua.discriminative.training.risk_annealer.AbstractMinRiskMERT;
import joshua.discriminative.training.risk_annealer.DeterministicAnnealer;
import joshua.discriminative.training.risk_annealer.GradientComputer;
+import joshua.discriminative.training.risk_annealer.nbest.NbestMinRiskDAMert;
import joshua.util.FileUtility;
public class HGMinRiskDAMert extends AbstractMinRiskMERT {
@@ -54,7 +56,7 @@
boolean haveRefereces = true;
- int totalNumHyp = 0;
+ int oldTotalNumHyp = 0;
//== for loss-augmented pruning
@@ -152,33 +154,46 @@
try {
String oldMergedFile = hypFilePrefix +".merged." + (iter-1);
String newMergedFile = hypFilePrefix +".merged." + (iter);
- if(iter ==1){
- FileUtility.copyFile(curHypFilePrefix+".hg.items", newMergedFile+".hg.items");
- FileUtility.copyFile(curHypFilePrefix+".hg.rules", newMergedFile+".hg.rules");
- }else{
- boolean saveModelCosts = true;
-
- boolean mergeWithDedup = false;
- if(MRConfig.hyp_merge_mode==2){
- mergeWithDedup = true;
- }
- /**TODO: this assumes that the feature values for the same hypothesis does not change,
- * though the weights for these features can change. In particular, this means
- * we cannot tune the weight for the aggregate discriminative model while we are tunining the individual
- * discriminative feature. This is also true for the bestHyperEdge pointer.*/
- int newTotalNumHyp = DiskHyperGraph.mergeDiskHyperGraphs(MRConfig.ngramStateID, saveModelCosts, this.numTrainingSentence,
- MRConfig.use_unique_nbest, MRConfig.use_tree_nbest,
- oldMergedFile, curHypFilePrefix, newMergedFile, mergeWithDedup);
- this.curHypFilePrefix = newMergedFile;
-
-
- if((newTotalNumHyp-totalNumHyp)*1.0/totalNumHyp<MRConfig.stop_hyp_ratio ) {
- System.out.println("No new hypotheses generated at iteration " + iter + " for stop_hyp_ratio=" + MRConfig.stop_hyp_ratio);
- break;
- }else{
- totalNumHyp = newTotalNumHyp;
+ int newTotalNumHyp =0;
+
+ if(MRConfig.use_kbest_hg==false && MRConfig.hyp_merge_mode==2){
+ System.out.println("use_kbest_hg==false && MRConfig.hyp_merge_mode; we will look at the nbest");
+ if(iter ==1){
+ FileUtility.copyFile(curHypFilePrefix, newMergedFile);
+ newTotalNumHyp = FileUtilityOld.numberLinesInFile(newMergedFile);
+ }else{
+ newTotalNumHyp = NbestMerger.mergeNbest(oldMergedFile, curHypFilePrefix, newMergedFile);
+ }
+ }else{
+ if(iter ==1){
+ FileUtility.copyFile(curHypFilePrefix+".hg.items", newMergedFile+".hg.items");
+ FileUtility.copyFile(curHypFilePrefix+".hg.rules", newMergedFile+".hg.rules");
+ }else{
+ boolean saveModelCosts = true;
+
+ /**TODO: this assumes that the feature values for the same hypothesis does not change,
+ * though the weights for these features can change. In particular, this means
+ * we cannot tune the weight for the aggregate discriminative model while we are tunining the individual
+ * discriminative feature. This is also true for the bestHyperEdge pointer.*/
+ newTotalNumHyp = DiskHyperGraph.mergeDiskHyperGraphs(MRConfig.ngramStateID, saveModelCosts, this.numTrainingSentence,
+ MRConfig.use_unique_nbest, MRConfig.use_tree_nbest,
+ oldMergedFile, curHypFilePrefix, newMergedFile, (MRConfig.hyp_merge_mode==2));
+
}
+
+ this.curHypFilePrefix = newMergedFile;
+ }
+
+ //check convergence
+ double newRatio = (newTotalNumHyp-oldTotalNumHyp)*1.0/oldTotalNumHyp;
+ if(iter <=2 || newRatio > MRConfig.stop_hyp_ratio) {
+ System.out.println("oldTotalNumHyp=" + oldTotalNumHyp + "; newTotalNumHyp=" + newTotalNumHyp + "; newRatio="+ newRatio +"; at iteration " + iter);
+ oldTotalNumHyp = newTotalNumHyp;
+ }else{
+ System.out.println("No new hypotheses generated at iteration " + iter + " for stop_hyp_ratio=" + MRConfig.stop_hyp_ratio);
+ break;
}
+
} catch (IOException e) {
e.printStackTrace();
}
diff --git a/src/joshua/discriminative/training/risk_annealer/hypergraph/MRConfig.java b/src/joshua/discriminative/training/risk_annealer/hypergraph/MRConfig.java
index 8ecbdcb..b8e3848 100644
--- a/src/joshua/discriminative/training/risk_annealer/hypergraph/MRConfig.java
+++ b/src/joshua/discriminative/training/risk_annealer/hypergraph/MRConfig.java
@@ -351,8 +351,8 @@
}
if(use_kbest_hg==false && hyp_merge_mode==2){
- logger.severe("wrong config: use_kbest_hg==false && hyp_merge_mode==2, cannot do dedup-merge for real hypergraph-based training");
- System.exit(1);
+ logger.warning("use_kbest_hg==false && hyp_merge_mode==2, cannot do dedup-merge for real hypergraph-based training, back to nbest merge, but trained on hg");
+ //System.exit(1);
}
}
diff --git a/src/joshua/discriminative/training/risk_annealer/nbest/NbestMinRiskDAMert.java b/src/joshua/discriminative/training/risk_annealer/nbest/NbestMinRiskDAMert.java
index 562427c..29e9d7f 100644
--- a/src/joshua/discriminative/training/risk_annealer/nbest/NbestMinRiskDAMert.java
+++ b/src/joshua/discriminative/training/risk_annealer/nbest/NbestMinRiskDAMert.java
@@ -2,16 +2,19 @@
import java.io.BufferedReader;
import java.io.BufferedWriter;
+import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import joshua.discriminative.FileUtilityOld;
+import joshua.discriminative.training.NbestMerger;
import joshua.discriminative.training.risk_annealer.AbstractMinRiskMERT;
import joshua.discriminative.training.risk_annealer.DeterministicAnnealer;
import joshua.discriminative.training.risk_annealer.GradientComputer;
import joshua.discriminative.training.risk_annealer.hypergraph.MRConfig;
+import joshua.util.FileUtility;
/**
@@ -19,6 +22,7 @@
* @version $LastChangedDate: 2008-10-20 00:12:30 -0400 $
*/
public abstract class NbestMinRiskDAMert extends AbstractMinRiskMERT {
+ int totalNumHyp = 0;
String nbestPrefix;
boolean useShortestRef;
@@ -52,7 +56,16 @@
if(iter ==1){
copyNbest(curNbestFile, newNbestMergedFile);
}else{
- boolean haveNewHyp = mergeNbest(oldNbestMergedFile, curNbestFile, newNbestMergedFile);
+ boolean haveNewHyp = true;
+ if(true){
+ int newTotalNumHyp = NbestMerger.mergeNbest(oldNbestMergedFile, curNbestFile, newNbestMergedFile);
+ if(newTotalNumHyp!=totalNumHyp)
+ haveNewHyp = true;
+ totalNumHyp = newTotalNumHyp;
+ }else{
+ haveNewHyp = mergeNbest(oldNbestMergedFile, curNbestFile, newNbestMergedFile);
+ }
+
if(haveNewHyp==false) {
System.out.println("No new hypotheses generated at iteration " + iter);
break;
@@ -90,7 +103,7 @@
//return false: if the nbest does not add any new hyp
//TODO: decide converged if the number of new hyp generate is very small
//TODO: terminate decoding when the weights does not change much; this one makes more sense, as if the weights do not change much; then new hypotheses will be rare
- private boolean mergeNbest(String oldMergedNbestFile, String newNbestFile, String newMergedNbestFile){
+ public static boolean mergeNbest(String oldMergedNbestFile, String newNbestFile, String newMergedNbestFile){
boolean haveNewHyp =false;
BufferedReader oldMergedNbestReader = FileUtilityOld.getReadFileStream(oldMergedNbestFile);
BufferedReader newNbestReader = FileUtilityOld.getReadFileStream(newNbestFile);
@@ -99,22 +112,22 @@
int oldSentID=-1;
String line;
String previousLineInNewNbest = FileUtilityOld.readLineLzf(newNbestReader);;
- HashMap<String, String> nbests = new HashMap<String, String>();//key: hyp itself, value: remaining fds exlcuding sent_id
+ HashMap<String, String> oldNbests = new HashMap<String, String>();//key: hyp itself, value: remaining fds exlcuding sent_id
while((line=FileUtilityOld.readLineLzf(oldMergedNbestReader))!=null){
String[] fds = line.split("\\s+\\|{3}\\s+");
- int new_sent_id = new Integer(fds[0]);
- if(oldSentID!=-1 && oldSentID!=new_sent_id){
+ int newSentID = new Integer(fds[0]);
+ if(oldSentID!=-1 && oldSentID!=newSentID){
boolean[] t_have_new_hyp = new boolean[1];
- previousLineInNewNbest = processNbest(newNbestReader, newMergedNbestReader, oldSentID, nbests, previousLineInNewNbest, t_have_new_hyp);
+ previousLineInNewNbest = processNbest(newNbestReader, newMergedNbestReader, oldSentID, oldNbests, previousLineInNewNbest, t_have_new_hyp);
if(t_have_new_hyp[0]==true)
haveNewHyp = true;
}
- oldSentID = new_sent_id;
- nbests.put(fds[1], fds[2]);//last field is not needed
+ oldSentID = newSentID;
+ oldNbests.put(fds[1], fds[2]);//last field is not needed
}
//last nbest
boolean[] t_have_new_hyp = new boolean[1];
- previousLineInNewNbest= processNbest(newNbestReader, newMergedNbestReader, oldSentID, nbests, previousLineInNewNbest, t_have_new_hyp);
+ previousLineInNewNbest= processNbest(newNbestReader, newMergedNbestReader, oldSentID, oldNbests, previousLineInNewNbest, t_have_new_hyp);
if(previousLineInNewNbest!=null){
System.out.println("last line is not null, must be wrong");
System.exit(0);
@@ -128,7 +141,7 @@
return haveNewHyp;
}
- private String processNbest(BufferedReader newNbestReader, BufferedWriter newMergedNbestReader, int oldSentID, HashMap<String, String> nbests,
+ private static String processNbest(BufferedReader newNbestReader, BufferedWriter newMergedNbestReader, int oldSentID, HashMap<String, String> oldNbests,
String previousLine, boolean[] have_new_hyp){
have_new_hyp[0] = false;
String previousLineInNewNbest = previousLine;
@@ -137,9 +150,9 @@
String[] t_fds = previousLineInNewNbest.split("\\s+\\|{3}\\s+");
int t_new_id = new Integer(t_fds[0]);
if( t_new_id == oldSentID){
- if(nbests.containsKey(t_fds[1])==false){//new hyp
+ if(oldNbests.containsKey(t_fds[1])==false){//new hyp
have_new_hyp[0] = true;
- nbests.put(t_fds[1], t_fds[2]);//merge into nbests
+ oldNbests.put(t_fds[1], t_fds[2]);//merge into nbests
}
}else{
break;
@@ -149,15 +162,16 @@
break;
}
//#### print the nbest: order is not important; and the last field is ignored
- for (Map.Entry<String, String> entry : nbests.entrySet()){
+ for (Map.Entry<String, String> entry : oldNbests.entrySet()){
FileUtilityOld.writeLzf(newMergedNbestReader, oldSentID + " ||| " + entry.getKey() + " ||| " + entry.getValue() + "\n");
}
- nbests.clear();
+ oldNbests.clear();
return previousLineInNewNbest;
}
//return false: if the nbest does not add any new hyp
- private void copyNbest(String newNbestFile, String newMergedNbestFile){
+ public static void copyNbest(String newNbestFile, String newMergedNbestFile){
+ /*
BufferedReader newNbestReader = FileUtilityOld.getReadFileStream(newNbestFile);
BufferedWriter newMergedNbestReader = FileUtilityOld.getWriteFileStream(newMergedNbestFile);
@@ -167,7 +181,13 @@
FileUtilityOld.writeLzf(newMergedNbestReader, fds[0] + " ||| " + fds[1] + " ||| " + fds[2] + "\n");
}
FileUtilityOld.closeReadFile(newNbestReader);
- FileUtilityOld.closeWriteFile(newMergedNbestReader);
+ FileUtilityOld.closeWriteFile(newMergedNbestReader);*/
+ try {
+ FileUtility.copyFile(newNbestFile, newMergedNbestFile);
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
}