blob: 445b57be1e2dcd5fb5b979ef58af2804c1b14d81 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "lm/enumerate_vocab.hh"
#include "lm/model.hh"
#include "lm/left.hh"
#include "lm/state.hh"
#include "util/murmur_hash.hh"
#include <iostream>
#include <string.h>
#include <stdlib.h>
#include <jni.h>
#include <pthread.h>
// Grr. Everybody's compiler is slightly different and I'm trying to not depend on boost.
#include <unordered_set>
#include <vector>
// Verify that jint and lm::ngram::WordIndex are the same size. If this breaks
// for you, there's a need to revise probString.
namespace {
template<bool> struct StaticCheck {
};
template<> struct StaticCheck<true> {
typedef bool StaticAssertionPassed;
};
typedef StaticCheck<sizeof(jint) == sizeof(lm::WordIndex)>::StaticAssertionPassed FloatSize;
// Could be uint64_t if you wanted to have 33-bit support.
typedef uint32_t StateIndex;
typedef std::vector<lm::ngram::ChartState> StateVector;
class HashIndex : public std::unary_function<StateIndex, uint64_t> {
public:
explicit HashIndex(const StateVector &vec) : vec_(vec) {}
uint64_t operator()(StateIndex index) const {
return hash_value(vec_[index]);
}
private:
const StateVector &vec_;
};
class EqualIndex : public std::binary_function<StateIndex, StateIndex, bool> {
public:
explicit EqualIndex(const StateVector &vec) : vec_(vec) {}
bool operator()(StateIndex first, StateIndex second) const {
return vec_[first] == vec_[second];
}
private:
const StateVector &vec_;
};
typedef std::unordered_set<StateIndex, HashIndex, EqualIndex> Lookup;
/**
* A Chart bundles together a vector holding CharStates and an unordered_set of StateIndexes
* which provides a mapping between StateIndexes and the positions of ChartStates in the vector.
* This allows for duplicate states to avoid allocating separate state objects at multiple places
* throughout a sentence.
*/
class Chart {
public:
Chart(long* ngramBuffer) :
ngramBuffer_(ngramBuffer),
lookup_(1000, HashIndex(vec_), EqualIndex(vec_)) {}
StateIndex Intern(const lm::ngram::ChartState &state) {
vec_.push_back(state);
std::pair<Lookup::iterator, bool> ins(lookup_.insert(vec_.size() - 1));
if (!ins.second) {
vec_.pop_back();
}
return *ins.first + 1; // +1 so that the first id is 1, not 0. We use sign bit to
// distinguish ChartState from vocab id.
}
const lm::ngram::ChartState &InterpretState(StateIndex index) const {
return vec_[index - 1];
}
long* ngramBuffer_;
private:
StateVector vec_;
Lookup lookup_;
};
// Vocab ids above what the vocabulary knows about are unknown and should
// be mapped to that.
void MapArray(const std::vector<lm::WordIndex>& map, jint *begin, jint *end) {
for (jint *i = begin; i < end; ++i) {
*i = map[*i];
}
}
char *PieceCopy(const StringPiece &str) {
char *ret = (char*) malloc(str.size() + 1);
memcpy(ret, str.data(), str.size());
ret[str.size()] = 0;
return ret;
}
// Rather than handle several different instantiations over JNI, we'll just
// do virtual calls C++-side.
class VirtualBase {
public:
virtual ~VirtualBase() {
}
// compute/return n-gram probability for array of Joshua word ids
virtual float Prob(jint *begin, jint *end) const = 0;
// Compute/return n-gram probability for array of lm:WordIndexes
virtual float ProbForWordIndexArray(jint *begin, jint *end) const = 0;
// Returns the internal lm::WordIndex for a string
virtual uint GetLmId(const StringPiece& word) const = 0;
virtual bool IsLmOov(const int joshua_id) const = 0;
virtual bool IsKnownWordIndex(const lm::WordIndex& id) const = 0;
virtual float ProbRule(lm::ngram::ChartState& state, const Chart &chart) const = 0;
virtual float ProbString(jint * const begin, jint * const end,
jint start) const = 0;
virtual float EstimateRule(jlong *begin, jlong *end) const = 0;
virtual uint8_t Order() const = 0;
virtual bool RegisterWord(const StringPiece& word, const int joshua_id) = 0;
protected:
VirtualBase() {
}
};
template<class Model> class VirtualImpl: public VirtualBase {
public:
VirtualImpl(const char *name) :
m_(name) {
// Insert unknown id mapping.
map_.push_back(0);
}
~VirtualImpl() {
}
float Prob(jint * const begin, jint * const end) const {
// map Joshua word ids to lm::WordIndexes
MapArray(map_, begin, end);
return ProbForWordIndexArray(begin, end);
}
float ProbForWordIndexArray(jint * const begin, jint * const end) const {
std::reverse(begin, end - 1);
lm::ngram::State ignored;
return m_.FullScoreForgotState(
reinterpret_cast<const lm::WordIndex*>(begin),
reinterpret_cast<const lm::WordIndex*>(end - 1), *(end - 1),
ignored).prob;
}
uint GetLmId(const StringPiece& word) const {
return m_.GetVocabulary().Index(word);
}
bool IsLmOov(const int joshua_id) const {
if (map_.size() <= joshua_id) {
return true;
}
return !IsKnownWordIndex(map_[joshua_id]);
}
bool IsKnownWordIndex(const lm::WordIndex& id) const {
return id != m_.GetVocabulary().NotFound();
}
float ProbRule(lm::ngram::ChartState& state, const Chart &chart) const {
// By convention the first long in the ngramBuffer denotes the size of the buffer
long* begin = chart.ngramBuffer_ + 1;
long* end = begin + *chart.ngramBuffer_;
if (begin == end) return 0.0;
lm::ngram::RuleScore<Model> ruleScore(m_, state);
if (*begin < 0) {
ruleScore.BeginNonTerminal(chart.InterpretState(-*begin));
} else {
const lm::WordIndex word = map_[*begin];
if (word == m_.GetVocabulary().BeginSentence()) {
ruleScore.BeginSentence();
} else {
ruleScore.Terminal(word);
}
}
for (jlong* i = begin + 1; i != end; i++) {
long word = *i;
if (word < 0)
ruleScore.NonTerminal(chart.InterpretState(-word));
else
ruleScore.Terminal(map_[word]);
}
return ruleScore.Finish();
}
float EstimateRule(jlong * const begin, jlong * const end) const {
if (begin == end) return 0.0;
lm::ngram::ChartState nullState;
lm::ngram::RuleScore<Model> ruleScore(m_, nullState);
if (*begin < 0) {
ruleScore.Reset();
} else {
const lm::WordIndex word = map_[*begin];
if (word == m_.GetVocabulary().BeginSentence()) {
ruleScore.BeginSentence();
} else {
ruleScore.Terminal(word);
}
}
for (jlong* i = begin + 1; i != end; i++) {
long word = *i;
if (word < 0)
ruleScore.Reset();
else
ruleScore.Terminal(map_[word]);
}
return ruleScore.Finish();
}
float ProbString(jint * const begin, jint * const end, jint start) const {
MapArray(map_, begin, end);
float prob;
lm::ngram::State state;
if (start == 0) {
prob = 0;
state = m_.NullContextState();
} else {
std::reverse(begin, begin + start);
prob = m_.FullScoreForgotState(
reinterpret_cast<const lm::WordIndex*>(begin),
reinterpret_cast<const lm::WordIndex*>(begin + start),
begin[start], state).prob;
++start;
}
lm::ngram::State state2;
for (const jint *i = begin + start;;) {
if (i >= end)
break;
float got = m_.Score(state, *i, state2);
i++;
prob += got;
if (i >= end)
break;
got = m_.Score(state2, *i, state);
i++;
prob += got;
}
return prob;
}
uint8_t Order() const {
return m_.Order();
}
bool RegisterWord(const StringPiece& word, const int joshua_id) {
if (map_.size() <= joshua_id) {
map_.resize(joshua_id + 1, 0);
}
bool already_present = false;
if (map_[joshua_id] != 0)
already_present = true;
map_[joshua_id] = m_.GetVocabulary().Index(word);
return already_present;
}
private:
Model m_;
std::vector<lm::WordIndex> map_;
};
VirtualBase *ConstructModel(const char *file_name) {
using namespace lm::ngram;
ModelType model_type;
if (!RecognizeBinary(file_name, model_type))
model_type = HASH_PROBING;
switch (model_type) {
case PROBING:
return new VirtualImpl<ProbingModel>(file_name);
case REST_PROBING:
return new VirtualImpl<RestProbingModel>(file_name);
case TRIE:
return new VirtualImpl<TrieModel>(file_name);
case ARRAY_TRIE:
return new VirtualImpl<ArrayTrieModel>(file_name);
case QUANT_TRIE:
return new VirtualImpl<QuantTrieModel>(file_name);
case QUANT_ARRAY_TRIE:
return new VirtualImpl<QuantArrayTrieModel>(file_name);
default:
UTIL_THROW(
lm::FormatLoadException,
"Unrecognized file format " << (unsigned) model_type
<< " in file " << file_name);
}
}
} // namespace
extern "C" {
JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_construct(
JNIEnv *env, jclass, jstring file_name) {
const char *str = env->GetStringUTFChars(file_name, 0);
if (!str)
return 0;
VirtualBase *ret;
try {
ret = ConstructModel(str);
} catch (std::exception &e) {
std::cerr << e.what() << std::endl;
abort();
}
env->ReleaseStringUTFChars(file_name, str);
return reinterpret_cast<jlong>(ret);
}
JNIEXPORT void JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_destroy(
JNIEnv *env, jclass, jlong pointer) {
delete reinterpret_cast<VirtualBase*>(pointer);
}
JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_createPool(
JNIEnv *env, jclass, jobject arr) {
jlong* ngramBuffer = (jlong*)env->GetDirectBufferAddress(arr);
Chart *newChart = new Chart(ngramBuffer);
return reinterpret_cast<long>(newChart);
}
JNIEXPORT void JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_destroyPool(
JNIEnv *env, jclass, jlong pointer) {
delete reinterpret_cast<Chart*>(pointer);
}
JNIEXPORT jint JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_order(
JNIEnv *env, jclass, jlong pointer) {
return reinterpret_cast<VirtualBase*>(pointer)->Order();
}
JNIEXPORT jboolean JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_registerWord(
JNIEnv *env, jclass, jlong pointer, jstring word, jint id) {
const char *str = env->GetStringUTFChars(word, 0);
if (!str)
return false;
jint ret;
try {
ret = reinterpret_cast<VirtualBase*>(pointer)->RegisterWord(str, id);
} catch (std::exception &e) {
std::cerr << e.what() << std::endl;
abort();
}
env->ReleaseStringUTFChars(word, str);
return ret;
}
JNIEXPORT jfloat JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_prob(
JNIEnv *env, jclass, jlong pointer, jintArray arr) {
jint length = env->GetArrayLength(arr);
if (length <= 0)
return 0.0;
// GCC only.
jint values[length];
env->GetIntArrayRegion(arr, 0, length, values);
return reinterpret_cast<const VirtualBase*>(pointer)->Prob(values,
values + length);
}
JNIEXPORT jfloat JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probForString(
JNIEnv *env, jclass, jlong pointer, jobjectArray arr) {
jint length = env->GetArrayLength(arr);
if (length <= 0)
return 0.0;
jint values[length];
const VirtualBase* lm_base = reinterpret_cast<const VirtualBase*>(pointer);
for (int i=0; i<length; i++) {
jstring word = (jstring) env->GetObjectArrayElement(arr, i);
const char *str = env->GetStringUTFChars(word, 0);
values[i] = lm_base->GetLmId(str);
env->ReleaseStringUTFChars(word, str);
}
return lm_base->ProbForWordIndexArray(values,
values + length);
}
JNIEXPORT jboolean JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_isLmOov(
JNIEnv *env, jclass, jlong pointer, jint word) {
const VirtualBase* lm_base = reinterpret_cast<const VirtualBase*>(pointer);
return lm_base->IsLmOov(word);
}
JNIEXPORT jboolean JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_isKnownWord(
JNIEnv *env, jclass, jlong pointer, jstring word) {
const char *str = env->GetStringUTFChars(word, 0);
if (!str)
return false;
bool ret;
const VirtualBase* lm_base = reinterpret_cast<const VirtualBase*>(pointer);
lm::WordIndex id = lm_base->GetLmId(str);
ret = lm_base->IsKnownWordIndex(id);
env->ReleaseStringUTFChars(word, str);
return ret;
}
JNIEXPORT jfloat JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probString(
JNIEnv *env, jclass, jlong pointer, jintArray arr, jint start) {
jint length = env->GetArrayLength(arr);
if (length <= start)
return 0.0;
// GCC only.
jint values[length];
env->GetIntArrayRegion(arr, 0, length, values);
return reinterpret_cast<const VirtualBase*>(pointer)->ProbString(values,
values + length, start);
}
union FloatConverter {
float f;
uint32_t i;
};
JNIEXPORT jlong JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_probRule(
JNIEnv *env, jclass, jlong pointer, jlong chartPtr) {
// Compute the probability
lm::ngram::ChartState outState;
const VirtualBase *base = reinterpret_cast<const VirtualBase*>(pointer);
Chart* chart = reinterpret_cast<Chart*>(chartPtr);
FloatConverter prob;
prob.f = base->ProbRule(outState, *chart);
StateIndex index = chart->Intern(outState);
return static_cast<uint64_t>(index) << 32 | static_cast<uint64_t>(prob.i);
}
JNIEXPORT jfloat JNICALL Java_org_apache_joshua_decoder_ff_lm_KenLM_estimateRule(
JNIEnv *env, jclass, jlong pointer, jlongArray arr) {
jint length = env->GetArrayLength(arr);
// GCC only.
jlong values[length];
env->GetLongArrayRegion(arr, 0, length, values);
// Compute the probability
return reinterpret_cast<const VirtualBase*>(pointer)->EstimateRule(values,
values + length);
}
} // extern