| #include "lm/interpolate/backoff_reunification.hh" |
| #include "lm/common/ngram_stream.hh" |
| |
| #define BOOST_TEST_MODULE InterpolateBackoffReunificationTest |
| #include <boost/test/unit_test.hpp> |
| |
| namespace lm { |
| namespace interpolate { |
| |
| namespace { |
| |
| // none of this input actually makes sense, all we care about is making |
| // sure the merging works |
| template <uint8_t N> |
| struct Gram { |
| WordIndex ids[N]; |
| float prob; |
| float boff; |
| }; |
| |
| template <uint8_t N> |
| struct Grams { |
| const static Gram<N> grams[]; |
| }; |
| |
| template <> |
| const Gram<1> Grams<1>::grams[] |
| = {{{0}, 0.1f, 0.1f}, {{1}, 0.4f, 0.2f}, {{2}, 0.5f, 0.1f}}; |
| |
| template <> |
| const Gram<2> Grams<2>::grams[] = {{{0, 0}, 0.05f, 0.05f}, |
| {{1, 0}, 0.05f, 0.02f}, |
| {{1, 1}, 0.2f, 0.04f}, |
| {{2, 2}, 0.2f, 0.01f}}; |
| |
| template <> |
| const Gram<3> Grams<3>::grams[] = {{{0, 0, 0}, 0.001f, 0.005f}, |
| {{1, 0, 0}, 0.001f, 0.002f}, |
| {{2, 0, 0}, 0.001f, 0.003f}, |
| {{0, 1, 0}, 0.1f, 0.008f}, |
| {{1, 1, 0}, 0.1f, 0.09f}, |
| {{1, 1, 1}, 0.2f, 0.08f}}; |
| |
| template <uint8_t N> |
| class WriteInput { |
| public: |
| void Run(const util::stream::ChainPosition &position) { |
| lm::NGramStream<float> output(position); |
| |
| for (std::size_t i = 0; i < sizeof(Grams<N>::grams) / sizeof(Gram<N>); |
| ++i, ++output) { |
| std::copy(Grams<N>::grams[i].ids, Grams<N>::grams[i].ids + N, |
| output->begin()); |
| output->Value() = Grams<N>::grams[i].prob; |
| } |
| output.Poison(); |
| } |
| }; |
| |
| template <uint8_t N> |
| class WriteBackoffs { |
| public: |
| void Run(const util::stream::ChainPosition &position) { |
| util::stream::Stream output(position); |
| |
| for (std::size_t i = 0; i < sizeof(Grams<N>::grams) / sizeof(Gram<N>); |
| ++i, ++output) { |
| *reinterpret_cast<float *>(output.Get()) = Grams<N>::grams[i].boff; |
| } |
| output.Poison(); |
| } |
| }; |
| |
| template <uint8_t N> |
| class CheckOutput { |
| public: |
| void Run(const util::stream::ChainPosition &position) { |
| lm::NGramStream<ProbBackoff> stream(position); |
| |
| std::size_t i = 0; |
| for (; stream; ++stream, ++i) { |
| std::stringstream ss; |
| for (WordIndex *idx = stream->begin(); idx != stream->end(); ++idx) |
| ss << "(" << *idx << ")"; |
| |
| UTIL_THROW_IF2( |
| !std::equal(stream->begin(), stream->end(), Grams<N>::grams[i].ids), |
| "Mismatched id in CheckOutput<" << (int)N << ">: " << ss.str()); |
| |
| UTIL_THROW_IF2(stream->Value().prob != Grams<N>::grams[i].prob, |
| "Mismatched probability in CheckOutput<" |
| << (int)N << ">, got " << stream->Value().prob |
| << ", expected " << Grams<N>::grams[i].prob); |
| |
| UTIL_THROW_IF2(stream->Value().backoff != Grams<N>::grams[i].boff, |
| "Mismatched backoff in CheckOutput<" |
| << (int)N << ">, got " << stream->Value().backoff |
| << ", expected " << Grams<N>::grams[i].boff); |
| } |
| UTIL_THROW_IF2(i != sizeof(Grams<N>::grams) / sizeof(Gram<N>), |
| "Did not get correct number of " |
| << (int)N << "-grams: expected " |
| << sizeof(Grams<N>::grams) / sizeof(Gram<N>) |
| << ", got " << i); |
| } |
| }; |
| } |
| |
| BOOST_AUTO_TEST_CASE(BackoffReunificationTest) { |
| util::stream::ChainConfig config; |
| config.total_memory = 100; |
| config.block_count = 1; |
| |
| util::stream::Chains prob_chains(3); |
| config.entry_size = NGram<float>::TotalSize(1); |
| prob_chains.push_back(config); |
| prob_chains.back() >> WriteInput<1>(); |
| |
| config.entry_size = NGram<float>::TotalSize(2); |
| prob_chains.push_back(config); |
| prob_chains.back() >> WriteInput<2>(); |
| |
| config.entry_size = NGram<float>::TotalSize(3); |
| prob_chains.push_back(config); |
| prob_chains.back() >> WriteInput<3>(); |
| |
| util::stream::Chains boff_chains(3); |
| config.entry_size = sizeof(float); |
| boff_chains.push_back(config); |
| boff_chains.back() >> WriteBackoffs<1>(); |
| |
| boff_chains.push_back(config); |
| boff_chains.back() >> WriteBackoffs<2>(); |
| |
| boff_chains.push_back(config); |
| boff_chains.back() >> WriteBackoffs<3>(); |
| |
| util::stream::ChainPositions prob_pos(prob_chains); |
| util::stream::ChainPositions boff_pos(boff_chains); |
| |
| util::stream::Chains output_chains(3); |
| for (std::size_t i = 0; i < 3; ++i) { |
| config.entry_size = NGram<ProbBackoff>::TotalSize(i + 1); |
| output_chains.push_back(config); |
| } |
| |
| ReunifyBackoff(prob_pos, boff_pos, output_chains); |
| |
| output_chains[0] >> CheckOutput<1>(); |
| output_chains[1] >> CheckOutput<2>(); |
| output_chains[2] >> CheckOutput<3>(); |
| |
| prob_chains >> util::stream::kRecycle; |
| boff_chains >> util::stream::kRecycle; |
| |
| output_chains.Wait(); |
| } |
| } |
| } |