| #ifndef LM_VOCAB_H |
| #define LM_VOCAB_H |
| |
| #include "lm/enumerate_vocab.hh" |
| #include "lm/lm_exception.hh" |
| #include "lm/virtual_interface.hh" |
| #include "util/file_stream.hh" |
| #include "util/murmur_hash.hh" |
| #include "util/pool.hh" |
| #include "util/probing_hash_table.hh" |
| #include "util/sorted_uniform.hh" |
| #include "util/string_piece.hh" |
| |
| #include <limits> |
| #include <string> |
| #include <vector> |
| |
| namespace lm { |
| struct ProbBackoff; |
| class EnumerateVocab; |
| |
| namespace ngram { |
| struct Config; |
| |
| namespace detail { |
| uint64_t HashForVocab(const char *str, std::size_t len); |
| inline uint64_t HashForVocab(const StringPiece &str) { |
| return HashForVocab(str.data(), str.length()); |
| } |
| struct ProbingVocabularyHeader; |
| } // namespace detail |
| |
| // Writes words immediately to a file instead of buffering, because we know |
| // where in the file to put them. |
| class ImmediateWriteWordsWrapper : public EnumerateVocab { |
| public: |
| ImmediateWriteWordsWrapper(EnumerateVocab *inner, int fd, uint64_t start); |
| |
| void Add(WordIndex index, const StringPiece &str) { |
| stream_ << str << '\0'; |
| if (inner_) inner_->Add(index, str); |
| } |
| |
| private: |
| EnumerateVocab *inner_; |
| |
| util::FileStream stream_; |
| }; |
| |
| // When the binary size isn't known yet. |
| class WriteWordsWrapper : public EnumerateVocab { |
| public: |
| WriteWordsWrapper(EnumerateVocab *inner); |
| |
| void Add(WordIndex index, const StringPiece &str); |
| |
| const std::string &Buffer() const { return buffer_; } |
| void Write(int fd, uint64_t start); |
| |
| private: |
| EnumerateVocab *inner_; |
| |
| std::string buffer_; |
| }; |
| |
| // Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices. |
| class SortedVocabulary : public base::Vocabulary { |
| public: |
| SortedVocabulary(); |
| |
| WordIndex Index(const StringPiece &str) const { |
| const uint64_t *found; |
| if (util::BoundedSortedUniformFind<const uint64_t*, util::IdentityAccessor<uint64_t>, util::Pivot64>( |
| util::IdentityAccessor<uint64_t>(), |
| begin_ - 1, 0, |
| end_, std::numeric_limits<uint64_t>::max(), |
| detail::HashForVocab(str), found)) { |
| return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table. |
| } else { |
| return 0; |
| } |
| } |
| |
| // Size for purposes of file writing |
| static uint64_t Size(uint64_t entries, const Config &config); |
| |
| /* Read null-delimited words from file from_words, renumber according to |
| * hash order, write null-delimited words to to_words, and create a mapping |
| * from old id to new id. The 0th vocab word must be <unk>. |
| */ |
| static void ComputeRenumbering(WordIndex types, int from_words, int to_words, std::vector<WordIndex> &mapping); |
| |
| // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. |
| WordIndex Bound() const { return bound_; } |
| |
| // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. |
| void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); |
| |
| void Relocate(void *new_start); |
| |
| void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); |
| |
| // Insert and FinishedLoading go together. |
| WordIndex Insert(const StringPiece &str); |
| // Reorders reorder_vocab so that the IDs are sorted. |
| void FinishedLoading(ProbBackoff *reorder_vocab); |
| |
| // Trie stores the correct counts including <unk> in the header. If this was previously sized based on a count exluding <unk>, padding with 8 bytes will make it the correct size based on a count including <unk>. |
| std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); } |
| |
| bool SawUnk() const { return saw_unk_; } |
| |
| void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); |
| |
| uint64_t *&EndHack() { return end_; } |
| |
| void Populated(); |
| |
| private: |
| template <class T> void GenericFinished(T *reorder); |
| |
| uint64_t *begin_, *end_; |
| |
| WordIndex bound_; |
| |
| bool saw_unk_; |
| |
| EnumerateVocab *enumerate_; |
| |
| // Actual strings. Used only when loading from ARPA and enumerate_ != NULL |
| util::Pool string_backing_; |
| |
| std::vector<StringPiece> strings_to_enumerate_; |
| }; |
| |
| #pragma pack(push) |
| #pragma pack(4) |
| struct ProbingVocabularyEntry { |
| uint64_t key; |
| WordIndex value; |
| |
| typedef uint64_t Key; |
| uint64_t GetKey() const { return key; } |
| void SetKey(uint64_t to) { key = to; } |
| |
| static ProbingVocabularyEntry Make(uint64_t key, WordIndex value) { |
| ProbingVocabularyEntry ret; |
| ret.key = key; |
| ret.value = value; |
| return ret; |
| } |
| }; |
| #pragma pack(pop) |
| |
| // Vocabulary storing a map from uint64_t to WordIndex. |
| class ProbingVocabulary : public base::Vocabulary { |
| public: |
| ProbingVocabulary(); |
| |
| WordIndex Index(const StringPiece &str) const { |
| Lookup::ConstIterator i; |
| return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; |
| } |
| |
| static uint64_t Size(uint64_t entries, float probing_multiplier); |
| // This just unwraps Config to get the probing_multiplier. |
| static uint64_t Size(uint64_t entries, const Config &config); |
| |
| // Vocab words are [0, Bound()). |
| WordIndex Bound() const { return bound_; } |
| |
| // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. |
| void SetupMemory(void *start, std::size_t allocated); |
| void SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { |
| SetupMemory(start, allocated); |
| } |
| |
| void Relocate(void *new_start); |
| |
| void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); |
| |
| WordIndex Insert(const StringPiece &str); |
| |
| template <class Weights> void FinishedLoading(Weights * /*reorder_vocab*/) { |
| InternalFinishedLoading(); |
| } |
| |
| std::size_t UnkCountChangePadding() const { return 0; } |
| |
| bool SawUnk() const { return saw_unk_; } |
| |
| void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); |
| |
| private: |
| void InternalFinishedLoading(); |
| |
| typedef util::ProbingHashTable<ProbingVocabularyEntry, util::IdentityHash> Lookup; |
| |
| Lookup lookup_; |
| |
| WordIndex bound_; |
| |
| bool saw_unk_; |
| |
| EnumerateVocab *enumerate_; |
| |
| detail::ProbingVocabularyHeader *header_; |
| }; |
| |
| void MissingUnknown(const Config &config) throw(SpecialWordMissingException); |
| void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException); |
| |
| template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) throw(SpecialWordMissingException) { |
| if (!vocab.SawUnk()) MissingUnknown(config); |
| if (vocab.BeginSentence() == vocab.NotFound()) MissingSentenceMarker(config, "<s>"); |
| if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>"); |
| } |
| |
| class WriteUniqueWords { |
| public: |
| explicit WriteUniqueWords(int fd) : word_list_(fd) {} |
| |
| void operator()(const StringPiece &word) { |
| word_list_ << word << '\0'; |
| } |
| |
| private: |
| util::FileStream word_list_; |
| }; |
| |
| class NoOpUniqueWords { |
| public: |
| NoOpUniqueWords() {} |
| void operator()(const StringPiece &word) {} |
| }; |
| |
| template <class NewWordAction = NoOpUniqueWords> class GrowableVocab { |
| public: |
| static std::size_t MemUsage(WordIndex content) { |
| return Lookup::MemUsage(content > 2 ? content : 2); |
| } |
| |
| // Does not take ownership of write_wordi |
| template <class NewWordConstruct> GrowableVocab(WordIndex initial_size, const NewWordConstruct &new_word_construct = NewWordAction()) |
| : lookup_(initial_size), new_word_(new_word_construct) { |
| FindOrInsert("<unk>"); // Force 0 |
| FindOrInsert("<s>"); // Force 1 |
| FindOrInsert("</s>"); // Force 2 |
| } |
| |
| WordIndex Index(const StringPiece &str) const { |
| Lookup::ConstIterator i; |
| return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; |
| } |
| |
| WordIndex FindOrInsert(const StringPiece &word) { |
| ProbingVocabularyEntry entry = ProbingVocabularyEntry::Make(util::MurmurHashNative(word.data(), word.size()), Size()); |
| Lookup::MutableIterator it; |
| if (!lookup_.FindOrInsert(entry, it)) { |
| new_word_(word); |
| UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh"); |
| } |
| return it->value; |
| } |
| |
| WordIndex Size() const { return lookup_.Size(); } |
| |
| private: |
| typedef util::AutoProbing<ProbingVocabularyEntry, util::IdentityHash> Lookup; |
| |
| Lookup lookup_; |
| |
| NewWordAction new_word_; |
| }; |
| |
| } // namespace ngram |
| } // namespace lm |
| |
| #endif // LM_VOCAB_H |