blob: c11eec90ab3ad73b3c5f96472787453951611725 [file] [log] [blame]
#ifndef LM_INTERPOLATE_TUNE_INSTANCE_H
#define LM_INTERPOLATE_TUNE_INSTANCE_H
#include "lm/interpolate/tune_matrix.hh"
#include "lm/word_index.hh"
#include "util/scoped.hh"
#include "util/stream/config.hh"
#include "util/string_piece.hh"
#include <boost/optional.hpp>
#include <vector>
namespace util { namespace stream {
template <class S, class T> class Sort;
class Chain;
class FileBuffer;
}} // namespaces
namespace lm { namespace interpolate {
typedef uint32_t InstanceIndex;
typedef uint32_t ModelIndex;
struct Extension {
// Which tuning instance does this belong to?
InstanceIndex instance;
WordIndex word;
ModelIndex model;
// ln p_{model} (word | context(instance))
float ln_prob;
bool operator<(const Extension &other) const {
if (instance != other.instance)
return instance < other.instance;
if (word != other.word)
return word < other.word;
if (model != other.model)
return model < other.model;
return false;
}
};
class Instances {
public:
Instances(int tune_file, const std::vector<StringPiece> &model_names);
Eigen::ConstRowXpr Backoffs(InstanceIndex instance) const {
return ln_backoffs_.row(instance);
}
const Vector &CorrectGradientTerm() const { return neg_ln_correct_sum_; }
const Matrix &LNUnigrams() const { return ln_unigrams_; }
void ReadExtensions(util::stream::Chain &to);
private:
// backoffs_(instance, model) is the backoff all the way to unigrams.
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> BackoffMatrix;
BackoffMatrix ln_backoffs_;
// neg_correct_sum_(model) = -\sum_{instances} ln p_{model}(correct(instance) | context(instance)).
// This appears as a term in the gradient.
Vector neg_ln_correct_sum_;
// unigrams_(word, model) = ln p_{model}(word).
Matrix ln_unigrams_;
struct ExtensionCompare {
bool operator()(const void *f, const void *s) const {
return reinterpret_cast<const Extension &>(f) < reinterpret_cast<const Extension &>(s);
}
};
// This is the source of data for the first iteration.
util::scoped_ptr<util::stream::Sort<ExtensionCompare> > extensions_first_;
// Source of data for subsequent iterations. This contains already-sorted data.
util::scoped_ptr<util::stream::FileBuffer> extensions_subsequent_;
const util::stream::SortConfig sorting_config_;
};
}} // namespaces
#endif // LM_INTERPOLATE_TUNE_INSTANCE_H