| #ifndef LM_VALUE_BUILD_H |
| #define LM_VALUE_BUILD_H |
| |
| #include "lm/weights.hh" |
| #include "lm/word_index.hh" |
| #include "util/bit_packing.hh" |
| |
| #include <vector> |
| |
| namespace lm { |
| namespace ngram { |
| |
| struct Config; |
| struct BackoffValue; |
| struct RestValue; |
| |
| class NoRestBuild { |
| public: |
| typedef BackoffValue Value; |
| |
| NoRestBuild() {} |
| |
| void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} |
| void SetRest(const WordIndex *, unsigned int, const ProbBackoff &) const {} |
| |
| template <class Second> bool MarkExtends(ProbBackoff &weights, const Second &) const { |
| util::UnsetSign(weights.prob); |
| return false; |
| } |
| |
| // Probing doesn't need to go back to unigram. |
| const static bool kMarkEvenLower = false; |
| }; |
| |
| class MaxRestBuild { |
| public: |
| typedef RestValue Value; |
| |
| MaxRestBuild() {} |
| |
| void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} |
| void SetRest(const WordIndex *, unsigned int, RestWeights &weights) const { |
| weights.rest = weights.prob; |
| util::SetSign(weights.rest); |
| } |
| |
| bool MarkExtends(RestWeights &weights, const RestWeights &to) const { |
| util::UnsetSign(weights.prob); |
| if (weights.rest >= to.rest) return false; |
| weights.rest = to.rest; |
| return true; |
| } |
| bool MarkExtends(RestWeights &weights, const Prob &to) const { |
| util::UnsetSign(weights.prob); |
| if (weights.rest >= to.prob) return false; |
| weights.rest = to.prob; |
| return true; |
| } |
| |
| // Probing does need to go back to unigram. |
| const static bool kMarkEvenLower = true; |
| }; |
| |
| template <class Model> class LowerRestBuild { |
| public: |
| typedef RestValue Value; |
| |
| LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab); |
| |
| ~LowerRestBuild(); |
| |
| void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} |
| void SetRest(const WordIndex *vocab_ids, unsigned int n, RestWeights &weights) const { |
| typename Model::State ignored; |
| if (n == 1) { |
| weights.rest = unigrams_[*vocab_ids]; |
| } else { |
| weights.rest = models_[n-2]->FullScoreForgotState(vocab_ids + 1, vocab_ids + n, *vocab_ids, ignored).prob; |
| } |
| } |
| |
| template <class Second> bool MarkExtends(RestWeights &weights, const Second &) const { |
| util::UnsetSign(weights.prob); |
| return false; |
| } |
| |
| const static bool kMarkEvenLower = false; |
| |
| std::vector<float> unigrams_; |
| |
| std::vector<const Model*> models_; |
| }; |
| |
| } // namespace ngram |
| } // namespace lm |
| |
| #endif // LM_VALUE_BUILD_H |