add suport for hyp_merge_mode=2 when using hg-based training

git-svn-id: https://joshua.svn.sf.net/svnroot/joshua/trunk@1415 0ae5e6b2-d358-4f09-a895-f82f13dd62a4
diff --git a/aexperiment/example.config.HG.javalm b/aexperiment/example.config.HG.javalm
index 14cbe73..a0157a4 100644
--- a/aexperiment/example.config.HG.javalm
+++ b/aexperiment/example.config.HG.javalm
@@ -56,7 +56,7 @@
 
 #disk hg
 save_disk_hg=true
-use_kbest_hg=true
+use_kbest_hg=false
 forest_pruning=false
 forest_pruning_threshold=150
 
@@ -93,7 +93,7 @@
 #discriminative aexperiment/featureFile 1.0
 
 #general
-maxNumIter=4
+maxNumIter=15
 useSemiringV2=true
 maxNumHGInQueue=100
 numThreads=10
diff --git a/src/joshua/discriminative/training/NbestMerger.java b/src/joshua/discriminative/training/NbestMerger.java
new file mode 100644
index 0000000..8aa1acc
--- /dev/null
+++ b/src/joshua/discriminative/training/NbestMerger.java
@@ -0,0 +1,67 @@
+package joshua.discriminative.training;

+

+import java.io.BufferedWriter;

+import java.io.IOException;

+import java.util.ArrayList;

+import java.util.HashSet;

+import java.util.List;

+import java.util.Set;

+

+

+import joshua.discriminative.bleu_approximater.NbestReader;

+import joshua.util.FileUtility;

+import joshua.util.Regex;

+

+public class NbestMerger {

+	

+	public static int mergeNbest(String nbestFile1, String nbestFile2, String nbestOutFile){

+		int totalNumHyp = 0;

+		try {

+			NbestReader nbestReader1 = new NbestReader(nbestFile1);

+			NbestReader nbestReader2 = new NbestReader(nbestFile2);

+			BufferedWriter outWriter = FileUtility.getWriteFileStream(nbestOutFile);

+			

+			while(nbestReader1.hasNext()){

+				List<String> nbest1 = nbestReader1.next();

+				List<String> nbest2 = nbestReader2.next();

+		 

+				List<String> newNbest = processOneSentence(nbest1, nbest2);

+				for(String hyp : newNbest){

+					outWriter.write(hyp+"\n");

+				}

+				totalNumHyp += newNbest.size();

+			}

+			outWriter.close();

+		

+		} catch (IOException e) {

+			// TODO Auto-generated catch block

+			e.printStackTrace();

+		}

+		System.out.println("totalNumHyp="+totalNumHyp);

+		return totalNumHyp;

+	}

+	

+	private static List<String> processOneSentence(List<String> nbest1, List<String> nbest2){

+		

+		List<String> newNbest = new ArrayList<String>();

+		Set<String> uniqueNbests = new HashSet<String>();

+		processOneNbest(nbest1, uniqueNbests, newNbest);

+		processOneNbest(nbest2, uniqueNbests, newNbest);

+		return newNbest;

+	}

+	

+	private static void processOneNbest(List<String> nbest, Set<String> uniqueNbests, List<String> newNbest){

+		for(String line : nbest){

+			String[] fds = Regex.threeBarsWithSpace.split(line);

+			String hypItself = fds[1];

+			

+			if(uniqueNbests.contains(hypItself)){

+				//skip

+			}else{

+				uniqueNbests.add(hypItself);

+				newNbest.add(line);

+			}

+		}

+	}

+	

+}

diff --git a/src/joshua/discriminative/training/risk_annealer/hypergraph/HGMinRiskDAMert.java b/src/joshua/discriminative/training/risk_annealer/hypergraph/HGMinRiskDAMert.java
index 6e01a98..9edddf1 100644
--- a/src/joshua/discriminative/training/risk_annealer/hypergraph/HGMinRiskDAMert.java
+++ b/src/joshua/discriminative/training/risk_annealer/hypergraph/HGMinRiskDAMert.java
@@ -28,9 +28,11 @@
 import joshua.discriminative.feature_related.feature_template.TMFT;

 import joshua.discriminative.feature_related.feature_template.TargetTMFT;

 import joshua.discriminative.ranker.HGRanker;

+import joshua.discriminative.training.NbestMerger;

 import joshua.discriminative.training.risk_annealer.AbstractMinRiskMERT;

 import joshua.discriminative.training.risk_annealer.DeterministicAnnealer;

 import joshua.discriminative.training.risk_annealer.GradientComputer;

+import joshua.discriminative.training.risk_annealer.nbest.NbestMinRiskDAMert;

 import joshua.util.FileUtility;

 

 public class HGMinRiskDAMert extends AbstractMinRiskMERT {

@@ -54,7 +56,7 @@
 	

 	boolean haveRefereces = true;

 	

-	int totalNumHyp = 0;

+	int oldTotalNumHyp = 0;

 	

 	

 	//== for loss-augmented pruning

@@ -152,33 +154,46 @@
         		try {

 		        	String oldMergedFile = hypFilePrefix +".merged." + (iter-1);

 		        	String newMergedFile = hypFilePrefix +".merged." + (iter);

-		        	if(iter ==1){		        		

-						FileUtility.copyFile(curHypFilePrefix+".hg.items", newMergedFile+".hg.items");						

-		        		FileUtility.copyFile(curHypFilePrefix+".hg.rules", newMergedFile+".hg.rules");

-		        	}else{

-		        		boolean saveModelCosts = true;

-		        		

-		        		boolean mergeWithDedup = false;

-		        		if(MRConfig.hyp_merge_mode==2){

-		        			mergeWithDedup = true;

-		        		}

-			            /**TODO: this assumes that the feature values for the same hypothesis does not change,

-			             * though the weights for these features can change. In particular, this means

-			             * we cannot tune the weight for the aggregate discriminative model while we are tunining the individual 

-			             * discriminative feature. This is also true for the bestHyperEdge pointer.*/

-			            int newTotalNumHyp = DiskHyperGraph.mergeDiskHyperGraphs(MRConfig.ngramStateID, saveModelCosts, this.numTrainingSentence, 

-			            		MRConfig.use_unique_nbest, MRConfig.use_tree_nbest,

-			            		oldMergedFile, curHypFilePrefix, newMergedFile, mergeWithDedup);

-			            this.curHypFilePrefix = newMergedFile;

-			            

-			     

-			            if((newTotalNumHyp-totalNumHyp)*1.0/totalNumHyp<MRConfig.stop_hyp_ratio ) {

-			            	System.out.println("No new hypotheses generated at iteration " + iter + " for stop_hyp_ratio=" + MRConfig.stop_hyp_ratio); 

-			            	break;

-			            }else{

-			            	totalNumHyp = newTotalNumHyp;

+		        	int newTotalNumHyp =0;

+		        	

+		        	if(MRConfig.use_kbest_hg==false && MRConfig.hyp_merge_mode==2){

+		        		System.out.println("use_kbest_hg==false && MRConfig.hyp_merge_mode; we will look at the nbest");

+		            	if(iter ==1){

+		            		FileUtility.copyFile(curHypFilePrefix, newMergedFile);

+		            		newTotalNumHyp = FileUtilityOld.numberLinesInFile(newMergedFile);

+		            	}else{

+		            		newTotalNumHyp = NbestMerger.mergeNbest(oldMergedFile, curHypFilePrefix, newMergedFile);

+		                }		            	

+	        		}else{		        	

+			        	if(iter ==1){		        		

+							FileUtility.copyFile(curHypFilePrefix+".hg.items", newMergedFile+".hg.items");						

+			        		FileUtility.copyFile(curHypFilePrefix+".hg.rules", newMergedFile+".hg.rules");

+			        	}else{

+			        		boolean saveModelCosts = true;

+			        		

+				            /**TODO: this assumes that the feature values for the same hypothesis does not change,

+				             * though the weights for these features can change. In particular, this means

+				             * we cannot tune the weight for the aggregate discriminative model while we are tunining the individual 

+				             * discriminative feature. This is also true for the bestHyperEdge pointer.*/

+				            newTotalNumHyp = DiskHyperGraph.mergeDiskHyperGraphs(MRConfig.ngramStateID, saveModelCosts, this.numTrainingSentence, 

+				            		MRConfig.use_unique_nbest, MRConfig.use_tree_nbest,

+				            		oldMergedFile, curHypFilePrefix, newMergedFile, (MRConfig.hyp_merge_mode==2));

+				            

 			            }

+			        	

+			        	this.curHypFilePrefix = newMergedFile;

+		        	}

+		        	

+		        	//check convergence

+		        	double newRatio = (newTotalNumHyp-oldTotalNumHyp)*1.0/oldTotalNumHyp;

+		        	if(iter <=2 || newRatio > MRConfig.stop_hyp_ratio) {

+		        		System.out.println("oldTotalNumHyp=" + oldTotalNumHyp + "; newTotalNumHyp=" + newTotalNumHyp + "; newRatio="+ newRatio +";  at iteration " + iter);

+		        		oldTotalNumHyp = newTotalNumHyp; 

+		            }else{

+		            	System.out.println("No new hypotheses generated at iteration " + iter + " for stop_hyp_ratio=" + MRConfig.stop_hyp_ratio); 

+		            	break;

 		            }

+		        	

         		} catch (IOException e) {

 					e.printStackTrace();

 				}

diff --git a/src/joshua/discriminative/training/risk_annealer/hypergraph/MRConfig.java b/src/joshua/discriminative/training/risk_annealer/hypergraph/MRConfig.java
index 8ecbdcb..b8e3848 100644
--- a/src/joshua/discriminative/training/risk_annealer/hypergraph/MRConfig.java
+++ b/src/joshua/discriminative/training/risk_annealer/hypergraph/MRConfig.java
@@ -351,8 +351,8 @@
 		}

 		

 		if(use_kbest_hg==false && hyp_merge_mode==2){

-			logger.severe("wrong config: use_kbest_hg==false && hyp_merge_mode==2, cannot do dedup-merge for real hypergraph-based training");

-			System.exit(1);

+			logger.warning("use_kbest_hg==false && hyp_merge_mode==2, cannot do dedup-merge for real hypergraph-based training, back to nbest merge, but trained on hg");

+			//System.exit(1);

 		}

 	}

 

diff --git a/src/joshua/discriminative/training/risk_annealer/nbest/NbestMinRiskDAMert.java b/src/joshua/discriminative/training/risk_annealer/nbest/NbestMinRiskDAMert.java
index 562427c..29e9d7f 100644
--- a/src/joshua/discriminative/training/risk_annealer/nbest/NbestMinRiskDAMert.java
+++ b/src/joshua/discriminative/training/risk_annealer/nbest/NbestMinRiskDAMert.java
@@ -2,16 +2,19 @@
 

 import java.io.BufferedReader;

 import java.io.BufferedWriter;

+import java.io.IOException;

 import java.util.HashMap;

 import java.util.List;

 import java.util.Map;

 import java.util.logging.Logger;

 

 import joshua.discriminative.FileUtilityOld;

+import joshua.discriminative.training.NbestMerger;

 import joshua.discriminative.training.risk_annealer.AbstractMinRiskMERT;

 import joshua.discriminative.training.risk_annealer.DeterministicAnnealer;

 import joshua.discriminative.training.risk_annealer.GradientComputer;

 import joshua.discriminative.training.risk_annealer.hypergraph.MRConfig;

+import joshua.util.FileUtility;

 

 

 /** 

@@ -19,6 +22,7 @@
 * @version $LastChangedDate: 2008-10-20 00:12:30 -0400  $

 */

 public abstract class NbestMinRiskDAMert extends AbstractMinRiskMERT {		

+	int totalNumHyp = 0;

 	

 	String nbestPrefix;

 	boolean useShortestRef;

@@ -52,7 +56,16 @@
         	if(iter ==1){

         		copyNbest(curNbestFile, newNbestMergedFile);

         	}else{

-	            boolean haveNewHyp = mergeNbest(oldNbestMergedFile, curNbestFile, newNbestMergedFile);

+        		boolean haveNewHyp = true;

+        		if(true){

+        			int newTotalNumHyp = NbestMerger.mergeNbest(oldNbestMergedFile, curNbestFile, newNbestMergedFile);

+        			if(newTotalNumHyp!=totalNumHyp)

+        				haveNewHyp = true;

+        			totalNumHyp = newTotalNumHyp;

+        		}else{

+        			haveNewHyp = mergeNbest(oldNbestMergedFile, curNbestFile, newNbestMergedFile);

+        		}

+	            

 	            if(haveNewHyp==false) {

 	            	System.out.println("No new hypotheses generated at iteration " + iter); 

 	            	break;

@@ -90,7 +103,7 @@
 	//return false: if the nbest does not add any new hyp

 	//TODO: decide converged if the number of new hyp generate is very small

 	//TODO: terminate decoding when the weights does not change much; this one makes more sense, as if the weights do not change much; then new hypotheses will be rare

-	private boolean mergeNbest(String oldMergedNbestFile, String newNbestFile, String newMergedNbestFile){

+	public static boolean mergeNbest(String oldMergedNbestFile, String newNbestFile, String newMergedNbestFile){

 		boolean haveNewHyp =false;

 		BufferedReader oldMergedNbestReader = FileUtilityOld.getReadFileStream(oldMergedNbestFile);

 		BufferedReader newNbestReader = FileUtilityOld.getReadFileStream(newNbestFile);

@@ -99,22 +112,22 @@
 		int oldSentID=-1;

 		String line;

 		String previousLineInNewNbest = FileUtilityOld.readLineLzf(newNbestReader);;

-		HashMap<String, String> nbests = new HashMap<String, String>();//key: hyp itself, value: remaining fds exlcuding sent_id

+		HashMap<String, String> oldNbests = new HashMap<String, String>();//key: hyp itself, value: remaining fds exlcuding sent_id

 		while((line=FileUtilityOld.readLineLzf(oldMergedNbestReader))!=null){

 			String[] fds = line.split("\\s+\\|{3}\\s+");

-			int new_sent_id = new Integer(fds[0]);

-			if(oldSentID!=-1 && oldSentID!=new_sent_id){

+			int newSentID = new Integer(fds[0]);

+			if(oldSentID!=-1 && oldSentID!=newSentID){

 				boolean[] t_have_new_hyp = new boolean[1];

-				previousLineInNewNbest = processNbest(newNbestReader, newMergedNbestReader, oldSentID, nbests, previousLineInNewNbest, t_have_new_hyp);

+				previousLineInNewNbest = processNbest(newNbestReader, newMergedNbestReader, oldSentID, oldNbests, previousLineInNewNbest, t_have_new_hyp);

 				if(t_have_new_hyp[0]==true) 

 					haveNewHyp = true;

 			}

-			oldSentID = new_sent_id;

-			nbests.put(fds[1], fds[2]);//last field is not needed

+			oldSentID = newSentID;

+			oldNbests.put(fds[1], fds[2]);//last field is not needed

 		}

 		//last nbest

 		boolean[] t_have_new_hyp = new boolean[1];

-		previousLineInNewNbest= processNbest(newNbestReader, newMergedNbestReader, oldSentID, nbests, previousLineInNewNbest, t_have_new_hyp);

+		previousLineInNewNbest= processNbest(newNbestReader, newMergedNbestReader, oldSentID, oldNbests, previousLineInNewNbest, t_have_new_hyp);

 		if(previousLineInNewNbest!=null){

 			System.out.println("last line is not null, must be wrong"); 

 			System.exit(0);

@@ -128,7 +141,7 @@
 		return haveNewHyp;

 	}

 	

-	private String processNbest(BufferedReader newNbestReader, BufferedWriter newMergedNbestReader, int oldSentID, HashMap<String, String> nbests,

+	private static String processNbest(BufferedReader newNbestReader, BufferedWriter newMergedNbestReader, int oldSentID, HashMap<String, String> oldNbests,

 			String previousLine, boolean[] have_new_hyp){

 		have_new_hyp[0] = false;

 		String previousLineInNewNbest = previousLine;

@@ -137,9 +150,9 @@
 			String[] t_fds = previousLineInNewNbest.split("\\s+\\|{3}\\s+");

 			int t_new_id = new Integer(t_fds[0]); 

 			if( t_new_id == oldSentID){

-				if(nbests.containsKey(t_fds[1])==false){//new hyp

+				if(oldNbests.containsKey(t_fds[1])==false){//new hyp

 					have_new_hyp[0] = true;

-					nbests.put(t_fds[1], t_fds[2]);//merge into nbests

+					oldNbests.put(t_fds[1], t_fds[2]);//merge into nbests

 				}

 			}else{

 				break;

@@ -149,15 +162,16 @@
 				break;

 		}

 		//#### print the nbest: order is not important; and the last field is ignored

-		for (Map.Entry<String, String> entry : nbests.entrySet()){ 

+		for (Map.Entry<String, String> entry : oldNbests.entrySet()){ 

 		    FileUtilityOld.writeLzf(newMergedNbestReader, oldSentID + " ||| " + entry.getKey() + " ||| " + entry.getValue() + "\n");

 		}				

-		nbests.clear();

+		oldNbests.clear();

 		return previousLineInNewNbest;

 	}

 	

 	//return false: if the nbest does not add any new hyp

-	private void copyNbest(String newNbestFile, String newMergedNbestFile){

+	public static void copyNbest(String newNbestFile, String newMergedNbestFile){

+		/*

 		BufferedReader newNbestReader = FileUtilityOld.getReadFileStream(newNbestFile);

 		BufferedWriter newMergedNbestReader =	FileUtilityOld.getWriteFileStream(newMergedNbestFile);		

 		

@@ -167,7 +181,13 @@
 		    FileUtilityOld.writeLzf(newMergedNbestReader, fds[0] + " ||| " + fds[1] + " ||| " + fds[2] + "\n");		

 		}

 		FileUtilityOld.closeReadFile(newNbestReader);

-		FileUtilityOld.closeWriteFile(newMergedNbestReader);

+		FileUtilityOld.closeWriteFile(newMergedNbestReader);*/

+		try {

+			FileUtility.copyFile(newNbestFile, newMergedNbestFile);

+		} catch (IOException e) {

+			// TODO Auto-generated catch block

+			e.printStackTrace();

+		}

 	}