OPENNLP-1556 Improve speed of checksum computation in TwoPassDataIndexer (#600)
- adjusts TwoPassDataIndexer to make use of JDK's built-in CheckedOutputStream / CheckedInputStream for checksum (CRC32c) computations
- removes untested class HashSumEventStream which is just a wrapper for calling a slow toString() in Event to get some bytes to use for the computation of a checksum
- provides a HashSumEventStream replacement: ChecksumEventStream which makes use of the faster CRC32c checksum computation, avoiding cryptographic hash functions such as MD5
- adds JUnit tests for ChecksumEventStream
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java
index d546739..9ea5ddc 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java
@@ -20,10 +20,10 @@
import java.io.IOException;
import opennlp.tools.ml.model.AbstractDataIndexer;
+import opennlp.tools.ml.model.ChecksumEventStream;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.DataIndexerFactory;
import opennlp.tools.ml.model.Event;
-import opennlp.tools.ml.model.HashSumEventStream;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.util.InsufficientTrainingDataException;
import opennlp.tools.util.ObjectStream;
@@ -85,10 +85,10 @@
public final MaxentModel train(ObjectStream<Event> events) throws IOException {
validate();
- HashSumEventStream hses = new HashSumEventStream(events);
+ ChecksumEventStream hses = new ChecksumEventStream(events);
DataIndexer indexer = getDataIndexer(hses);
- addToReport("Training-Eventhash", hses.calculateHashSum().toString(16));
+ addToReport("Training-Eventhash", String.valueOf(hses.calculateChecksum()));
return train(indexer);
}
}
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/model/ChecksumEventStream.java b/opennlp-tools/src/main/java/opennlp/tools/ml/model/ChecksumEventStream.java
new file mode 100644
index 0000000..52af8ed
--- /dev/null
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/model/ChecksumEventStream.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml.model;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.zip.CRC32C;
+import java.util.zip.Checksum;
+
+import opennlp.tools.util.AbstractObjectStream;
+import opennlp.tools.util.ObjectStream;
+
+/**
+ * A {@link Checksum}-based {@link AbstractObjectStream event stream} implementation.
+ * Computes the checksum while consuming the event stream.
+ * By default, this implementation will use {@link CRC32C} for checksum calculations
+ * as it can use of CPU-specific acceleration instructions at runtime.
+ *
+ * @see Event
+ * @see Checksum
+ * @see AbstractObjectStream
+ */
+public class ChecksumEventStream extends AbstractObjectStream<Event> {
+
+ private final Checksum checksum;
+
+
+ /**
+ * Initializes an {@link ChecksumEventStream}.
+ *
+ * @param eventStream The {@link ObjectStream} that provides the {@link Event} samples.
+ */
+ public ChecksumEventStream(ObjectStream<Event> eventStream) {
+ super(eventStream);
+ // CRC32C supports CPU-specific acceleration instructions
+ checksum = new CRC32C();
+ }
+
+ @Override
+ public Event read() throws IOException {
+ Event event = super.read();
+ if (event != null) {
+ checksum.update(event.toString().getBytes(StandardCharsets.UTF_8));
+ }
+ return event;
+ }
+
+ /**
+ * Calculates and returns the (current) checksum.
+ * <p>
+ * Note: This should be called once the underlying stream has been (fully) consumed.
+ *
+ * @return The calculated checksum as {@code long}.
+ */
+ public long calculateChecksum() {
+ return checksum.getValue();
+ }
+}
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/model/HashSumEventStream.java b/opennlp-tools/src/main/java/opennlp/tools/ml/model/HashSumEventStream.java
deleted file mode 100644
index 6fafb24..0000000
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/model/HashSumEventStream.java
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package opennlp.tools.ml.model;
-
-import java.io.IOException;
-import java.math.BigInteger;
-import java.nio.charset.StandardCharsets;
-import java.security.MessageDigest;
-import java.security.NoSuchAlgorithmException;
-
-import opennlp.tools.util.AbstractObjectStream;
-import opennlp.tools.util.ObjectStream;
-
-/**
- * A hash sum based {@link AbstractObjectStream} implementation.
- *
- * @see Event
- * @see MessageDigest
- * @see AbstractObjectStream
- */
-public class HashSumEventStream extends AbstractObjectStream<Event> {
-
- private final MessageDigest digest;
-
- public HashSumEventStream(ObjectStream<Event> eventStream) {
- super(eventStream);
-
- try {
- digest = MessageDigest.getInstance("MD5");
- } catch (NoSuchAlgorithmException e) {
- // should never happen: do all java runtimes have md5 ?!
- throw new IllegalStateException(e);
- }
- }
-
- @Override
- public Event read() throws IOException {
- Event event = super.read();
-
- if (event != null) {
- digest.update(event.toString().getBytes(StandardCharsets.UTF_8));
- }
-
- return event;
- }
-
- /**
- * Calculates the hash sum of the stream and wraps it into a {@link BigInteger}.
- * Note: The method must be called after the stream is completely consumed.
- *
- * @return The calculated hash sum as {@link BigInteger}.
- * @throws IllegalStateException Thrown if the stream is not consumed completely,
- * completely means that hasNext() returns {@code false}.
- */
- public BigInteger calculateHashSum() {
- return new BigInteger(1, digest.digest());
- }
-
-}
diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/model/TwoPassDataIndexer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/model/TwoPassDataIndexer.java
index dd67dc2..0e49a4b 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/ml/model/TwoPassDataIndexer.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/ml/model/TwoPassDataIndexer.java
@@ -17,7 +17,6 @@
package opennlp.tools.ml.model;
-
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
@@ -26,11 +25,13 @@
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
-import java.math.BigInteger;
import java.nio.file.Files;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.zip.CRC32C;
+import java.util.zip.CheckedInputStream;
+import java.util.zip.CheckedOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -63,48 +64,50 @@
int cutoff = trainingParameters.getIntParameter(CUTOFF_PARAM, CUTOFF_DEFAULT);
boolean sort = trainingParameters.getBooleanParameter(SORT_PARAM, SORT_DEFAULT);
- long start = System.currentTimeMillis();
-
logger.info("Indexing events with TwoPass using cutoff of {}", cutoff);
-
logger.info("Computing event counts...");
+ long start = System.currentTimeMillis();
Map<String,Integer> predicateIndex = new HashMap<>();
-
File tmp = Files.createTempFile("events", null).toFile();
tmp.deleteOnExit();
int numEvents;
- BigInteger writeHash;
- HashSumEventStream writeEventStream = new HashSumEventStream(eventStream); // do not close.
- try (DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(tmp)))) {
- numEvents = computeEventCounts(writeEventStream, dos, predicateIndex, cutoff);
- }
- writeHash = writeEventStream.calculateHashSum();
+ long writeChecksum;
- logger.info("done. {} events", numEvents);
- logger.info("Indexing...");
+ try (BufferedOutputStream out = new BufferedOutputStream(new FileOutputStream(tmp));
+ CheckedOutputStream writeStream = new CheckedOutputStream(out, new CRC32C());
+ DataOutputStream dos = new DataOutputStream(writeStream)) {
+
+ numEvents = computeEventCounts(eventStream, dos, predicateIndex, cutoff);
+ writeChecksum = writeStream.getChecksum().getValue();
+ logger.info("done. {} events", numEvents);
+ }
List<ComparableEvent> eventsToCompare;
- BigInteger readHash = null;
- try (HashSumEventStream readStream = new HashSumEventStream(new EventStream(tmp))) {
- eventsToCompare = index(readStream, predicateIndex);
- readHash = readStream.calculateHashSum();
+ long readChecksum;
+ try (BufferedInputStream in = new BufferedInputStream(new FileInputStream(tmp));
+ CheckedInputStream readStream = new CheckedInputStream(in, new CRC32C());
+ EventStream readEventsStream = new EventStream(new DataInputStream(readStream))) {
+ logger.info("Indexing...");
+ eventsToCompare = index(readEventsStream, predicateIndex);
+ readChecksum = readStream.getChecksum().getValue();
}
tmp.delete();
- if (readHash.compareTo(writeHash) != 0)
- throw new IOException("Event hash for writing and reading events did not match.");
+ if (readChecksum != writeChecksum) {
+ throw new IOException("Checksum for writing and reading events did not match.");
+ } else {
+ logger.info("done.");
- logger.info("done.");
-
- if (sort) {
- logger.info("Sorting and merging events... ");
+ if (sort) {
+ logger.info("Sorting and merging events... ");
+ }
+ else {
+ logger.info("Collecting events... ");
+ }
+ sortAndMerge(eventsToCompare,sort);
+ logger.info(String.format("Done indexing in %.2f s.", (System.currentTimeMillis() - start) / 1000d));
}
- else {
- logger.info("Collecting events... ");
- }
- sortAndMerge(eventsToCompare,sort);
- logger.info(String.format("Done indexing in %.2f s.", (System.currentTimeMillis() - start) / 1000d));
}
/**
@@ -170,8 +173,8 @@
private final DataInputStream inputStream;
- public EventStream(File file) throws IOException {
- inputStream = new DataInputStream(new BufferedInputStream(new FileInputStream(file)));
+ public EventStream(DataInputStream dataInputStream) {
+ this.inputStream = dataInputStream;
}
@Override
diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/model/ChecksumEventStreamTest.java b/opennlp-tools/src/test/java/opennlp/tools/ml/model/ChecksumEventStreamTest.java
new file mode 100644
index 0000000..95d38b2
--- /dev/null
+++ b/opennlp-tools/src/test/java/opennlp/tools/ml/model/ChecksumEventStreamTest.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml.model;
+
+import java.io.IOException;
+
+import org.junit.jupiter.api.Test;
+
+import opennlp.tools.util.ObjectStream;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class ChecksumEventStreamTest {
+
+ @Test
+ void testCalculateChecksumEquality() throws IOException {
+ ChecksumEventStream ces1 = new ChecksumEventStream(createEventStreamFull());
+ ChecksumEventStream ces2 = new ChecksumEventStream(createEventStreamFull());
+ consumeEventStream(ces1, 7);
+ consumeEventStream(ces2, 7);
+
+ long checksum1 = ces1.calculateChecksum();
+ long checksum2 = ces2.calculateChecksum();
+ assertTrue(checksum1 != 0);
+ assertTrue(checksum2 != 0);
+ assertEquals(checksum1, checksum2);
+ }
+
+ @Test
+ void testCalculateChecksumMismatch() throws IOException {
+ ChecksumEventStream ces1 = new ChecksumEventStream(createEventStreamFull());
+ ChecksumEventStream ces2 = new ChecksumEventStream(createEventStreamPartial());
+ consumeEventStream(ces1, 7);
+ consumeEventStream(ces2, 2);
+
+ long checksum1 = ces1.calculateChecksum();
+ long checksum2 = ces2.calculateChecksum();
+ assertTrue(checksum1 != 0);
+ assertTrue(checksum2 != 0);
+ assertNotEquals(checksum1, checksum2);
+ }
+
+ private ObjectStream<Event> createEventStreamFull() {
+ // He belongs to <START:org> Apache Software Foundation <END> .
+ return new SimpleEventStreamBuilder()
+ .add("other/w=he n1w=belongs n2w=to po=other pow=other,He powf=other,ic ppo=other")
+ .add("other/w=belongs p1w=he n1w=to n2w=apache po=other pow=other,belongs powf=other,lc ppo=other")
+ .add("other/w=to p1w=belongs p2w=he n1w=apache n2w=software po=other pow=other,to" +
+ " powf=other,lc ppo=other")
+ .add("org-start/w=apache p1w=to p2w=belongs n1w=software n2w=foundation po=other pow=other,Apache" +
+ " powf=other,ic ppo=other")
+ .add("org-cont/w=software p1w=apache p2w=to n1w=foundation n2w=. po=org-start" +
+ " pow=org-start,Software powf=org-start,ic ppo=other")
+ .add("org-cont/w=foundation p1w=software p2w=apache n1w=. po=org-cont pow=org-cont,Foundation" +
+ " powf=org-cont,ic ppo=org-start")
+ .add("other/w=. p1w=foundation p2w=software po=org-cont pow=org-cont,. powf=org-cont,other" +
+ " ppo=org-cont")
+ .build();
+ }
+
+ private ObjectStream<Event> createEventStreamPartial() {
+ // He .
+ return new SimpleEventStreamBuilder()
+ .add("other/w=he n1w=belongs n2w=to po=other pow=other,He powf=other,ic ppo=other")
+ .add("other/w=. p1w=foundation p2w=software po=org-cont pow=org-cont,. powf=org-cont,other" +
+ " ppo=org-cont")
+ .build();
+ }
+
+ private void consumeEventStream(ObjectStream<Event> eventStream, int eventCount) throws IOException {
+ for (int i = 0; i < eventCount; i++) {
+ assertNotNull(eventStream.read());
+ }
+ }
+}