blob: c20e637b1835682254e903a193de22017a438efa [file] [log] [blame]
#include "lm/interpolate/tune_derivatives.hh"
namespace lm { namespace interpolate {
ComputeDerivative::ComputeDerivative(const util::FixedArray<Instance> &instances, const Matrix &ln_unigrams, WordIndex bos)
: instances_(instances), ln_unigrams_(ln_unigrams), bos_(bos) {
neg_correct_summed_ = Vector::Zero(ln_unigrams.cols());
for (const Instance *i = instances.begin(); i != instances.end(); ++i) {
neg_correct_summed_ -= i->ln_correct;
}
}
Accum ComputeDerivative::Iteration(const Vector &weights, Vector &gradient, Matrix &hessian) {
gradient = neg_correct_summed_;
hessian = Matrix::Zero(weights.rows(), weights.rows());
// TODO: loop instead to force low-memory evaluation
// Compute p_I(x).
Vector interp_uni((ln_unigrams_ * weights).array().exp());
// Even -inf doesn't work for <s> because weights can be negative. Manually set it to zero.
interp_uni(bos_) = 0.0;
Accum Z_epsilon = interp_uni.sum();
interp_uni /= Z_epsilon;
// unigram_cross(i) = \sum_{all x} p_I(x) ln p_i(x)
Vector unigram_cross(ln_unigrams_.transpose() * interp_uni);
Accum sum_B_I = 0.0;
Accum sum_ln_Z_context = 0.0;
Vector weighted_extensions;
Matrix convolve;
Vector full_cross;
for (const Instance *n = instances_.begin(); n != instances_.end(); ++n) {
Accum ln_weighted_backoffs = n->ln_backoff.dot(weights);
Accum weighted_backoffs = exp(ln_weighted_backoffs);
// Compute \sum_{x: model does not backoff to unigram} p_I(x)
Accum sum_x_p_I = 0.0;
for (std::vector<WordIndex>::const_iterator x = n->extension_words.begin(); x != n->extension_words.end(); ++x) {
sum_x_p_I += interp_uni(*x);
}
weighted_extensions = (n->ln_extensions * weights).array().exp();
Accum Z_context = Z_epsilon * weighted_backoffs * (1.0 - sum_x_p_I) + weighted_extensions.sum();
sum_ln_Z_context += log(Z_context);
Accum B_I = Z_epsilon / Z_context * weighted_backoffs;
sum_B_I += B_I;
// This is the gradient term for this instance except for -log p_i(w_n | w_1^{n-1}) which was accounted for as part of neg_correct_sum_.
// full_cross(i) is \sum_{all x} p_I(x | context) log p_i(x | context)
full_cross =
// Uncorrected term
B_I * (n->ln_backoff + unigram_cross)
// Correction term: add correct values
+ n->ln_extensions.transpose() * weighted_extensions / Z_context
// Subtract values that should not have been charged.
- sum_x_p_I * B_I * n->ln_backoff;
for (std::vector<WordIndex>::const_iterator x = n->extension_words.begin(); x != n->extension_words.end(); ++x) {
full_cross.noalias() -= interp_uni(*x) * B_I * ln_unigrams_.row(*x);
}
gradient += full_cross;
convolve = unigram_cross * n->ln_backoff.transpose();
// There's one missing term here, which is independent of context and done at the end.
hessian.noalias() +=
// First term of Hessian, assuming all models back off to unigram.
B_I * (convolve + convolve.transpose() + n->ln_backoff * n->ln_backoff.transpose())
// Second term of Hessian, with correct full probabilities.
- full_cross * full_cross.transpose();
// Adjust the first term of the Hessian to account for extension
for (std::size_t x = 0; x < n->extension_words.size(); ++x) {
WordIndex universal_x = n->extension_words[x];
hessian.noalias() +=
// Replacement terms.
weighted_extensions(x) / Z_context * n->ln_extensions.row(x).transpose() * n->ln_extensions.row(x)
// Presumed unigrams. TODO: individual terms with backoffs pulled out? Maybe faster?
- interp_uni(universal_x) * B_I * (ln_unigrams_.row(universal_x).transpose() + n->ln_backoff) * (ln_unigrams_.row(universal_x) + n->ln_backoff.transpose());
}
}
for (Matrix::Index x = 0; x < interp_uni.rows(); ++x) {
// \sum_{contexts} B_I(context) \sum_x p_I(x) log p_i(x) log p_j(x)
hessian.noalias() += sum_B_I * interp_uni(x) * ln_unigrams_.row(x).transpose() * ln_unigrams_.row(x);
}
return exp((neg_correct_summed_.dot(weights) + sum_ln_Z_context) / static_cast<double>(instances_.size()));
}
}} // namespaces