blob: c8fae3bc1987a379a815c259de55c3360784e6fb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.opennlp.namefinder;
import java.util.Arrays;
import org.tensorflow.Tensor;
public class FeedDictionary implements AutoCloseable {
static int PAD_VALUE = 0;
private final Tensor<Float> dropoutTensor;
private final Tensor<Integer> charIdsTensor;
private final Tensor<Integer> wordLengthsTensor;
private final Tensor<Integer> wordIdsTensor;
private final int[] sentenceLengths;
private final Tensor<Integer> sentenceLengthsTensor;
private final int maxSentenceLength;
private final int maxCharLength;
private final int numberOfSentences;
public int[] getSentenceLengths() {
return sentenceLengths;
}
public int getMaxSentenceLength() {
return maxSentenceLength;
}
public int getNumberOfSentences() {
return numberOfSentences;
}
public Tensor<Float> getDropoutTensor() {
return dropoutTensor;
}
public Tensor<Integer> getCharIdsTensor() {
return charIdsTensor;
}
public Tensor<Integer> getSentenceLengthsTensor() {
return sentenceLengthsTensor;
}
public Tensor<Integer> getWordLengthsTensor() {
return wordLengthsTensor;
}
public Tensor<Integer> getWordIdsTensor() {
return wordIdsTensor;
}
private FeedDictionary(final float dropout,
final int[][][] charIds,
final int[][] wordLengths,
final int[][] wordIds,
final int[] sentenceLengths,
final int maxSentenceLength,
final int maxCharLength,
final int numberOfSentences) {
dropoutTensor = Tensor.create(dropout, Float.class);
charIdsTensor = Tensor.create(charIds, Integer.class);
wordLengthsTensor = Tensor.create(wordLengths, Integer.class);
wordIdsTensor = Tensor.create(wordIds, Integer.class);
this.sentenceLengths = sentenceLengths;
sentenceLengthsTensor = Tensor.create(sentenceLengths, Integer.class);
this.maxSentenceLength = maxSentenceLength;
this.maxCharLength = maxCharLength;
this.numberOfSentences = numberOfSentences;
}
public void close() {
dropoutTensor.close();
charIdsTensor.close();
wordLengthsTensor.close();
wordIdsTensor.close();
sentenceLengthsTensor.close();
}
// multi sentences
public static FeedDictionary create(TokenIds sentences) {
int numberOfSentences = sentences.getWordIds().length;
int[][][] charIds = new int[numberOfSentences][][];
int[][] wordLengths = new int[numberOfSentences][];
int maxSentenceLength = Arrays.stream(sentences.getWordIds()).map(s -> s.length).reduce(Integer::max).get();
Padded paddedSentences = padArrays(sentences.getWordIds(), maxSentenceLength);
int[][] wordIds = paddedSentences.ids;
int[] sentenceLengths = paddedSentences.lengths;
int maxCharLength = Arrays.stream(sentences.getCharIds()).flatMap(s -> Arrays.stream(s).map(c -> c.length)).reduce(Integer::max).get();
for (int i=0; i < numberOfSentences; i++) {
Padded paddedWords = padArrays(sentences.getCharIds()[i], maxCharLength);
charIds[i] = paddedWords.ids;
wordLengths[i] = paddedWords.lengths;
}
return new FeedDictionary(1.0f, charIds, wordLengths, wordIds, sentenceLengths, maxSentenceLength, maxCharLength, numberOfSentences);
}
private static Padded padArrays(int[][] ids, int length) {
int[][] paddedIds = new int[ids.length][length];
int[] lengths = new int[ids.length];
for (int i = 0; i < ids.length; i++) {
int[] src = ids[i];
int[] dest = new int[length];
System.arraycopy(src, 0, dest, 0, src.length);
if (src.length < length)
Arrays.fill(dest, src.length, length, PAD_VALUE);
paddedIds[i] = dest;
lengths[i] = src.length;
}
return new Padded(paddedIds, lengths);
}
private static class Padded {
Padded(int[][] ids, int[] lengths) {
this.ids = ids;
this.lengths = lengths;
}
private int[][] ids;
private int[] lengths;
}
}