Merge branch 'main' into java_21
diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 5fae1e4..418b302 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -200,8 +200,8 @@
* GITHUB#12677: Better detect vector module in non-default setups (e.g., custom module layers).
(Uwe Schindler)
-* GITHUB#12634, GITHUB#12632, GITHUB#12680, GITHUB#12681, GITHUB#12731: Speed up Panama vector support
- and test improvements. (Uwe Schindler, Robert Muir)
+* GITHUB#12634, GITHUB#12632, GITHUB#12680, GITHUB#12681, GITHUB#12731, GITHUB#12737: Speed up
+ Panama vector support and test improvements. (Uwe Schindler, Robert Muir)
* GITHUB#12586: Remove over-counting of deleted terms. (Guo Feng)
@@ -267,6 +267,8 @@
* GITHUB#12569: Prevent concurrent tasks from parallelizing execution further which could cause deadlock
(Luca Cavanna)
+* GITHUB#12765: Disable vectorization on VMs that are not Hotspot-based. (Uwe Schindler, Robert Muir)
+
Bug Fixes
---------------------
diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingStoredFieldsConsumer.java b/lucene/core/src/java/org/apache/lucene/index/SortingStoredFieldsConsumer.java
index 61bd680..1c7c582 100644
--- a/lucene/core/src/java/org/apache/lucene/index/SortingStoredFieldsConsumer.java
+++ b/lucene/core/src/java/org/apache/lucene/index/SortingStoredFieldsConsumer.java
@@ -61,7 +61,7 @@
public void decompress(
DataInput in, int originalLength, int offset, int length, BytesRef bytes)
throws IOException {
- bytes.bytes = ArrayUtil.grow(bytes.bytes, length);
+ bytes.bytes = ArrayUtil.growNoCopy(bytes.bytes, length);
in.skipBytes(offset);
in.readBytes(bytes.bytes, 0, length);
bytes.offset = 0;
diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java
index de546c9..750e0ee 100644
--- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java
+++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java
@@ -17,72 +17,46 @@
package org.apache.lucene.internal.vectorization;
+import org.apache.lucene.util.Constants;
+import org.apache.lucene.util.SuppressForbidden;
+
final class DefaultVectorUtilSupport implements VectorUtilSupport {
DefaultVectorUtilSupport() {}
+ // the way FMA should work! if available use it, otherwise fall back to mul/add
+ @SuppressForbidden(reason = "Uses FMA only where fast and carefully contained")
+ private static float fma(float a, float b, float c) {
+ if (Constants.HAS_FAST_SCALAR_FMA) {
+ return Math.fma(a, b, c);
+ } else {
+ return a * b + c;
+ }
+ }
+
@Override
public float dotProduct(float[] a, float[] b) {
float res = 0f;
- /*
- * If length of vector is larger than 8, we use unrolled dot product to accelerate the
- * calculation.
- */
- int i;
- for (i = 0; i < a.length % 8; i++) {
- res += b[i] * a[i];
+ int i = 0;
+
+ // if the array is big, unroll it
+ if (a.length > 32) {
+ float acc1 = 0;
+ float acc2 = 0;
+ float acc3 = 0;
+ float acc4 = 0;
+ int upperBound = a.length & ~(4 - 1);
+ for (; i < upperBound; i += 4) {
+ acc1 = fma(a[i], b[i], acc1);
+ acc2 = fma(a[i + 1], b[i + 1], acc2);
+ acc3 = fma(a[i + 2], b[i + 2], acc3);
+ acc4 = fma(a[i + 3], b[i + 3], acc4);
+ }
+ res += acc1 + acc2 + acc3 + acc4;
}
- if (a.length < 8) {
- return res;
- }
- for (; i + 31 < a.length; i += 32) {
- res +=
- b[i + 0] * a[i + 0]
- + b[i + 1] * a[i + 1]
- + b[i + 2] * a[i + 2]
- + b[i + 3] * a[i + 3]
- + b[i + 4] * a[i + 4]
- + b[i + 5] * a[i + 5]
- + b[i + 6] * a[i + 6]
- + b[i + 7] * a[i + 7];
- res +=
- b[i + 8] * a[i + 8]
- + b[i + 9] * a[i + 9]
- + b[i + 10] * a[i + 10]
- + b[i + 11] * a[i + 11]
- + b[i + 12] * a[i + 12]
- + b[i + 13] * a[i + 13]
- + b[i + 14] * a[i + 14]
- + b[i + 15] * a[i + 15];
- res +=
- b[i + 16] * a[i + 16]
- + b[i + 17] * a[i + 17]
- + b[i + 18] * a[i + 18]
- + b[i + 19] * a[i + 19]
- + b[i + 20] * a[i + 20]
- + b[i + 21] * a[i + 21]
- + b[i + 22] * a[i + 22]
- + b[i + 23] * a[i + 23];
- res +=
- b[i + 24] * a[i + 24]
- + b[i + 25] * a[i + 25]
- + b[i + 26] * a[i + 26]
- + b[i + 27] * a[i + 27]
- + b[i + 28] * a[i + 28]
- + b[i + 29] * a[i + 29]
- + b[i + 30] * a[i + 30]
- + b[i + 31] * a[i + 31];
- }
- for (; i + 7 < a.length; i += 8) {
- res +=
- b[i + 0] * a[i + 0]
- + b[i + 1] * a[i + 1]
- + b[i + 2] * a[i + 2]
- + b[i + 3] * a[i + 3]
- + b[i + 4] * a[i + 4]
- + b[i + 5] * a[i + 5]
- + b[i + 6] * a[i + 6]
- + b[i + 7] * a[i + 7];
+
+ for (; i < a.length; i++) {
+ res = fma(a[i], b[i], res);
}
return res;
}
@@ -92,50 +66,80 @@
float sum = 0.0f;
float norm1 = 0.0f;
float norm2 = 0.0f;
- int dim = a.length;
+ int i = 0;
- for (int i = 0; i < dim; i++) {
- float elem1 = a[i];
- float elem2 = b[i];
- sum += elem1 * elem2;
- norm1 += elem1 * elem1;
- norm2 += elem2 * elem2;
+ // if the array is big, unroll it
+ if (a.length > 32) {
+ float sum1 = 0;
+ float sum2 = 0;
+ float norm1_1 = 0;
+ float norm1_2 = 0;
+ float norm2_1 = 0;
+ float norm2_2 = 0;
+
+ int upperBound = a.length & ~(2 - 1);
+ for (; i < upperBound; i += 2) {
+ // one
+ sum1 = fma(a[i], b[i], sum1);
+ norm1_1 = fma(a[i], a[i], norm1_1);
+ norm2_1 = fma(b[i], b[i], norm2_1);
+
+ // two
+ sum2 = fma(a[i + 1], b[i + 1], sum2);
+ norm1_2 = fma(a[i + 1], a[i + 1], norm1_2);
+ norm2_2 = fma(b[i + 1], b[i + 1], norm2_2);
+ }
+ sum += sum1 + sum2;
+ norm1 += norm1_1 + norm1_2;
+ norm2 += norm2_1 + norm2_2;
+ }
+
+ for (; i < a.length; i++) {
+ sum = fma(a[i], b[i], sum);
+ norm1 = fma(a[i], a[i], norm1);
+ norm2 = fma(b[i], b[i], norm2);
}
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
}
@Override
public float squareDistance(float[] a, float[] b) {
- float squareSum = 0.0f;
- int dim = a.length;
- int i;
- for (i = 0; i + 8 <= dim; i += 8) {
- squareSum += squareDistanceUnrolled(a, b, i);
- }
- for (; i < dim; i++) {
- float diff = a[i] - b[i];
- squareSum += diff * diff;
- }
- return squareSum;
- }
+ float res = 0;
+ int i = 0;
- private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
- float diff0 = v1[index + 0] - v2[index + 0];
- float diff1 = v1[index + 1] - v2[index + 1];
- float diff2 = v1[index + 2] - v2[index + 2];
- float diff3 = v1[index + 3] - v2[index + 3];
- float diff4 = v1[index + 4] - v2[index + 4];
- float diff5 = v1[index + 5] - v2[index + 5];
- float diff6 = v1[index + 6] - v2[index + 6];
- float diff7 = v1[index + 7] - v2[index + 7];
- return diff0 * diff0
- + diff1 * diff1
- + diff2 * diff2
- + diff3 * diff3
- + diff4 * diff4
- + diff5 * diff5
- + diff6 * diff6
- + diff7 * diff7;
+ // if the array is big, unroll it
+ if (a.length > 32) {
+ float acc1 = 0;
+ float acc2 = 0;
+ float acc3 = 0;
+ float acc4 = 0;
+
+ int upperBound = a.length & ~(4 - 1);
+ for (; i < upperBound; i += 4) {
+ // one
+ float diff1 = a[i] - b[i];
+ acc1 = fma(diff1, diff1, acc1);
+
+ // two
+ float diff2 = a[i + 1] - b[i + 1];
+ acc2 = fma(diff2, diff2, acc2);
+
+ // three
+ float diff3 = a[i + 2] - b[i + 2];
+ acc3 = fma(diff3, diff3, acc3);
+
+ // four
+ float diff4 = a[i + 3] - b[i + 3];
+ acc4 = fma(diff4, diff4, acc4);
+ }
+ res += acc1 + acc2 + acc3 + acc4;
+ }
+
+ for (; i < a.length; i++) {
+ float diff = a[i] - b[i];
+ res = fma(diff, diff, res);
+ }
+ return res;
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java
index 3d565b6..35a4852 100644
--- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java
+++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java
@@ -111,6 +111,12 @@
+ Locale.getDefault());
return new DefaultVectorizationProvider();
}
+ // only use vector module with Hotspot VM
+ if (!Constants.IS_HOTSPOT_VM) {
+ LOG.warning(
+ "Java runtime is not using Hotspot VM; Java vector incubator API can't be enabled.");
+ return new DefaultVectorizationProvider();
+ }
// is the incubator module present and readable (JVM providers may to exclude them or it is
// build with jlink)
final var vectorMod = lookupVectorModule();
diff --git a/lucene/core/src/java/org/apache/lucene/util/ByteBlockPool.java b/lucene/core/src/java/org/apache/lucene/util/ByteBlockPool.java
index d2d5fc0..cbdd3bb 100644
--- a/lucene/core/src/java/org/apache/lucene/util/ByteBlockPool.java
+++ b/lucene/core/src/java/org/apache/lucene/util/ByteBlockPool.java
@@ -38,6 +38,8 @@
/** Abstract class for allocating and freeing byte blocks. */
public abstract static class Allocator {
+ // TODO: ByteBlockPool assume the blockSize is always {@link BYTE_BLOCK_SIZE}, but this class
+ // allow arbitrary value of blockSize. We should make them consistent.
protected final int blockSize;
protected Allocator(int blockSize) {
@@ -215,19 +217,38 @@
/** Appends the bytes in the provided {@link BytesRef} at the current position. */
public void append(final BytesRef bytes) {
- int bytesLeft = bytes.length;
- int offset = bytes.offset;
+ append(bytes.bytes, bytes.offset, bytes.length);
+ }
+
+ /**
+ * Append the provided byte array at the current position.
+ *
+ * @param bytes the byte array to write
+ */
+ public void append(final byte[] bytes) {
+ append(bytes, 0, bytes.length);
+ }
+
+ /**
+ * Append some portion of the provided byte array at the current position.
+ *
+ * @param bytes the byte array to write
+ * @param offset the offset of the byte array
+ * @param length the number of bytes to write
+ */
+ public void append(final byte[] bytes, int offset, int length) {
+ int bytesLeft = length;
while (bytesLeft > 0) {
int bufferLeft = BYTE_BLOCK_SIZE - byteUpto;
if (bytesLeft < bufferLeft) {
// fits within current buffer
- System.arraycopy(bytes.bytes, offset, buffer, byteUpto, bytesLeft);
+ System.arraycopy(bytes, offset, buffer, byteUpto, bytesLeft);
byteUpto += bytesLeft;
break;
} else {
// fill up this buffer and move to next one
if (bufferLeft > 0) {
- System.arraycopy(bytes.bytes, offset, buffer, byteUpto, bufferLeft);
+ System.arraycopy(bytes, offset, buffer, byteUpto, bufferLeft);
}
nextBuffer();
bytesLeft -= bufferLeft;
@@ -256,6 +277,18 @@
}
}
+ /**
+ * Read a single byte at the given offset
+ *
+ * @param offset the offset to read
+ * @return the byte
+ */
+ public byte readByte(final long offset) {
+ int bufferIndex = (int) (offset >> BYTE_BLOCK_SHIFT);
+ int pos = (int) (offset & BYTE_BLOCK_MASK);
+ return buffers[bufferIndex][pos];
+ }
+
@Override
public long ramBytesUsed() {
long size = BASE_RAM_BYTES;
@@ -269,4 +302,9 @@
}
return size;
}
+
+ /** the current position (in absolute value) of this byte pool */
+ public long getPosition() {
+ return bufferUpto * allocator.blockSize + byteUpto;
+ }
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/Constants.java b/lucene/core/src/java/org/apache/lucene/util/Constants.java
index 3ef1298..01ead03 100644
--- a/lucene/core/src/java/org/apache/lucene/util/Constants.java
+++ b/lucene/core/src/java/org/apache/lucene/util/Constants.java
@@ -18,7 +18,6 @@
import java.security.AccessController;
import java.security.PrivilegedAction;
-import java.util.Objects;
import java.util.logging.Logger;
/** Some useful constants. */
@@ -60,19 +59,16 @@
/** The value of <code>System.getProperty("java.vendor")</code>. */
public static final String JAVA_VENDOR = getSysProp("java.vendor", UNKNOWN);
- /** True iff the Java runtime is a client runtime and C2 compiler is not enabled */
+ /** True iff the Java runtime is a client runtime and C2 compiler is not enabled. */
public static final boolean IS_CLIENT_VM =
getSysProp("java.vm.info", "").contains("emulated-client");
+ /** True iff the Java VM is based on Hotspot and has the Hotspot MX bean readable by Lucene. */
+ public static final boolean IS_HOTSPOT_VM = HotspotVMOptions.IS_HOTSPOT_VM;
+
/** True iff running on a 64bit JVM */
public static final boolean JRE_IS_64BIT = is64Bit();
- /** true iff we know fast FMA is supported, to deliver less error */
- public static final boolean HAS_FAST_FMA =
- (IS_CLIENT_VM == false)
- && Objects.equals(OS_ARCH, "amd64")
- && HotspotVMOptions.get("UseFMA").map(Boolean::valueOf).orElse(false);
-
private static boolean is64Bit() {
final String datamodel = getSysProp("sun.arch.data.model");
if (datamodel != null) {
@@ -82,6 +78,76 @@
}
}
+ /** true if FMA likely means a cpu instruction and not BigDecimal logic. */
+ private static final boolean HAS_FMA =
+ (IS_CLIENT_VM == false) && HotspotVMOptions.get("UseFMA").map(Boolean::valueOf).orElse(false);
+
+ /** maximum supported vectorsize. */
+ private static final int MAX_VECTOR_SIZE =
+ HotspotVMOptions.get("MaxVectorSize").map(Integer::valueOf).orElse(0);
+
+ /** true for an AMD cpu with SSE4a instructions. */
+ private static final boolean HAS_SSE4A =
+ HotspotVMOptions.get("UseXmmI2F").map(Boolean::valueOf).orElse(false);
+
+ /** true iff we know VFMA has faster throughput than separate vmul/vadd. */
+ public static final boolean HAS_FAST_VECTOR_FMA = hasFastVectorFMA();
+
+ /** true iff we know FMA has faster throughput than separate mul/add. */
+ public static final boolean HAS_FAST_SCALAR_FMA = hasFastScalarFMA();
+
+ private static boolean hasFastVectorFMA() {
+ if (HAS_FMA) {
+ String value = getSysProp("lucene.useVectorFMA", "auto");
+ if ("auto".equals(value)) {
+ // newer Neoverse cores have their act together
+ // the problem is just apple silicon (this is a practical heuristic)
+ if (OS_ARCH.equals("aarch64") && MAC_OS_X == false) {
+ return true;
+ }
+ // zen cores or newer, its a wash, turn it on as it doesn't hurt
+ // starts to yield gains for vectors only at zen4+
+ if (HAS_SSE4A && MAX_VECTOR_SIZE >= 32) {
+ return true;
+ }
+ // intel has their act together
+ if (OS_ARCH.equals("amd64") && HAS_SSE4A == false) {
+ return true;
+ }
+ } else {
+ return Boolean.parseBoolean(value);
+ }
+ }
+ // everyone else is slow, until proven otherwise by benchmarks
+ return false;
+ }
+
+ private static boolean hasFastScalarFMA() {
+ if (HAS_FMA) {
+ String value = getSysProp("lucene.useScalarFMA", "auto");
+ if ("auto".equals(value)) {
+ // newer Neoverse cores have their act together
+ // the problem is just apple silicon (this is a practical heuristic)
+ if (OS_ARCH.equals("aarch64") && MAC_OS_X == false) {
+ return true;
+ }
+ // latency becomes 4 for the Zen3 (0x19h), but still a wash
+ // until the Zen4 anyway, and big drop on previous zens:
+ if (HAS_SSE4A && MAX_VECTOR_SIZE >= 64) {
+ return true;
+ }
+ // intel has their act together
+ if (OS_ARCH.equals("amd64") && HAS_SSE4A == false) {
+ return true;
+ }
+ } else {
+ return Boolean.parseBoolean(value);
+ }
+ }
+ // everyone else is slow, until proven otherwise by benchmarks
+ return false;
+ }
+
private static String getSysProp(String property) {
try {
return doPrivileged(() -> System.getProperty(property));
diff --git a/lucene/core/src/java/org/apache/lucene/util/HotspotVMOptions.java b/lucene/core/src/java/org/apache/lucene/util/HotspotVMOptions.java
index 70f963e..e9fc584 100644
--- a/lucene/core/src/java/org/apache/lucene/util/HotspotVMOptions.java
+++ b/lucene/core/src/java/org/apache/lucene/util/HotspotVMOptions.java
@@ -26,8 +26,8 @@
final class HotspotVMOptions {
private HotspotVMOptions() {} // can't construct
- /** True if the Java VM is based on Hotspot and has the Hotspot MX bean readable by Lucene */
- public static final boolean IS_HOTSPOT;
+ /** True iff the Java VM is based on Hotspot and has the Hotspot MX bean readable by Lucene */
+ public static final boolean IS_HOTSPOT_VM;
/**
* Returns an optional with the value of a Hotspot VM option. If the VM option does not exist or
@@ -84,7 +84,7 @@
"Lucene cannot optimize algorithms or calculate object sizes for JVMs that are not based on Hotspot or a compatible implementation.");
}
}
- IS_HOTSPOT = isHotspot;
+ IS_HOTSPOT_VM = isHotspot;
ACCESSOR = accessor;
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/RamUsageEstimator.java b/lucene/core/src/java/org/apache/lucene/util/RamUsageEstimator.java
index 7e0bdfd..fd32ece 100644
--- a/lucene/core/src/java/org/apache/lucene/util/RamUsageEstimator.java
+++ b/lucene/core/src/java/org/apache/lucene/util/RamUsageEstimator.java
@@ -112,10 +112,10 @@
/** Initialize constants and try to collect information about the JVM internals. */
static {
- if (Constants.JRE_IS_64BIT && HotspotVMOptions.IS_HOTSPOT) {
+ if (Constants.JRE_IS_64BIT) {
+ JVM_IS_HOTSPOT_64BIT = HotspotVMOptions.IS_HOTSPOT_VM;
// Try to get compressed oops and object alignment (the default seems to be 8 on Hotspot);
// (this only works on 64 bit, on 32 bits the alignment and reference size is fixed):
- JVM_IS_HOTSPOT_64BIT = true;
COMPRESSED_REFS_ENABLED =
HotspotVMOptions.get("UseCompressedOops").map(Boolean::valueOf).orElse(false);
NUM_BYTES_OBJECT_ALIGNMENT =
diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
index ef52e605..4a792c1 100644
--- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
@@ -20,7 +20,31 @@
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider;
-/** Utilities for computations with numeric arrays */
+/**
+ * Utilities for computations with numeric arrays, especially algebraic operations like vector dot
+ * products. This class uses SIMD vectorization if the corresponding Java module is available and
+ * enabled. To enable vectorized code, pass {@code --add-modules jdk.incubator.vector} to Java's
+ * command line.
+ *
+ * <p>It will use CPU's <a href="https://en.wikipedia.org/wiki/Fused_multiply%E2%80%93add">FMA
+ * instructions</a> if it is known to perform faster than separate multiply+add. This requires at
+ * least Hotspot C2 enabled, which is the default for OpenJDK based JVMs.
+ *
+ * <p>To explicitly disable or enable FMA usage, pass the following system properties:
+ *
+ * <ul>
+ * <li>{@code -Dlucene.useScalarFMA=(auto|true|false)} for scalar operations
+ * <li>{@code -Dlucene.useVectorFMA=(auto|true|false)} for vectorized operations (with vector
+ * incubator module)
+ * </ul>
+ *
+ * <p>The default is {@code auto}, which enables this for known CPU types and JVM settings. If
+ * Hotspot C2 is disabled, FMA and vectorization are <strong>not</strong> used.
+ *
+ * <p>Vectorization and FMA is only supported for Hotspot-based JVMs; it won't work on OpenJ9-based
+ * JVMs unless they provide {@link com.sun.management.HotSpotDiagnosticMXBean}. Please also make
+ * sure that you have the {@code jdk.management} module enabled in modularized applications.
+ */
public final class VectorUtil {
private static final VectorUtilSupport IMPL =
diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/ByteBlockPoolReverseBytesReader.java b/lucene/core/src/java/org/apache/lucene/util/fst/ByteBlockPoolReverseBytesReader.java
new file mode 100644
index 0000000..41ca21d
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/util/fst/ByteBlockPoolReverseBytesReader.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util.fst;
+
+import java.io.IOException;
+import org.apache.lucene.util.ByteBlockPool;
+
+/** Reads in reverse from a ByteBlockPool. */
+final class ByteBlockPoolReverseBytesReader extends FST.BytesReader {
+
+ private final ByteBlockPool buf;
+ // the difference between the FST node address and the hash table copied node address
+ private long posDelta;
+ private long pos;
+
+ public ByteBlockPoolReverseBytesReader(ByteBlockPool buf) {
+ this.buf = buf;
+ }
+
+ @Override
+ public byte readByte() {
+ return buf.readByte(pos--);
+ }
+
+ @Override
+ public void readBytes(byte[] b, int offset, int len) {
+ for (int i = 0; i < len; i++) {
+ b[offset + i] = buf.readByte(pos--);
+ }
+ }
+
+ @Override
+ public void skipBytes(long numBytes) throws IOException {
+ pos -= numBytes;
+ }
+
+ @Override
+ public long getPosition() {
+ return pos + posDelta;
+ }
+
+ @Override
+ public void setPosition(long pos) {
+ this.pos = pos - posDelta;
+ }
+
+ @Override
+ public boolean reversed() {
+ return true;
+ }
+
+ public void setPosDelta(long posDelta) {
+ this.posDelta = posDelta;
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/BytesStore.java b/lucene/core/src/java/org/apache/lucene/util/fst/BytesStore.java
index 469454d..a03b9b0 100644
--- a/lucene/core/src/java/org/apache/lucene/util/fst/BytesStore.java
+++ b/lucene/core/src/java/org/apache/lucene/util/fst/BytesStore.java
@@ -444,11 +444,7 @@
@Override
public FST.BytesReader getReverseBytesReader() {
- return getReverseReader(true);
- }
-
- FST.BytesReader getReverseReader(boolean allowSingle) {
- if (allowSingle && blocks.size() == 1) {
+ if (blocks.size() == 1) {
return new ReverseBytesReader(blocks.get(0));
}
return new FST.BytesReader() {
diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/FSTCompiler.java b/lucene/core/src/java/org/apache/lucene/util/fst/FSTCompiler.java
index 3af6241..53cb18a 100644
--- a/lucene/core/src/java/org/apache/lucene/util/fst/FSTCompiler.java
+++ b/lucene/core/src/java/org/apache/lucene/util/fst/FSTCompiler.java
@@ -145,7 +145,7 @@
if (suffixRAMLimitMB < 0) {
throw new IllegalArgumentException("ramLimitMB must be >= 0; got: " + suffixRAMLimitMB);
} else if (suffixRAMLimitMB > 0) {
- dedupHash = new NodeHash<>(this, suffixRAMLimitMB, bytes.getReverseReader(false));
+ dedupHash = new NodeHash<>(this, suffixRAMLimitMB);
} else {
dedupHash = null;
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/NodeHash.java b/lucene/core/src/java/org/apache/lucene/util/fst/NodeHash.java
index 04c1be4..6907416 100644
--- a/lucene/core/src/java/org/apache/lucene/util/fst/NodeHash.java
+++ b/lucene/core/src/java/org/apache/lucene/util/fst/NodeHash.java
@@ -17,6 +17,7 @@
package org.apache.lucene.util.fst;
import java.io.IOException;
+import org.apache.lucene.util.ByteBlockPool;
import org.apache.lucene.util.packed.PackedInts;
import org.apache.lucene.util.packed.PagedGrowableWriter;
@@ -49,14 +50,17 @@
private final FSTCompiler<T> fstCompiler;
private final FST.Arc<T> scratchArc = new FST.Arc<>();
- private final FST.BytesReader in;
+ // store the last fallback table node length in getFallback()
+ private int lastFallbackNodeLength;
+ // store the last fallback table hashtable slot in getFallback()
+ private long lastFallbackHashSlot;
/**
* ramLimitMB is the max RAM we can use for recording suffixes. If we hit this limit, the least
* recently used suffixes are discarded, and the FST is no longer minimalI. Still, larger
* ramLimitMB will make the FST smaller (closer to minimal).
*/
- public NodeHash(FSTCompiler<T> fstCompiler, double ramLimitMB, FST.BytesReader in) {
+ public NodeHash(FSTCompiler<T> fstCompiler, double ramLimitMB) {
if (ramLimitMB <= 0) {
throw new IllegalArgumentException("ramLimitMB must be > 0; got: " + ramLimitMB);
}
@@ -70,28 +74,35 @@
primaryTable = new PagedGrowableHash();
this.fstCompiler = fstCompiler;
- this.in = in;
}
private long getFallback(FSTCompiler.UnCompiledNode<T> nodeIn, long hash) throws IOException {
+ this.lastFallbackNodeLength = -1;
+ this.lastFallbackHashSlot = -1;
if (fallbackTable == null) {
// no fallback yet (primary table is not yet large enough to swap)
return 0;
}
- long pos = hash & fallbackTable.mask;
+ long hashSlot = hash & fallbackTable.mask;
int c = 0;
while (true) {
- long node = fallbackTable.get(pos);
- if (node == 0) {
+ long nodeAddress = fallbackTable.getNodeAddress(hashSlot);
+ if (nodeAddress == 0) {
// not found
return 0;
- } else if (nodesEqual(nodeIn, node)) {
- // frozen version of this node is already here
- return node;
+ } else {
+ int length = fallbackTable.nodesEqual(nodeIn, nodeAddress, hashSlot);
+ if (length != -1) {
+ // store the node length for further use
+ this.lastFallbackNodeLength = length;
+ this.lastFallbackHashSlot = hashSlot;
+ // frozen version of this node is already here
+ return nodeAddress;
+ }
}
// quadratic probe (but is it, really?)
- pos = (pos + (++c)) & fallbackTable.mask;
+ hashSlot = (hashSlot + (++c)) & fallbackTable.mask;
}
}
@@ -99,36 +110,60 @@
long hash = hash(nodeIn);
- long pos = hash & primaryTable.mask;
+ long hashSlot = hash & primaryTable.mask;
int c = 0;
while (true) {
- long node = primaryTable.get(pos);
- if (node == 0) {
+ long nodeAddress = primaryTable.getNodeAddress(hashSlot);
+ if (nodeAddress == 0) {
// node is not in primary table; is it in fallback table?
- node = getFallback(nodeIn, hash);
- if (node != 0) {
+ nodeAddress = getFallback(nodeIn, hash);
+ if (nodeAddress != 0) {
+ assert lastFallbackHashSlot != -1 && lastFallbackNodeLength != -1;
+
// it was already in fallback -- promote to primary
- primaryTable.set(pos, node);
+ // TODO: Copy directly between 2 ByteBlockPool to avoid double-copy
+ primaryTable.setNode(
+ hashSlot,
+ nodeAddress,
+ fallbackTable.getBytes(lastFallbackHashSlot, lastFallbackNodeLength));
} else {
// not in fallback either -- freeze & add the incoming node
+ long startAddress = fstCompiler.bytes.getPosition();
// freeze & add
- node = fstCompiler.addNode(nodeIn);
+ nodeAddress = fstCompiler.addNode(nodeIn);
+ // TODO: Write the bytes directly from BytesStore
// we use 0 as empty marker in hash table, so it better be impossible to get a frozen node
// at 0:
- assert node != 0;
+ assert nodeAddress != FST.FINAL_END_NODE && nodeAddress != FST.NON_FINAL_END_NODE;
+ byte[] buf = new byte[Math.toIntExact(nodeAddress - startAddress + 1)];
+ fstCompiler.bytes.copyBytes(startAddress, buf, 0, buf.length);
+
+ primaryTable.setNode(hashSlot, nodeAddress, buf);
// confirm frozen hash and unfrozen hash are the same
- assert hash(node) == hash : "mismatch frozenHash=" + hash(node) + " vs hash=" + hash;
-
- primaryTable.set(pos, node);
+ assert primaryTable.hash(nodeAddress, hashSlot) == hash
+ : "mismatch frozenHash="
+ + primaryTable.hash(nodeAddress, hashSlot)
+ + " vs hash="
+ + hash;
}
// how many bytes would be used if we had "perfect" hashing:
- long ramBytesUsed = primaryTable.count * PackedInts.bitsRequired(node) / 8;
+ // - x2 for fstNodeAddress for FST node address
+ // - x2 for copiedNodeAddress for copied node address
+ // - the bytes copied out FST to the hashtable copiedNodes
+ // each account for approximate hash table overhead halfway between 33.3% and 66.6%
+ // note that some of the copiedNodes are shared between fallback and primary tables so this
+ // computation is pessimistic
+ long copiedBytes = primaryTable.copiedNodes.getPosition();
+ long ramBytesUsed =
+ primaryTable.count * 2 * PackedInts.bitsRequired(nodeAddress) / 8
+ + primaryTable.count * 2 * PackedInts.bitsRequired(copiedBytes) / 8
+ + copiedBytes;
// NOTE: we could instead use the more precise RAM used, but this leads to unpredictable
// quantized behavior due to 2X rehashing where for large ranges of the RAM limit, the
@@ -138,30 +173,29 @@
// in smaller FSTs, even if the precise RAM used is not always under the limit.
// divide limit by 2 because fallback gets half the RAM and primary gets the other half
- // divide by 2 again to account for approximate hash table overhead halfway between 33.3%
- // and 66.7% occupancy = 50%
- if (ramBytesUsed >= ramLimitBytes / (2 * 2)) {
+ if (ramBytesUsed >= ramLimitBytes / 2) {
// time to fallback -- fallback is now used read-only to promote a node (suffix) to
// primary if we encounter it again
fallbackTable = primaryTable;
// size primary table the same size to reduce rehash cost
// TODO: we could clear & reuse the previous fallbackTable, instead of allocating a new
// to reduce GC load
- primaryTable = new PagedGrowableHash(node, Math.max(16, primaryTable.entries.size()));
- } else if (primaryTable.count > primaryTable.entries.size() * (2f / 3)) {
+ primaryTable =
+ new PagedGrowableHash(nodeAddress, Math.max(16, primaryTable.fstNodeAddress.size()));
+ } else if (primaryTable.count > primaryTable.fstNodeAddress.size() * (2f / 3)) {
// rehash at 2/3 occupancy
- primaryTable.rehash(node);
+ primaryTable.rehash(nodeAddress);
}
- return node;
+ return nodeAddress;
- } else if (nodesEqual(nodeIn, node)) {
+ } else if (primaryTable.nodesEqual(nodeIn, nodeAddress, hashSlot) != -1) {
// same node (in frozen form) is already in primary table
- return node;
+ return nodeAddress;
}
// quadratic probe (but is it, really?)
- pos = (pos + (++c)) & primaryTable.mask;
+ hashSlot = (hashSlot + (++c)) & primaryTable.mask;
}
}
@@ -186,149 +220,233 @@
return h;
}
- // hash code for a frozen node. this must precisely match the hash computation of an unfrozen
- // node!
- private long hash(long node) throws IOException {
- final int PRIME = 31;
-
- long h = 0;
- fstCompiler.fst.readFirstRealTargetArc(node, scratchArc, in);
- while (true) {
- h = PRIME * h + scratchArc.label();
- h = PRIME * h + (int) (scratchArc.target() ^ (scratchArc.target() >> 32));
- h = PRIME * h + scratchArc.output().hashCode();
- h = PRIME * h + scratchArc.nextFinalOutput().hashCode();
- if (scratchArc.isFinal()) {
- h += 17;
- }
- if (scratchArc.isLast()) {
- break;
- }
- fstCompiler.fst.readNextRealArc(scratchArc, in);
- }
-
- return h;
- }
-
- /**
- * Compares an unfrozen node (UnCompiledNode) with a frozen node at byte location address (long),
- * returning true if they are equal.
- */
- private boolean nodesEqual(FSTCompiler.UnCompiledNode<T> node, long address) throws IOException {
- fstCompiler.fst.readFirstRealTargetArc(address, scratchArc, in);
-
- // fail fast for a node with fixed length arcs
- if (scratchArc.bytesPerArc() != 0) {
- assert node.numArcs > 0;
- // the frozen node uses fixed-with arc encoding (same number of bytes per arc), but may be
- // sparse or dense
- switch (scratchArc.nodeFlags()) {
- case FST.ARCS_FOR_BINARY_SEARCH:
- // sparse
- if (node.numArcs != scratchArc.numArcs()) {
- return false;
- }
- break;
- case FST.ARCS_FOR_DIRECT_ADDRESSING:
- // dense -- compare both the number of labels allocated in the array (some of which may
- // not actually be arcs), and the number of arcs
- if ((node.arcs[node.numArcs - 1].label - node.arcs[0].label + 1) != scratchArc.numArcs()
- || node.numArcs != FST.Arc.BitTable.countBits(scratchArc, in)) {
- return false;
- }
- break;
- default:
- throw new AssertionError("unhandled scratchArc.nodeFlag() " + scratchArc.nodeFlags());
- }
- }
-
- // compare arc by arc to see if there is a difference
- for (int arcUpto = 0; arcUpto < node.numArcs; arcUpto++) {
- final FSTCompiler.Arc<T> arc = node.arcs[arcUpto];
- if (arc.label != scratchArc.label()
- || arc.output.equals(scratchArc.output()) == false
- || ((FSTCompiler.CompiledNode) arc.target).node != scratchArc.target()
- || arc.nextFinalOutput.equals(scratchArc.nextFinalOutput()) == false
- || arc.isFinal != scratchArc.isFinal()) {
- return false;
- }
-
- if (scratchArc.isLast()) {
- if (arcUpto == node.numArcs - 1) {
- return true;
- } else {
- return false;
- }
- }
-
- fstCompiler.fst.readNextRealArc(scratchArc, in);
- }
-
- // unfrozen node has fewer arcs than frozen node
-
- return false;
- }
-
/** Inner class because it needs access to hash function and FST bytes. */
private class PagedGrowableHash {
- private PagedGrowableWriter entries;
+ // storing the FST node address where the position is the masked hash of the node arcs
+ private PagedGrowableWriter fstNodeAddress;
+ // storing the local copiedNodes address in the same position as fstNodeAddress
+ // here we are effectively storing a Map<Long, Long> from the FST node address to copiedNodes
+ // address
+ private PagedGrowableWriter copiedNodeAddress;
private long count;
private long mask;
+ // storing the byte slice from the FST for nodes we added to the hash so that we don't need to
+ // look up from the FST itself, so the FST bytes can stream directly to disk as append-only
+ // writes.
+ // each node will be written subsequently
+ private final ByteBlockPool copiedNodes;
+ // the {@link FST.BytesReader} to read from copiedNodes. we use this when computing a frozen
+ // node hash
+ // or comparing if a frozen and unfrozen nodes are equal
+ private final ByteBlockPoolReverseBytesReader bytesReader;
// 256K blocks, but note that the final block is sized only as needed so it won't use the full
// block size when just a few elements were written to it
private static final int BLOCK_SIZE_BYTES = 1 << 18;
public PagedGrowableHash() {
- entries = new PagedGrowableWriter(16, BLOCK_SIZE_BYTES, 8, PackedInts.COMPACT);
+ fstNodeAddress = new PagedGrowableWriter(16, BLOCK_SIZE_BYTES, 8, PackedInts.COMPACT);
+ copiedNodeAddress = new PagedGrowableWriter(16, BLOCK_SIZE_BYTES, 8, PackedInts.COMPACT);
mask = 15;
+ copiedNodes = new ByteBlockPool(new ByteBlockPool.DirectAllocator());
+ bytesReader = new ByteBlockPoolReverseBytesReader(copiedNodes);
}
public PagedGrowableHash(long lastNodeAddress, long size) {
- entries =
+ fstNodeAddress =
new PagedGrowableWriter(
size, BLOCK_SIZE_BYTES, PackedInts.bitsRequired(lastNodeAddress), PackedInts.COMPACT);
+ copiedNodeAddress = new PagedGrowableWriter(size, BLOCK_SIZE_BYTES, 8, PackedInts.COMPACT);
mask = size - 1;
assert (mask & size) == 0 : "size must be a power-of-2; got size=" + size + " mask=" + mask;
+ copiedNodes = new ByteBlockPool(new ByteBlockPool.DirectAllocator());
+ bytesReader = new ByteBlockPoolReverseBytesReader(copiedNodes);
}
- public long get(long index) {
- return entries.get(index);
+ /**
+ * Get the copied bytes at the provided hash slot
+ *
+ * @param hashSlot the hash slot to read from
+ * @param length the number of bytes to read
+ * @return the copied byte array
+ */
+ public byte[] getBytes(long hashSlot, int length) {
+ long address = copiedNodeAddress.get(hashSlot);
+ assert address - length + 1 >= 0;
+ byte[] buf = new byte[length];
+ copiedNodes.readBytes(address - length + 1, buf, 0, length);
+ return buf;
}
- public void set(long index, long pointer) throws IOException {
- entries.set(index, pointer);
+ /**
+ * Get the node address from the provided hash slot
+ *
+ * @param hashSlot the hash slot to read
+ * @return the node address
+ */
+ public long getNodeAddress(long hashSlot) {
+ return fstNodeAddress.get(hashSlot);
+ }
+
+ /**
+ * Set the node address and bytes from the provided hash slot
+ *
+ * @param hashSlot the hash slot to write to
+ * @param nodeAddress the node address
+ * @param bytes the node bytes to be copied
+ */
+ public void setNode(long hashSlot, long nodeAddress, byte[] bytes) {
+ assert fstNodeAddress.get(hashSlot) == 0;
+ fstNodeAddress.set(hashSlot, nodeAddress);
count++;
+
+ copiedNodes.append(bytes);
+ // write the offset, which points to the last byte of the node we copied since we later read
+ // this node in reverse
+ assert copiedNodeAddress.get(hashSlot) == 0;
+ copiedNodeAddress.set(hashSlot, copiedNodes.getPosition() - 1);
}
private void rehash(long lastNodeAddress) throws IOException {
+ // TODO: https://github.com/apache/lucene/issues/12744
+ // should we always use a small startBitsPerValue here (e.g 8) instead base off of
+ // lastNodeAddress?
+
// double hash table size on each rehash
- PagedGrowableWriter newEntries =
+ long newSize = 2 * fstNodeAddress.size();
+ PagedGrowableWriter newCopiedNodeAddress =
new PagedGrowableWriter(
- 2 * entries.size(),
+ newSize,
+ BLOCK_SIZE_BYTES,
+ PackedInts.bitsRequired(copiedNodes.getPosition()),
+ PackedInts.COMPACT);
+ PagedGrowableWriter newFSTNodeAddress =
+ new PagedGrowableWriter(
+ newSize,
BLOCK_SIZE_BYTES,
PackedInts.bitsRequired(lastNodeAddress),
PackedInts.COMPACT);
- long newMask = newEntries.size() - 1;
- for (long idx = 0; idx < entries.size(); idx++) {
- long address = entries.get(idx);
+ long newMask = newFSTNodeAddress.size() - 1;
+ for (long idx = 0; idx < fstNodeAddress.size(); idx++) {
+ long address = fstNodeAddress.get(idx);
if (address != 0) {
- long pos = hash(address) & newMask;
+ long hashSlot = hash(address, idx) & newMask;
int c = 0;
while (true) {
- if (newEntries.get(pos) == 0) {
- newEntries.set(pos, address);
+ if (newFSTNodeAddress.get(hashSlot) == 0) {
+ newFSTNodeAddress.set(hashSlot, address);
+ newCopiedNodeAddress.set(hashSlot, copiedNodeAddress.get(idx));
break;
}
// quadratic probe
- pos = (pos + (++c)) & newMask;
+ hashSlot = (hashSlot + (++c)) & newMask;
}
}
}
mask = newMask;
- entries = newEntries;
+ fstNodeAddress = newFSTNodeAddress;
+ copiedNodeAddress = newCopiedNodeAddress;
+ }
+
+ // hash code for a frozen node. this must precisely match the hash computation of an unfrozen
+ // node!
+ private long hash(long nodeAddress, long hashSlot) throws IOException {
+ FST.BytesReader in = getBytesReader(nodeAddress, hashSlot);
+
+ final int PRIME = 31;
+
+ long h = 0;
+ fstCompiler.fst.readFirstRealTargetArc(nodeAddress, scratchArc, in);
+ while (true) {
+ h = PRIME * h + scratchArc.label();
+ h = PRIME * h + (int) (scratchArc.target() ^ (scratchArc.target() >> 32));
+ h = PRIME * h + scratchArc.output().hashCode();
+ h = PRIME * h + scratchArc.nextFinalOutput().hashCode();
+ if (scratchArc.isFinal()) {
+ h += 17;
+ }
+ if (scratchArc.isLast()) {
+ break;
+ }
+ fstCompiler.fst.readNextRealArc(scratchArc, in);
+ }
+
+ return h;
+ }
+
+ /**
+ * Compares an unfrozen node (UnCompiledNode) with a frozen node at byte location address
+ * (long), returning the node length if the two nodes are equals, or -1 otherwise
+ *
+ * <p>The node length will be used to promote the node from the fallback table to the primary
+ * table
+ */
+ private int nodesEqual(FSTCompiler.UnCompiledNode<T> node, long address, long hashSlot)
+ throws IOException {
+ FST.BytesReader in = getBytesReader(address, hashSlot);
+ fstCompiler.fst.readFirstRealTargetArc(address, scratchArc, in);
+
+ // fail fast for a node with fixed length arcs
+ if (scratchArc.bytesPerArc() != 0) {
+ assert node.numArcs > 0;
+ // the frozen node uses fixed-with arc encoding (same number of bytes per arc), but may be
+ // sparse or dense
+ switch (scratchArc.nodeFlags()) {
+ case FST.ARCS_FOR_BINARY_SEARCH:
+ // sparse
+ if (node.numArcs != scratchArc.numArcs()) {
+ return -1;
+ }
+ break;
+ case FST.ARCS_FOR_DIRECT_ADDRESSING:
+ // dense -- compare both the number of labels allocated in the array (some of which may
+ // not actually be arcs), and the number of arcs
+ if ((node.arcs[node.numArcs - 1].label - node.arcs[0].label + 1) != scratchArc.numArcs()
+ || node.numArcs != FST.Arc.BitTable.countBits(scratchArc, in)) {
+ return -1;
+ }
+ break;
+ default:
+ throw new AssertionError("unhandled scratchArc.nodeFlag() " + scratchArc.nodeFlags());
+ }
+ }
+
+ // compare arc by arc to see if there is a difference
+ for (int arcUpto = 0; arcUpto < node.numArcs; arcUpto++) {
+ final FSTCompiler.Arc<T> arc = node.arcs[arcUpto];
+ if (arc.label != scratchArc.label()
+ || arc.output.equals(scratchArc.output()) == false
+ || ((FSTCompiler.CompiledNode) arc.target).node != scratchArc.target()
+ || arc.nextFinalOutput.equals(scratchArc.nextFinalOutput()) == false
+ || arc.isFinal != scratchArc.isFinal()) {
+ return -1;
+ }
+
+ if (scratchArc.isLast()) {
+ if (arcUpto == node.numArcs - 1) {
+ // position is 1 index past the starting address, as we are reading in backward
+ return Math.toIntExact(address - in.getPosition());
+ } else {
+ return -1;
+ }
+ }
+
+ fstCompiler.fst.readNextRealArc(scratchArc, in);
+ }
+
+ // unfrozen node has fewer arcs than frozen node
+
+ return -1;
+ }
+
+ private FST.BytesReader getBytesReader(long nodeAddress, long hashSlot) {
+ // make sure the nodeAddress and hashSlot is consistent
+ assert fstNodeAddress.get(hashSlot) == nodeAddress;
+ long localAddress = copiedNodeAddress.get(hashSlot);
+ bytesReader.setPosDelta(nodeAddress - localAddress);
+ return bytesReader;
}
}
}
diff --git a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java
index d4e8a50..ccd838c 100644
--- a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java
+++ b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java
@@ -29,6 +29,7 @@
import jdk.incubator.vector.VectorShape;
import jdk.incubator.vector.VectorSpecies;
import org.apache.lucene.util.Constants;
+import org.apache.lucene.util.SuppressForbidden;
/**
* VectorUtil methods implemented with Panama incubating vector API.
@@ -79,13 +80,22 @@
// the way FMA should work! if available use it, otherwise fall back to mul/add
private static FloatVector fma(FloatVector a, FloatVector b, FloatVector c) {
- if (Constants.HAS_FAST_FMA) {
+ if (Constants.HAS_FAST_VECTOR_FMA) {
return a.fma(b, c);
} else {
return a.mul(b).add(c);
}
}
+ @SuppressForbidden(reason = "Uses FMA only where fast and carefully contained")
+ private static float fma(float a, float b, float c) {
+ if (Constants.HAS_FAST_SCALAR_FMA) {
+ return Math.fma(a, b, c);
+ } else {
+ return a * b + c;
+ }
+ }
+
@Override
public float dotProduct(float[] a, float[] b) {
int i = 0;
@@ -99,7 +109,7 @@
// scalar tail
for (; i < a.length; i++) {
- res += b[i] * a[i];
+ res = fma(a[i], b[i], res);
}
return res;
}
@@ -165,11 +175,9 @@
// scalar tail
for (; i < a.length; i++) {
- float elem1 = a[i];
- float elem2 = b[i];
- sum += elem1 * elem2;
- norm1 += elem1 * elem1;
- norm2 += elem2 * elem2;
+ sum = fma(a[i], b[i], sum);
+ norm1 = fma(a[i], a[i], norm1);
+ norm2 = fma(b[i], b[i], norm2);
}
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
}
@@ -230,7 +238,7 @@
// scalar tail
for (; i < a.length; i++) {
float diff = a[i] - b[i];
- res += diff * diff;
+ res = fma(diff, diff, res);
}
return res;
}
diff --git a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java
index ffd18df..11901d7 100644
--- a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java
+++ b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java
@@ -63,7 +63,7 @@
Locale.ENGLISH,
"Java vector incubator API enabled; uses preferredBitSize=%d%s%s",
PanamaVectorUtilSupport.VECTOR_BITSIZE,
- Constants.HAS_FAST_FMA ? "; FMA enabled" : "",
+ Constants.HAS_FAST_VECTOR_FMA ? "; FMA enabled" : "",
PanamaVectorUtilSupport.HAS_FAST_INTEGER_VECTORS
? ""
: "; floating-point vectors only"));
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestByteBlockPool.java b/lucene/core/src/test/org/apache/lucene/util/TestByteBlockPool.java
index b242f00..c7c4e80 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestByteBlockPool.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestByteBlockPool.java
@@ -79,6 +79,7 @@
ByteBlockPool pool = new ByteBlockPool(new ByteBlockPool.DirectTrackingAllocator(bytesUsed));
pool.nextBuffer();
+ long totalBytes = 0;
List<byte[]> items = new ArrayList<>();
for (int i = 0; i < 100; i++) {
int size;
@@ -91,6 +92,10 @@
random().nextBytes(bytes);
items.add(bytes);
pool.append(new BytesRef(bytes));
+ totalBytes += size;
+
+ // make sure we report the correct position
+ assertEquals(totalBytes, pool.getPosition());
}
long position = 0;