blob: f0526c82299b8f9a3d106ffe1c78124d028dad2d [file] [log] [blame]
package joshua.corpus;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import joshua.decoder.Decoder;
import joshua.decoder.ff.lm.NGramLanguageModel;
import joshua.util.FormatUtils;
import joshua.util.MurmurHash;
/**
* Static singular vocabulary class. Supports vocabulary freezing and (de-)serialization into a
* vocabulary file.
*
* @author Juri Ganitkevitch
*/
public class Vocabulary {
private static ArrayList<NGramLanguageModel> lms;
private static TreeMap<Long, Integer> hashToId;
private static ArrayList<String> idToString;
private static TreeMap<Long, String> hashToString;
private static final Integer lock = new Integer(0);
private static final int UNKNOWN_ID;
private static final String UNKNOWN_WORD;
public static final String START_SYM = "<s>";
public static final String STOP_SYM = "</s>";
static {
UNKNOWN_ID = 0;
UNKNOWN_WORD = "<unk>";
lms = new ArrayList<NGramLanguageModel>();
clear();
}
public static boolean registerLanguageModel(NGramLanguageModel lm) {
synchronized (lock) {
// Store the language model.
lms.add(lm);
// Notify it of all the existing words.
boolean collision = false;
for (int i = idToString.size() - 1; i > 0; i--)
collision = collision || lm.registerWord(idToString.get(i), i);
return collision;
}
}
/**
* Reads a vocabulary from file. This deletes any additions to the vocabulary made prior to
* reading the file.
*
* @param file_name
* @return Returns true if vocabulary was read without mismatches or collisions.
* @throws IOException
*/
public static boolean read(String file_name) throws IOException {
synchronized (lock) {
File vocab_file = new File(file_name);
DataInputStream vocab_stream =
new DataInputStream(new BufferedInputStream(new FileInputStream(vocab_file)));
int size = vocab_stream.readInt();
Decoder.LOG(1, String.format("Read %d entries from the vocabulary", size));
clear();
for (int i = 0; i < size; i++) {
int id = vocab_stream.readInt();
String token = vocab_stream.readUTF();
if (id != Math.abs(id(token))) {
vocab_stream.close();
return false;
}
}
vocab_stream.close();
return (size + 1 == idToString.size());
}
}
public static void write(String file_name) throws IOException {
synchronized (lock) {
File vocab_file = new File(file_name);
DataOutputStream vocab_stream =
new DataOutputStream(new BufferedOutputStream(new FileOutputStream(vocab_file)));
vocab_stream.writeInt(idToString.size() - 1);
Decoder.LOG(1, String.format("Writing vocabulary: %d tokens", idToString.size() - 1));
for (int i = 1; i < idToString.size(); i++) {
vocab_stream.writeInt(i);
vocab_stream.writeUTF(idToString.get(i));
}
vocab_stream.close();
}
}
public static void freeze() {
synchronized (lock) {
int current_id = 1;
TreeMap<Long, Integer> hash_to_id = new TreeMap<Long, Integer>();
ArrayList<String> id_to_string = new ArrayList<String>(idToString.size() + 1);
id_to_string.add(UNKNOWN_ID, UNKNOWN_WORD);
Map.Entry<Long, Integer> walker = hashToId.firstEntry();
while (walker != null) {
String word = hashToString.get(walker.getKey());
hash_to_id.put(walker.getKey(), (walker.getValue() < 0 ? -current_id : current_id));
id_to_string.add(current_id, word);
current_id++;
walker = hashToId.higherEntry(walker.getKey());
}
idToString = id_to_string;
hashToId = hash_to_id;
}
}
public static int id(String token) {
synchronized (lock) {
long hash = 0;
try {
hash = MurmurHash.hash64(token);
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
String hash_word = hashToString.get(hash);
if (hash_word != null) {
if (!token.equals(hash_word)) {
Decoder.LOG(1, String.format("MurmurHash for the following symbols collides: '%s', '%s'", hash_word, token));
}
return hashToId.get(hash);
} else {
int id = idToString.size() * (nt(token) ? -1 : 1);
// register this (token,id) mapping with each language
// model, so that they can map it to their own private
// vocabularies
for (NGramLanguageModel lm : lms)
lm.registerWord(token, Math.abs(id));
idToString.add(token);
hashToString.put(hash, token);
hashToId.put(hash, id);
return id;
}
}
}
public static boolean hasId(int id) {
synchronized (lock) {
id = Math.abs(id);
return (id < idToString.size());
}
}
public static int[] addAll(String sentence) {
return addAll(sentence.split("\\s+"));
}
public static int[] addAll(String[] tokens) {
int[] ids = new int[tokens.length];
for (int i = 0; i < tokens.length; i++)
ids[i] = id(tokens[i]);
return ids;
}
public static String word(int id) {
id = Math.abs(id);
return idToString.get(id);
}
public static String getWords(int[] ids) {
if (ids.length == 0) return "";
StringBuilder sb = new StringBuilder();
for (int i = 0; i < ids.length - 1; i++)
sb.append(word(ids[i])).append(" ");
return sb.append(word(ids[ids.length - 1])).toString();
}
public static String getWords(final Iterable<Integer> ids) {
StringBuilder sb = new StringBuilder();
for (int id : ids)
sb.append(word(id)).append(" ");
return sb.deleteCharAt(sb.length() - 1).toString();
}
private static boolean isNonterminal(String word) {
return (word.substring(0,1).equals("[") && (word.substring(word.length() - 1,word.length()).equals("]")));
}
/**
* This method returns a list of all indices corresponding to Nonterminals in the Vocabulary
* @return
*/
public static List<Integer> getNonterminalIndices()
{
List<Integer> result = new ArrayList<Integer>();
for(int i = 0; i < idToString.size(); i++)
{
String word = idToString.get(i);
if(isNonterminal(word)){
result.add(i);
}
}
return result;
}
public static int getUnknownId() {
return UNKNOWN_ID;
}
public static String getUnknownWord() {
return UNKNOWN_WORD;
}
/**
* Returns true if the Vocabulary ID represents a nonterminal.
*
* @param id
* @return
*/
public static boolean nt(int id) {
return (id < 0);
}
public static boolean nt(String word) {
return FormatUtils.isNonterminal(word);
}
public static int size() {
synchronized (lock) {
return idToString.size();
}
}
public static int getTargetNonterminalIndex(int id) {
return FormatUtils.getNonterminalIndex(word(id));
}
private static void clear() {
hashToId = new TreeMap<Long, Integer>();
hashToString = new TreeMap<Long, String>();
idToString = new ArrayList<String>();
idToString.add(UNKNOWN_ID, UNKNOWN_WORD);
}
/**
* Used to indicate that a query has been made for a symbol that is not known.
*
* @author Lane Schwartz
*/
public static class UnknownSymbolException extends RuntimeException {
private static final long serialVersionUID = 1L;
/**
* Constructs an exception indicating that the specified identifier cannot be found in the
* symbol table.
*
* @param id Integer identifier
*/
public UnknownSymbolException(int id) {
super("Identifier " + id + " cannot be found in the symbol table");
}
/**
* Constructs an exception indicating that the specified symbol cannot be found in the symbol
* table.
*
* @param symbol String symbol
*/
public UnknownSymbolException(String symbol) {
super("Symbol " + symbol + " cannot be found in the symbol table");
}
}
/**
* Used to indicate that word hashing has produced a collision.
*
* @author Juri Ganitkevitch
*/
public static class HashCollisionException extends RuntimeException {
private static final long serialVersionUID = 1L;
public HashCollisionException(String first, String second) {
super("MurmurHash for the following symbols collides: '" + first + "', '" + second + "'");
}
}
public static Iterator<String> wordIterator() {
return idToString.iterator();
}
}