blob: d8faf948c55165e70af81bc0ae5f4928c8bfc021 [file] [log] [blame]
#include "lm/interpolate/backoff_reunification.hh"
#include "lm/common/ngram_stream.hh"
#define BOOST_TEST_MODULE InterpolateBackoffReunificationTest
#include <boost/test/unit_test.hpp>
namespace lm {
namespace interpolate {
namespace {
// none of this input actually makes sense, all we care about is making
// sure the merging works
template <uint8_t N>
struct Gram {
WordIndex ids[N];
float prob;
float boff;
};
template <uint8_t N>
struct Grams {
const static Gram<N> grams[];
};
template <>
const Gram<1> Grams<1>::grams[]
= {{{0}, 0.1f, 0.1f}, {{1}, 0.4f, 0.2f}, {{2}, 0.5f, 0.1f}};
template <>
const Gram<2> Grams<2>::grams[] = {{{0, 0}, 0.05f, 0.05f},
{{1, 0}, 0.05f, 0.02f},
{{1, 1}, 0.2f, 0.04f},
{{2, 2}, 0.2f, 0.01f}};
template <>
const Gram<3> Grams<3>::grams[] = {{{0, 0, 0}, 0.001f, 0.005f},
{{1, 0, 0}, 0.001f, 0.002f},
{{2, 0, 0}, 0.001f, 0.003f},
{{0, 1, 0}, 0.1f, 0.008f},
{{1, 1, 0}, 0.1f, 0.09f},
{{1, 1, 1}, 0.2f, 0.08f}};
template <uint8_t N>
class WriteInput {
public:
void Run(const util::stream::ChainPosition &position) {
lm::NGramStream<float> output(position);
for (std::size_t i = 0; i < sizeof(Grams<N>::grams) / sizeof(Gram<N>);
++i, ++output) {
std::copy(Grams<N>::grams[i].ids, Grams<N>::grams[i].ids + N,
output->begin());
output->Value() = Grams<N>::grams[i].prob;
}
output.Poison();
}
};
template <uint8_t N>
class WriteBackoffs {
public:
void Run(const util::stream::ChainPosition &position) {
util::stream::Stream output(position);
for (std::size_t i = 0; i < sizeof(Grams<N>::grams) / sizeof(Gram<N>);
++i, ++output) {
*reinterpret_cast<float *>(output.Get()) = Grams<N>::grams[i].boff;
}
output.Poison();
}
};
template <uint8_t N>
class CheckOutput {
public:
void Run(const util::stream::ChainPosition &position) {
lm::NGramStream<ProbBackoff> stream(position);
std::size_t i = 0;
for (; stream; ++stream, ++i) {
std::stringstream ss;
for (WordIndex *idx = stream->begin(); idx != stream->end(); ++idx)
ss << "(" << *idx << ")";
UTIL_THROW_IF2(
!std::equal(stream->begin(), stream->end(), Grams<N>::grams[i].ids),
"Mismatched id in CheckOutput<" << (int)N << ">: " << ss.str());
UTIL_THROW_IF2(stream->Value().prob != Grams<N>::grams[i].prob,
"Mismatched probability in CheckOutput<"
<< (int)N << ">, got " << stream->Value().prob
<< ", expected " << Grams<N>::grams[i].prob);
UTIL_THROW_IF2(stream->Value().backoff != Grams<N>::grams[i].boff,
"Mismatched backoff in CheckOutput<"
<< (int)N << ">, got " << stream->Value().backoff
<< ", expected " << Grams<N>::grams[i].boff);
}
UTIL_THROW_IF2(i != sizeof(Grams<N>::grams) / sizeof(Gram<N>),
"Did not get correct number of "
<< (int)N << "-grams: expected "
<< sizeof(Grams<N>::grams) / sizeof(Gram<N>)
<< ", got " << i);
}
};
}
BOOST_AUTO_TEST_CASE(BackoffReunificationTest) {
util::stream::ChainConfig config;
config.total_memory = 100;
config.block_count = 1;
util::stream::Chains prob_chains(3);
config.entry_size = NGram<float>::TotalSize(1);
prob_chains.push_back(config);
prob_chains.back() >> WriteInput<1>();
config.entry_size = NGram<float>::TotalSize(2);
prob_chains.push_back(config);
prob_chains.back() >> WriteInput<2>();
config.entry_size = NGram<float>::TotalSize(3);
prob_chains.push_back(config);
prob_chains.back() >> WriteInput<3>();
util::stream::Chains boff_chains(3);
config.entry_size = sizeof(float);
boff_chains.push_back(config);
boff_chains.back() >> WriteBackoffs<1>();
boff_chains.push_back(config);
boff_chains.back() >> WriteBackoffs<2>();
boff_chains.push_back(config);
boff_chains.back() >> WriteBackoffs<3>();
util::stream::ChainPositions prob_pos(prob_chains);
util::stream::ChainPositions boff_pos(boff_chains);
util::stream::Chains output_chains(3);
for (std::size_t i = 0; i < 3; ++i) {
config.entry_size = NGram<ProbBackoff>::TotalSize(i + 1);
output_chains.push_back(config);
}
ReunifyBackoff(prob_pos, boff_pos, output_chains);
output_chains[0] >> CheckOutput<1>();
output_chains[1] >> CheckOutput<2>();
output_chains[2] >> CheckOutput<3>();
prob_chains >> util::stream::kRecycle;
boff_chains >> util::stream::kRecycle;
output_chains.Wait();
}
}
}