blob: bd82fe4c2962db8f48553f64640f6697891c8b3e [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
#include "lm/enumerate_vocab.hh"
#include "lm/model.hh"
#include "lm/left.hh"
#include "lm/state.hh"
#include "util/murmur_hash.hh"
#include <iostream>
#include <string.h>
#include <stdlib.h>
#include <jni.h>
#include <pthread.h>
// Grr. Everybody's compiler is slightly different and I'm trying to not depend on boost.
#include <unordered_set>
#include <vector>
// Verify that jint and lm::ngram::WordIndex are the same size. If this breaks
// for you, there's a need to revise probString.
namespace {
template<bool> struct StaticCheck {
template<> struct StaticCheck<true> {
typedef bool StaticAssertionPassed;
typedef StaticCheck<sizeof(jint) == sizeof(lm::WordIndex)>::StaticAssertionPassed FloatSize;
// Could be uint64_t if you wanted to have 33-bit support.
typedef uint32_t StateIndex;
typedef std::vector<lm::ngram::ChartState> StateVector;
class HashIndex : public std::unary_function<StateIndex, uint64_t> {
explicit HashIndex(const StateVector &vec) : vec_(vec) {}
uint64_t operator()(StateIndex index) const {
return hash_value(vec_[index]);
const StateVector &vec_;
class EqualIndex : public std::binary_function<StateIndex, StateIndex, bool> {
explicit EqualIndex(const StateVector &vec) : vec_(vec) {}
bool operator()(StateIndex first, StateIndex second) const {
return vec_[first] == vec_[second];
const StateVector &vec_;
typedef std::unordered_set<StateIndex, HashIndex, EqualIndex> Lookup;
* A Chart bundles together a unordered_multimap that maps ChartState signatures to a single
* object instantiated using a pool. This allows duplicate states to avoid allocating separate
* state objects at multiple places throughout a sentence, and also allows state to be shared
* across KenLMs for the same sentence. Multimap is used to avoid hash collisions which can
* return incorrect results, and cause out-of-bounds lookups when multiple KenLMs are in use.
class Chart {
Chart() : lookup_(1000, HashIndex(vec_), EqualIndex(vec_)) {}
StateIndex Intern(const lm::ngram::ChartState &state) {
std::pair<Lookup::iterator, bool> ins(lookup_.insert(vec_.size() - 1));
if (!ins.second) {
return *ins.first + 1; // +1 so that the first id is 1, not 0. We use sign bit to
// distinguish ChartState from vocab id.
const lm::ngram::ChartState &InterpretState(StateIndex index) const {
return vec_[index - 1];
StateVector vec_;
Lookup lookup_;
// Vocab ids above what the vocabulary knows about are unknown and should
// be mapped to that.
void MapArray(const std::vector<lm::WordIndex>& map, jint *begin, jint *end) {
for (jint *i = begin; i < end; ++i) {
*i = map[*i];
char *PieceCopy(const StringPiece &str) {
char *ret = (char*) malloc(str.size() + 1);
memcpy(ret,, str.size());
ret[str.size()] = 0;
return ret;
// Rather than handle several different instantiations over JNI, we'll just
// do virtual calls C++-side.
class VirtualBase {
virtual ~VirtualBase() {
// compute/return n-gram probability for array of Joshua word ids
virtual float Prob(jint *begin, jint *end) const = 0;
// Compute/return n-gram probability for array of lm:WordIndexes
virtual float ProbForWordIndexArray(jint *begin, jint *end) const = 0;
// Returns the internal lm::WordIndex for a string
virtual uint GetLmId(const StringPiece& word) const = 0;
virtual bool IsLmOov(const int joshua_id) const = 0;
virtual bool IsKnownWordIndex(const lm::WordIndex& id) const = 0;
virtual float ProbRule(jlong *begin, jlong *end, lm::ngram::ChartState& state, const Chart &chart) const = 0;
virtual float ProbString(jint * const begin, jint * const end,
jint start) const = 0;
virtual float EstimateRule(jlong *begin, jlong *end) const = 0;
virtual uint8_t Order() const = 0;
virtual bool RegisterWord(const StringPiece& word, const int joshua_id) = 0;
VirtualBase() {
template<class Model> class VirtualImpl: public VirtualBase {
VirtualImpl(const char *name) :
m_(name) {
// Insert unknown id mapping.
~VirtualImpl() {
float Prob(jint * const begin, jint * const end) const {
// map Joshua word ids to lm::WordIndexes
MapArray(map_, begin, end);
return ProbForWordIndexArray(begin, end);
float ProbForWordIndexArray(jint * const begin, jint * const end) const {
std::reverse(begin, end - 1);
lm::ngram::State ignored;
return m_.FullScoreForgotState(
reinterpret_cast<const lm::WordIndex*>(begin),
reinterpret_cast<const lm::WordIndex*>(end - 1), *(end - 1),
uint GetLmId(const StringPiece& word) const {
return m_.GetVocabulary().Index(word);
bool IsLmOov(const int joshua_id) const {
if (map_.size() <= joshua_id) {
return true;
return !IsKnownWordIndex(map_[joshua_id]);
bool IsKnownWordIndex(const lm::WordIndex& id) const {
return id != m_.GetVocabulary().NotFound();
float ProbRule(jlong * const begin, jlong * const end, lm::ngram::ChartState& state, const Chart &chart) const {
if (begin == end) return 0.0;
lm::ngram::RuleScore<Model> ruleScore(m_, state);
if (*begin < 0) {
} else {
const lm::WordIndex word = map_[*begin];
if (word == m_.GetVocabulary().BeginSentence()) {
} else {
for (jlong* i = begin + 1; i != end; i++) {
long word = *i;
if (word < 0)
return ruleScore.Finish();
float EstimateRule(jlong * const begin, jlong * const end) const {
if (begin == end) return 0.0;
lm::ngram::ChartState nullState;
lm::ngram::RuleScore<Model> ruleScore(m_, nullState);
if (*begin < 0) {
} else {
const lm::WordIndex word = map_[*begin];
if (word == m_.GetVocabulary().BeginSentence()) {
} else {
for (jlong* i = begin + 1; i != end; i++) {
long word = *i;
if (word < 0)
return ruleScore.Finish();
float ProbString(jint * const begin, jint * const end, jint start) const {
MapArray(map_, begin, end);
float prob;
lm::ngram::State state;
if (start == 0) {
prob = 0;
state = m_.NullContextState();
} else {
std::reverse(begin, begin + start);
prob = m_.FullScoreForgotState(
reinterpret_cast<const lm::WordIndex*>(begin),
reinterpret_cast<const lm::WordIndex*>(begin + start),
begin[start], state).prob;
lm::ngram::State state2;
for (const jint *i = begin + start;;) {
if (i >= end)
float got = m_.Score(state, *i, state2);
prob += got;
if (i >= end)
got = m_.Score(state2, *i, state);
prob += got;
return prob;
uint8_t Order() const {
return m_.Order();
bool RegisterWord(const StringPiece& word, const int joshua_id) {
if (map_.size() <= joshua_id) {
map_.resize(joshua_id + 1, 0);
bool already_present = false;
if (map_[joshua_id] != 0)
already_present = true;
map_[joshua_id] = m_.GetVocabulary().Index(word);
return already_present;
Model m_;
std::vector<lm::WordIndex> map_;
VirtualBase *ConstructModel(const char *file_name) {
using namespace lm::ngram;
ModelType model_type;
if (!RecognizeBinary(file_name, model_type))
model_type = HASH_PROBING;
switch (model_type) {
return new VirtualImpl<ProbingModel>(file_name);
return new VirtualImpl<RestProbingModel>(file_name);
case TRIE:
return new VirtualImpl<TrieModel>(file_name);
return new VirtualImpl<ArrayTrieModel>(file_name);
return new VirtualImpl<QuantTrieModel>(file_name);
return new VirtualImpl<QuantArrayTrieModel>(file_name);
"Unrecognized file format " << (unsigned) model_type
<< " in file " << file_name);
} // namespace
extern "C" {
JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_construct(
JNIEnv *env, jclass, jstring file_name) {
const char *str = env->GetStringUTFChars(file_name, 0);
if (!str)
return 0;
VirtualBase *ret;
try {
ret = ConstructModel(str);
} catch (std::exception &e) {
std::cerr << e.what() << std::endl;
env->ReleaseStringUTFChars(file_name, str);
return reinterpret_cast<jlong>(ret);
JNIEXPORT void JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_destroy(
JNIEnv *env, jclass, jlong pointer) {
delete reinterpret_cast<VirtualBase*>(pointer);
JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_createPool(
JNIEnv *env, jclass) {
return reinterpret_cast<long>(new Chart());
JNIEXPORT void JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_destroyPool(
JNIEnv *env, jclass, jlong pointer) {
delete reinterpret_cast<Chart*>(pointer);
JNIEXPORT jint JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_order(
JNIEnv *env, jclass, jlong pointer) {
return reinterpret_cast<VirtualBase*>(pointer)->Order();
JNIEXPORT jboolean JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_registerWord(
JNIEnv *env, jclass, jlong pointer, jstring word, jint id) {
const char *str = env->GetStringUTFChars(word, 0);
if (!str)
return false;
jint ret;
try {
ret = reinterpret_cast<VirtualBase*>(pointer)->RegisterWord(str, id);
} catch (std::exception &e) {
std::cerr << e.what() << std::endl;
env->ReleaseStringUTFChars(word, str);
return ret;
JNIEXPORT jfloat JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_prob(
JNIEnv *env, jclass, jlong pointer, jintArray arr) {
jint length = env->GetArrayLength(arr);
if (length <= 0)
return 0.0;
// GCC only.
jint values[length];
env->GetIntArrayRegion(arr, 0, length, values);
return reinterpret_cast<const VirtualBase*>(pointer)->Prob(values,
values + length);
JNIEXPORT jfloat JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probForString(
JNIEnv *env, jclass, jlong pointer, jobjectArray arr) {
jint length = env->GetArrayLength(arr);
if (length <= 0)
return 0.0;
jint values[length];
const VirtualBase* lm_base = reinterpret_cast<const VirtualBase*>(pointer);
for (int i=0; i<length; i++) {
jstring word = (jstring) env->GetObjectArrayElement(arr, i);
const char *str = env->GetStringUTFChars(word, 0);
values[i] = lm_base->GetLmId(str);
env->ReleaseStringUTFChars(word, str);
return lm_base->ProbForWordIndexArray(values,
values + length);
JNIEXPORT jboolean JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_isLmOov(
JNIEnv *env, jclass, jlong pointer, jint word) {
const VirtualBase* lm_base = reinterpret_cast<const VirtualBase*>(pointer);
return lm_base->IsLmOov(word);
JNIEXPORT jboolean JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_isKnownWord(
JNIEnv *env, jclass, jlong pointer, jstring word) {
const char *str = env->GetStringUTFChars(word, 0);
if (!str)
return false;
bool ret;
const VirtualBase* lm_base = reinterpret_cast<const VirtualBase*>(pointer);
lm::WordIndex id = lm_base->GetLmId(str);
ret = lm_base->IsKnownWordIndex(id);
env->ReleaseStringUTFChars(word, str);
return ret;
JNIEXPORT jfloat JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probString(
JNIEnv *env, jclass, jlong pointer, jintArray arr, jint start) {
jint length = env->GetArrayLength(arr);
if (length <= start)
return 0.0;
// GCC only.
jint values[length];
env->GetIntArrayRegion(arr, 0, length, values);
return reinterpret_cast<const VirtualBase*>(pointer)->ProbString(values,
values + length, start);
union FloatConverter {
float f;
uint32_t i;
JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probRule(
JNIEnv *env, jclass, jlong pointer, jlong chartPtr, jlongArray arr) {
jint length = env->GetArrayLength(arr);
// GCC only.
jlong values[length];
env->GetLongArrayRegion(arr, 0, length, values);
// Compute the probability
lm::ngram::ChartState outState;
const VirtualBase *base = reinterpret_cast<const VirtualBase*>(pointer);
Chart* chart = reinterpret_cast<Chart*>(chartPtr);
FloatConverter prob;
prob.f = base->ProbRule(values, values + length, outState, *chart);
StateIndex index = chart->Intern(outState);
return static_cast<uint64_t>(index) << 32 | static_cast<uint64_t>(prob.i);
JNIEXPORT jfloat JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_estimateRule(
JNIEnv *env, jclass, jlong pointer, jlongArray arr) {
jint length = env->GetArrayLength(arr);
// GCC only.
jlong values[length];
env->GetLongArrayRegion(arr, 0, length, values);
// Compute the probability
return reinterpret_cast<const VirtualBase*>(pointer)->EstimateRule(values,
values + length);
} // extern