| /* |
| |
| EGYPT Toolkit for Statistical Machine Translation |
| Written by Yaser Al-Onaizan, Jan Curin, Michael Jahr, Kevin Knight, John Lafferty, Dan Melamed, David Purdy, Franz Och, Noah Smith, and David Yarowsky. |
| |
| This program is free software; you can redistribute it and/or |
| modify it under the terms of the GNU General Public License |
| as published by the Free Software Foundation; either version 2 |
| of the License, or (at your option) any later version. |
| |
| This program is distributed in the hope that it will be useful, |
| but WITHOUT ANY WARRANTY; without even the implied warranty of |
| MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| GNU General Public License for more details. |
| |
| You should have received a copy of the GNU General Public License |
| along with this program; if not, write to the Free Software |
| Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, |
| USA. |
| |
| */ |
| /* --------------------------------------------------------------------------* |
| * * |
| * Module :ATables * |
| * * |
| * Prototypes File: ATables.h * |
| * * |
| * Objective: Defines clases and methods for handling I/O for distortion & * |
| * alignment tables. * |
| *****************************************************************************/ |
| |
| #ifndef _atables_h |
| #define _atables_h 1 |
| |
| #include "defs.h" |
| #include <cassert> |
| #include <iostream> |
| #include <algorithm> |
| #include <functional> |
| #include <map> |
| #include <set> |
| #include "Vector.h" |
| #include <utility> |
| #if __GNUC__>2 |
| #include <ext/hash_map> |
| using __gnu_cxx::hash_map; |
| #else |
| #include <hash_map> |
| #endif |
| #include <fstream> |
| #include "Array4.h" |
| #include "myassert.h" |
| #include "Globals.h" |
| #include "syncObj.h" |
| |
| extern bool CompactADTable; |
| extern float amodel_smooth_factor; |
| extern short NoEmptyWord; |
| |
| /* ------------------- Class Defintions of amodel ---------------------------*/ |
| /* Class Name: amodel: |
| Objective: This defines the underlying data structure for distortiont prob. |
| and count tables. They are defined as a hash table. Each entry in the hash |
| table is the probability (d(j/l,m,i), where j is word target position, i is |
| source word position connected to it, m is target sentence length, and l is |
| source sentence length) or count collected for it. The probability and the |
| count are represented as log integer probability as |
| defined by the class LogProb . |
| |
| This class is used to represents a Tables (probabiliity) and d (distortion) |
| tables and also their corresponding count tables . |
| |
| *--------------------------------------------------------------------------*/ |
| |
| inline int Mabs(int a){ |
| if(a<0) |
| return -a; |
| else |
| return a; |
| } |
| |
| template <class VALTYPE> |
| class amodel{ |
| public: |
| Array4<VALTYPE> a; |
| bool is_distortion ; |
| WordIndex MaxSentLength; |
| bool ignoreL, ignoreM; |
| VALTYPE get(WordIndex aj, WordIndex j, WordIndex l, WordIndex m)const{ |
| massert( (!is_distortion) || aj<=m );massert( (!is_distortion) || j<=l );massert( (!is_distortion) || aj!=0 ); |
| massert( is_distortion || aj<=l );massert( is_distortion || j<=m );massert( (is_distortion) || j!=0 ); |
| massert( l<MaxSentLength );massert( m<MaxSentLength ); |
| return a.get(aj, j, (CompactADTable&&is_distortion)?MaxSentLength:(l+1),(CompactADTable&&!is_distortion)?MaxSentLength:(m+1)); |
| } |
| |
| static float smooth_factor; |
| amodel(bool flag = false) |
| : a(MAX_SENTENCE_LENGTH+1,0.0), is_distortion(flag), MaxSentLength(MAX_SENTENCE_LENGTH) |
| {lock = new Mutex();}; |
| |
| ~amodel(){delete lock;}; |
| |
| protected: |
| VALTYPE&getRef(WordIndex aj, WordIndex j, WordIndex l, WordIndex m){ |
| massert( (!is_distortion) || aj<=m );massert( (!is_distortion) || j<=l ); |
| massert( is_distortion || aj<=l );massert( is_distortion || j<=m );massert( (is_distortion) || j!=0 ); |
| massert( l<MaxSentLength );massert( m<MaxSentLength ); |
| return a(aj, j, (CompactADTable&&is_distortion)?MaxSentLength:(l+1),(CompactADTable&&!is_distortion)?MaxSentLength:(m+1)); |
| } |
| public: |
| void setValue(WordIndex aj, WordIndex j, WordIndex l, WordIndex m, VALTYPE val) { |
| lock->lock(); |
| getRef(aj, j, l, m)=val; |
| lock->unlock(); |
| } |
| |
| Mutex* lock; |
| public: |
| /** |
| By Qin |
| */ |
| void addValue(WordIndex aj, WordIndex j, WordIndex l, WordIndex m, VALTYPE val) { |
| lock->lock(); |
| getRef(aj, j, l, m)+=val; |
| lock->unlock(); |
| } |
| bool merge(amodel<VALTYPE>& am); |
| VALTYPE getValue(WordIndex aj, WordIndex j, WordIndex l, WordIndex m) const{ |
| if( is_distortion==0 ) |
| return max(double(PROB_SMOOTH),amodel_smooth_factor/(l+1)+(1.0-amodel_smooth_factor)*get(aj, j, l, m)); |
| else |
| return max(double(PROB_SMOOTH),amodel_smooth_factor/m+(1.0-amodel_smooth_factor)*get(aj, j, l, m)); |
| } |
| |
| void printTable(const char* filename)const ; |
| void printRealTable(const char* filename)const ; |
| template<class COUNT> |
| void normalize(amodel<COUNT>& aTable)const |
| { |
| WordIndex i, j, l, m ; |
| COUNT total; |
| int nParam=0; |
| for(l=0;l<MaxSentLength;l++){ |
| for(m=0;m<MaxSentLength;m++){ |
| if( CompactADTable && l!=m ) |
| continue; |
| unsigned int L=((CompactADTable&&is_distortion)?MaxSentLength:(l+1))-1; |
| unsigned int M=((CompactADTable&&!is_distortion)?MaxSentLength:(m+1))-1; |
| if( is_distortion==0 ){ |
| for(j=1;j<=M; j++){ |
| total=0.0; |
| for(i=0;i<=L;i++){ |
| total+=get(i, j, L, M); |
| } |
| if( total ){ |
| for(i=0;i<=L;i++){ |
| nParam++; |
| aTable.getRef(i, j, L, M)=get(i, j, L, M)/total; |
| massert(aTable.getRef(i,j,L,M)<=1.0); |
| if( NoEmptyWord&&i==0 ) |
| aTable.getRef(i,j,L,M)=0; |
| } |
| } |
| } |
| }else{ |
| for(i=0;i<=L;i++){ |
| total=0.0; |
| for(j=1;j<=M;j++) |
| total+=get(j, i, L, M); |
| if( total ) |
| for(j=1;j<=M;j++){ |
| aTable.getRef(j, i, L, M)=amodel_smooth_factor/M+(1.0-amodel_smooth_factor)*get(j, i, L, M)/total; |
| nParam++; |
| massert(aTable.getRef(j,i,L,M)<=1.0); |
| if( NoEmptyWord&&i==0 ) |
| aTable.getRef(j,i,L,M)=0; |
| } |
| } |
| } |
| } |
| } |
| cout << "A/D table contains " << nParam << " parameters.\n"; |
| } |
| |
| bool readTable(const char *filename); |
| bool readAugTable(const char *filename); |
| void clear() |
| {a.clear();} |
| }; |
| |
| /* ------------------- End of amodel Class Definitions ----------------------*/ |
| |
| #endif |