blob: a0db59cebf04216bdb90adc9ed89de352a3cc571 [file] [log] [blame]
#include "lm/interpolate/tune_instance.hh"
#include "util/file_stream.hh"
#include "util/file.hh"
#include "util/string_piece.hh"
#define BOOST_TEST_MODULE InstanceTest
#include <boost/test/unit_test.hpp>
#include <iostream>
#include <vector>
namespace lm { namespace interpolate { namespace {
Matrix::Index FindRow(const std::vector<WordIndex> &words, WordIndex word) {
std::vector<WordIndex>::const_iterator it = std::find(words.begin(), words.end(), word);
BOOST_REQUIRE(it != words.end());
return it - words.begin();
}
BOOST_AUTO_TEST_CASE(Toy) {
util::scoped_fd test_input(util::MakeTemp("temporary"));
{
util::FileStream(test_input.get()) << "c\n";
}
StringPiece dir("tune_instance_data/");
if (boost::unit_test::framework::master_test_suite().argc == 2) {
StringPiece zero_file(boost::unit_test::framework::master_test_suite().argv[1]);
BOOST_REQUIRE(zero_file.size() > strlen("toy0.1"));
BOOST_REQUIRE_EQUAL("toy0.1", StringPiece(zero_file.data() + zero_file.size() - 6, 6));
dir = StringPiece(zero_file.data(), zero_file.size() - 6);
}
std::vector<StringPiece> model_names;
std::string full0 = std::string(dir.data(), dir.size()) + "toy0";
std::string full1 = std::string(dir.data(), dir.size()) + "toy1";
model_names.push_back(full0);
model_names.push_back(full1);
util::FixedArray<Instance> instances;
Matrix ln_unigrams;
// Returns vocab id of <s>
BOOST_CHECK_EQUAL(1, LoadInstances(test_input.release(), model_names, instances, ln_unigrams));
// <unk>
BOOST_CHECK_CLOSE(-0.90309 * M_LN10, ln_unigrams(0, 0), 0.001);
BOOST_CHECK_CLOSE(-1 * M_LN10, ln_unigrams(0, 1), 0.001);
// <s>
BOOST_CHECK_GT(-98.0, ln_unigrams(1, 0));
BOOST_CHECK_GT(-98.0, ln_unigrams(1, 1));
// a
BOOST_CHECK_CLOSE(-0.46943438 * M_LN10, ln_unigrams(2, 0), 0.001);
BOOST_CHECK_CLOSE(-0.6146491 * M_LN10, ln_unigrams(2, 1), 0.001);
// </s>
BOOST_CHECK_CLOSE(-0.5720968 * M_LN10, ln_unigrams(3, 0), 0.001);
BOOST_CHECK_CLOSE(-0.6146491 * M_LN10, ln_unigrams(3, 1), 0.001);
// c
BOOST_CHECK_CLOSE(-0.90309 * M_LN10, ln_unigrams(4, 0), 0.001); // <unk>
BOOST_CHECK_CLOSE(-0.7659168 * M_LN10, ln_unigrams(4, 1), 0.001);
// too lazy to do b.
// Two instances:
// <s> predicts c
// <s> c predicts </s>
BOOST_REQUIRE_EQUAL(2, instances.size());
BOOST_CHECK_CLOSE(-0.30103 * M_LN10, instances[0].ln_backoff(0), 0.001);
BOOST_CHECK_CLOSE(-0.30103 * M_LN10, instances[0].ln_backoff(1), 0.001);
// Backoffs of <s> c
BOOST_CHECK_CLOSE(0.0, instances[1].ln_backoff(0), 0.001);
BOOST_CHECK_CLOSE((-0.30103 - 0.30103) * M_LN10, instances[1].ln_backoff(1), 0.001);
// Three extensions: a, b, c
BOOST_REQUIRE_EQUAL(3, instances[0].ln_extensions.rows());
BOOST_REQUIRE_EQUAL(3, instances[0].extension_words.size());
// <s> a
BOOST_CHECK_CLOSE(-0.37712017 * M_LN10, instances[0].ln_extensions(FindRow(instances[0].extension_words, 2), 0), 0.001);
// <s> c
BOOST_CHECK_CLOSE((-0.90309 + -0.30103) * M_LN10, instances[0].ln_extensions(FindRow(instances[0].extension_words, 4), 0), 0.001);
BOOST_CHECK_CLOSE(-0.4740302 * M_LN10, instances[0].ln_extensions(FindRow(instances[0].extension_words, 4), 1), 0.001);
// <s> c </s>
BOOST_CHECK_CLOSE(-0.09113217 * M_LN10, instances[1].ln_extensions(FindRow(instances[1].extension_words, 3), 1), 0.001);
// p_0(c | <s>) = p_0(c)b_0(<s>) = 10^(-0.90309 + -0.30103)
BOOST_CHECK_CLOSE((-0.90309 + -0.30103) * M_LN10, instances[0].ln_correct(0), 0.001);
// p_1(c | <s>) = 10^-0.4740302
BOOST_CHECK_CLOSE(-0.4740302 * M_LN10, instances[0].ln_correct(1), 0.001);
}
}}} // namespaces