OPENNLP-1375: Adding option for GPU inference. (#421)
* OPENNLP-1375: Adding option for GPU inference.
diff --git a/opennlp-dl/pom.xml b/opennlp-dl/pom.xml
index c3c4a81..db9159a 100644
--- a/opennlp-dl/pom.xml
+++ b/opennlp-dl/pom.xml
@@ -38,7 +38,8 @@
</dependency>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
- <artifactId>onnxruntime</artifactId>
+ <!-- This dependency supports CPU and GPU -->
+ <artifactId>onnxruntime_gpu</artifactId>
<version>${onnxruntime.version}</version>
</dependency>
<dependency>
diff --git a/opennlp-dl/src/main/java/opennlp/dl/Inference.java b/opennlp-dl/src/main/java/opennlp/dl/Inference.java
index 66ac9b9..03122f0 100644
--- a/opennlp-dl/src/main/java/opennlp/dl/Inference.java
+++ b/opennlp-dl/src/main/java/opennlp/dl/Inference.java
@@ -62,7 +62,13 @@
throws OrtException, IOException {
this.env = OrtEnvironment.getEnvironment();
- this.session = env.createSession(model.getPath(), new OrtSession.SessionOptions());
+
+ final OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
+ if (inferenceOptions.isGpu()) {
+ sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
+ }
+
+ this.session = env.createSession(model.getPath(), sessionOptions);
this.vocabulary = loadVocab(vocab);
this.tokenizer = new WordpieceTokenizer(vocabulary.keySet());
this.inferenceOptions = inferenceOptions;
diff --git a/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java b/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java
index 99d3c83..9241a20 100644
--- a/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java
+++ b/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java
@@ -19,25 +19,41 @@
public class InferenceOptions {
- private boolean includeAttentionMask;
- private boolean includeTokenTypeIds;
-
- public InferenceOptions() {
- this.includeAttentionMask = true;
- this.includeTokenTypeIds = true;
- }
-
- public InferenceOptions(boolean includeAttentionMask, boolean includeTokenTypeIds) {
- this.includeAttentionMask = includeAttentionMask;
- this.includeTokenTypeIds = includeTokenTypeIds;
- }
+ private boolean includeAttentionMask = true;
+ private boolean includeTokenTypeIds = true;
+ private boolean gpu;
+ private int gpuDeviceId = 0;
public boolean isIncludeAttentionMask() {
return includeAttentionMask;
}
+ public void setIncludeAttentionMask(boolean includeAttentionMask) {
+ this.includeAttentionMask = includeAttentionMask;
+ }
+
public boolean isIncludeTokenTypeIds() {
return includeTokenTypeIds;
}
+ public void setIncludeTokenTypeIds(boolean includeTokenTypeIds) {
+ this.includeTokenTypeIds = includeTokenTypeIds;
+ }
+
+ public boolean isGpu() {
+ return gpu;
+ }
+
+ public void setGpu(boolean gpu) {
+ this.gpu = gpu;
+ }
+
+ public int getGpuDeviceId() {
+ return gpuDeviceId;
+ }
+
+ public void setGpuDeviceId(int gpuDeviceId) {
+ this.gpuDeviceId = gpuDeviceId;
+ }
+
}
diff --git a/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java b/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
index a2d5847..577ef2c 100644
--- a/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
+++ b/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
@@ -24,6 +24,7 @@
import java.util.Set;
import org.junit.Assert;
+import org.junit.Ignore;
import org.junit.Test;
import opennlp.dl.AbstactDLTest;
@@ -60,6 +61,40 @@
}
+ @Ignore("This test will only run if a GPU device is present.")
+ @Test
+ public void categorizeWithGpu() throws Exception {
+
+ final File model = new File(getOpennlpDataDir(),
+ "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.onnx");
+ final File vocab = new File(getOpennlpDataDir(),
+ "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab");
+
+ final InferenceOptions inferenceOptions = new InferenceOptions();
+ inferenceOptions.setGpu(true);
+ inferenceOptions.setGpuDeviceId(0);
+
+ final DocumentCategorizerDL documentCategorizerDL =
+ new DocumentCategorizerDL(model, vocab, getCategories(), inferenceOptions);
+
+ final double[] result = documentCategorizerDL.categorize(new String[]{"I am happy"});
+ System.out.println(Arrays.toString(result));
+
+ final double[] expected = new double[]
+ {0.007819971069693565,
+ 0.006593209225684404,
+ 0.04995147883892059,
+ 0.3003573715686798,
+ 0.6352779865264893};
+
+ Assert.assertTrue(Arrays.equals(expected, result));
+ Assert.assertEquals(5, result.length);
+
+ final String category = documentCategorizerDL.getBestCategory(result);
+ Assert.assertEquals("very good", category);
+
+ }
+
@Test
public void categorizeWithInferenceOptions() throws Exception {
@@ -68,8 +103,8 @@
final File vocab = new File(getOpennlpDataDir(),
"onnx/doccat/lvwerra_distilbert-imdb.vocab");
- final InferenceOptions inferenceOptions =
- new InferenceOptions(true, false);
+ final InferenceOptions inferenceOptions = new InferenceOptions();
+ inferenceOptions.setIncludeTokenTypeIds(false);
final Map<Integer, String> categories = new HashMap<>();
categories.put(0, "negative");