[FLINK-9676][network] clarify contracts of BufferListener#notifyBufferAvailable() and fix a deadlock
When recycling exclusive buffers of a RemoteInputChannel and recycling
(other/floating) buffers to the buffer pool concurrently while the
RemoteInputChannel is registered as a listener to the buffer pool and adding the
exclusive buffer triggers a floating buffer to be recycled back to the same
buffer pool, a deadlock would occur holding locks on
LocalBufferPool#availableMemorySegments and RemoteInputChannel#bufferQueue but
acquiring them in reverse order.
One such instance would be:
Task canceler thread -> RemoteInputChannel1#releaseAllResources -> recycle floating buffers
-> lock(LocalBufferPool#availableMemorySegments) -> RemoteInputChannel2#notifyBufferAvailable
-> try to lock(RemoteInputChannel2#bufferQueue)
Task thread -> RemoteInputChannel2#recycle
-> lock(RemoteInputChannel2#bufferQueue) -> bufferQueue#addExclusiveBuffer -> floatingBuffer#recycleBuffer
-> try to lock(LocalBufferPool#availableMemorySegments)
Therefore, we decouple the listener callback from lock around
LocalBufferPool#availableMemorySegments and implicitly enforce that
RemoteInputChannel2#bufferQueue takes precedence over this lock, i.e. must
be acquired first and should never be taken after having locked on
LocalBufferPool#availableMemorySegments.
This closes #6257.
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java
index 7d9aa210..77eb601 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java
@@ -57,6 +57,12 @@
/**
* The currently available memory segments. These are segments, which have been requested from
* the network buffer pool and are currently not handed out as Buffer instances.
+ *
+ * <p><strong>BEWARE:</strong> Take special care with the interactions between this lock and
+ * locks acquired before entering this class vs. locks being acquired during calls to external
+ * code inside this class, e.g. with
+ * {@link org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel#bufferQueue}
+ * via the {@link #registeredListeners} callback.
*/
private final ArrayDeque<MemorySegment> availableMemorySegments = new ArrayDeque<MemorySegment>();
@@ -251,27 +257,56 @@
@Override
public void recycle(MemorySegment segment) {
+ BufferListener listener;
synchronized (availableMemorySegments) {
if (isDestroyed || numberOfRequestedMemorySegments > currentPoolSize) {
returnMemorySegment(segment);
+ return;
}
else {
- BufferListener listener = registeredListeners.poll();
+ listener = registeredListeners.poll();
if (listener == null) {
availableMemorySegments.add(segment);
availableMemorySegments.notify();
+ return;
}
- else {
- try {
- boolean needMoreBuffers = listener.notifyBufferAvailable(new NetworkBuffer(segment, this));
- if (needMoreBuffers) {
- registeredListeners.add(listener);
- }
+ }
+ }
+
+ // We do not know which locks have been acquired before the recycle() or are needed in the
+ // notification and which other threads also access them.
+ // -> call notifyBufferAvailable() outside of the synchronized block to avoid a deadlock (FLINK-9676)
+ boolean success = false;
+ boolean needMoreBuffers = false;
+ try {
+ needMoreBuffers = listener.notifyBufferAvailable(new NetworkBuffer(segment, this));
+ success = true;
+ } catch (Throwable ignored) {
+ // handled below, under the lock
+ }
+
+ if (!success || needMoreBuffers) {
+ synchronized (availableMemorySegments) {
+ if (isDestroyed) {
+ // cleanup tasks how they would have been done if we only had one synchronized block
+ if (needMoreBuffers) {
+ listener.notifyBufferDestroyed();
}
- catch (Throwable ignored) {
- availableMemorySegments.add(segment);
- availableMemorySegments.notify();
+ if (!success) {
+ returnMemorySegment(segment);
+ }
+ } else {
+ if (needMoreBuffers) {
+ registeredListeners.add(listener);
+ }
+ if (!success) {
+ if (numberOfRequestedMemorySegments > currentPoolSize) {
+ returnMemorySegment(segment);
+ } else {
+ availableMemorySegments.add(segment);
+ availableMemorySegments.notify();
+ }
}
}
}
@@ -283,6 +318,7 @@
*/
@Override
public void lazyDestroy() {
+ // NOTE: if you change this logic, be sure to update recycle() as well!
synchronized (availableMemorySegments) {
if (!isDestroyed) {
MemorySegment segment;
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPoolTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPoolTest.java
index b04286e..8834291 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPoolTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPoolTest.java
@@ -36,6 +36,7 @@
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
@@ -405,13 +406,13 @@
private BufferListener createBufferListener(int notificationTimes) {
return spy(new BufferListener() {
- int times = 0;
+ AtomicInteger times = new AtomicInteger(0);
@Override
public boolean notifyBufferAvailable(Buffer buffer) {
- times++;
+ int newCount = times.incrementAndGet();
buffer.recycleBuffer();
- return times < notificationTimes;
+ return newCount < notificationTimes;
}
@Override
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
index 802cb93..4080106 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
@@ -33,6 +33,7 @@
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
import org.apache.flink.runtime.taskmanager.TaskActions;
+import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.shaded.guava18.com.google.common.collect.Lists;
@@ -52,6 +53,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
@@ -62,6 +64,9 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
+/**
+ * Tests for the {@link RemoteInputChannel}.
+ */
public class RemoteInputChannelTest {
@Test
@@ -804,7 +809,7 @@
recycleFloatingBufferTask(bufferPool, numFloatingBuffers),
requestBufferTask});
- assertEquals("There should be " + inputChannel.getNumberOfRequiredBuffers() +" buffers available in channel.",
+ assertEquals("There should be " + inputChannel.getNumberOfRequiredBuffers() + " buffers available in channel.",
inputChannel.getNumberOfRequiredBuffers(), inputChannel.getNumberOfAvailableBuffers());
assertEquals("There should be no buffers available in local pool.",
0, bufferPool.getNumberOfAvailableMemorySegments());
@@ -878,6 +883,95 @@
}
}
+ /**
+ * Tests to verify that there is no race condition with two things running in parallel:
+ * recycling exclusive buffers and recycling external buffers to the buffer pool while the
+ * recycling of the exclusive buffer triggers recycling a floating buffer (FLINK-9676).
+ */
+ @Test
+ public void testConcurrentRecycleAndRelease2() throws Exception {
+ // Setup
+ final int retries = 1_000;
+ final int numExclusiveBuffers = 2;
+ final int numFloatingBuffers = 2;
+ final int numTotalBuffers = numExclusiveBuffers + numFloatingBuffers;
+ final NetworkBufferPool networkBufferPool = new NetworkBufferPool(
+ numTotalBuffers, 32);
+
+ final ExecutorService executor = Executors.newFixedThreadPool(2);
+
+ final SingleInputGate inputGate = createSingleInputGate();
+ final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate);
+ inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel);
+ try {
+ final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers);
+ inputGate.setBufferPool(bufferPool);
+ inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers);
+ inputChannel.requestSubpartition(0);
+
+ final Callable<Void> bufferPoolInteractionsTask = () -> {
+ for (int i = 0; i < retries; ++i) {
+ Buffer buffer = bufferPool.requestBufferBlocking();
+ buffer.recycleBuffer();
+ }
+ return null;
+ };
+
+ final Callable<Void> channelInteractionsTask = () -> {
+ ArrayList<Buffer> exclusiveBuffers = new ArrayList<>(numExclusiveBuffers);
+ ArrayList<Buffer> floatingBuffers = new ArrayList<>(numExclusiveBuffers);
+ try {
+ for (int i = 0; i < retries; ++i) {
+ // note: we may still have a listener on the buffer pool and receive
+ // floating buffers as soon as we take exclusive ones
+ for (int j = 0; j < numTotalBuffers; ++j) {
+ Buffer buffer = inputChannel.requestBuffer();
+ if (buffer == null) {
+ break;
+ } else {
+ //noinspection ObjectEquality
+ if (buffer.getRecycler() == inputChannel) {
+ exclusiveBuffers.add(buffer);
+ } else {
+ floatingBuffers.add(buffer);
+ }
+ }
+ }
+ // recycle excess floating buffers (will go back into the channel)
+ floatingBuffers.forEach(Buffer::recycleBuffer);
+ floatingBuffers.clear();
+
+ assertEquals(numExclusiveBuffers, exclusiveBuffers.size());
+ inputChannel.onSenderBacklog(0); // trigger subscription to buffer pool
+ // note: if we got a floating buffer by increasing the backlog, it will be released again when recycling the exclusive buffer, if not, we should release it once we get it
+ exclusiveBuffers.forEach(Buffer::recycleBuffer);
+ exclusiveBuffers.clear();
+ }
+ } finally {
+ inputChannel.releaseAllResources();
+ }
+
+ return null;
+ };
+
+ // Submit tasks and wait to finish
+ submitTasksAndWaitForResults(executor,
+ new Callable[] {bufferPoolInteractionsTask, channelInteractionsTask});
+ } catch (Throwable t) {
+ inputChannel.releaseAllResources();
+
+ try {
+ networkBufferPool.destroyAllBufferPools();
+ } catch (Throwable tInner) {
+ t.addSuppressed(tInner);
+ }
+
+ networkBufferPool.destroy();
+ executor.shutdown();
+ ExceptionUtils.rethrowException(t);
+ }
+ }
+
// ---------------------------------------------------------------------------------------------
private SingleInputGate createSingleInputGate() {
@@ -986,7 +1080,8 @@
private void submitTasksAndWaitForResults(ExecutorService executor, Callable[] tasks) throws Exception {
final List<Future> results = Lists.newArrayListWithCapacity(tasks.length);
- for(Callable task : tasks) {
+ for (Callable task : tasks) {
+ //noinspection unchecked
results.add(executor.submit(task));
}