blob: 59b31ca7cf41f11521e1952902cc7b41c5dc6aa7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.cassandra.spark.bulkwriter;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Range;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import o.a.c.sidecar.client.shaded.common.request.Request;
import o.a.c.sidecar.client.shaded.common.request.data.CreateSliceRequestPayload;
import o.a.c.sidecar.client.shaded.common.response.data.RingEntry;
import org.apache.cassandra.sidecar.client.SidecarClient;
import org.apache.cassandra.sidecar.client.SidecarInstanceImpl;
import org.apache.cassandra.sidecar.client.exception.RetriesExhaustedException;
import org.apache.cassandra.spark.bulkwriter.ImportCompletionCoordinator.RequestAndInstance;
import org.apache.cassandra.spark.bulkwriter.blobupload.BlobDataTransferApi;
import org.apache.cassandra.spark.bulkwriter.blobupload.BlobStreamResult;
import org.apache.cassandra.spark.bulkwriter.blobupload.CreatedRestoreSlice;
import org.apache.cassandra.spark.bulkwriter.blobupload.StorageClient;
import org.apache.cassandra.spark.bulkwriter.token.ConsistencyLevel;
import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping;
import org.apache.cassandra.spark.data.QualifiedTableName;
import org.apache.cassandra.spark.data.partitioner.Partitioner;
import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportExtension;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atMostOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
class ImportCompletionCoordinatorTest
{
private static final int TOTAL_INSTANCES = 10;
BulkWriterContext mockWriterContext;
BulkWriteValidator writerValidator;
TokenRangeMapping<RingInstance> topology;
JobInfo mockJobInfo;
BlobDataTransferApi dataTransferApi;
UUID jobId;
StorageTransportExtension mockExtension;
ArgumentCaptor<String> appliedObjectKeys;
Consumer<CancelJobEvent> onCancelJob;
@BeforeEach
public void setup() throws Exception
{
mockJobInfo = mock(JobInfo.class);
jobId = UUID.randomUUID();
when(mockJobInfo.getId()).thenReturn(jobId.toString());
when(mockJobInfo.getRestoreJobId()).thenReturn(jobId);
when(mockJobInfo.qualifiedTableName()).thenReturn(new QualifiedTableName("testkeyspace", "testtable"));
when(mockJobInfo.getConsistencyLevel()).thenReturn(ConsistencyLevel.CL.QUORUM);
when(mockJobInfo.effectiveSidecarPort()).thenReturn(9043);
when(mockJobInfo.jobKeepAliveMinutes()).thenReturn(-1);
mockWriterContext = mock(BulkWriterContext.class);
ClusterInfo mockClusterInfo = mock(ClusterInfo.class);
when(mockWriterContext.cluster()).thenReturn(mockClusterInfo);
CassandraContext mockCassandraContext = mock(CassandraContext.class);
when(mockClusterInfo.getCassandraContext()).thenReturn(mockCassandraContext);
topology = TokenRangeMappingUtils.buildTokenRangeMapping(0, ImmutableMap.of("DC1", 3), TOTAL_INSTANCES);
when(mockClusterInfo.getTokenRangeMapping(anyBoolean())).thenReturn(topology);
when(mockWriterContext.job()).thenReturn(mockJobInfo);
writerValidator = new BulkWriteValidator(mockWriterContext, new ReplicaAwareFailureHandler<>(Partitioner.Murmur3Partitioner));
// clients will not be used in this test class; mock is at the API method level
BlobDataTransferApi api = new BlobDataTransferApi(mockJobInfo, mock(SidecarClient.class), mock(StorageClient.class));
dataTransferApi = spy(api);
mockExtension = mock(StorageTransportExtension.class);
appliedObjectKeys = ArgumentCaptor.forClass(String.class);
doNothing().when(mockExtension).onObjectApplied(any(), appliedObjectKeys.capture(), anyLong(), anyLong());
onCancelJob = event -> {
throw new RuntimeException("It should not be called");
};
}
@Test
void testAwaitForCompletionWithNoErrors()
{
List<BlobStreamResult> resultList = buildBlobStreamResult(0, false, 0);
ImportCompletionCoordinator.of(0, mockWriterContext, dataTransferApi,
writerValidator, resultList, mockExtension, onCancelJob)
.waitForCompletion();
validateAllSlicesWereCalledAtMostOnce(resultList);
assertEquals(resultList.size(), appliedObjectKeys.getAllValues().size(),
"All objects should be applied and reported for exactly once");
assertEquals(allTestObjectKeys(), new HashSet<>(appliedObjectKeys.getAllValues()));
}
@Test
void testAwaitForCompletionWithNoErrorsAndSlowImport()
{
List<BlobStreamResult> resultList = buildBlobStreamResult(0, true, 0);
ImportCompletionCoordinator.of(0, mockWriterContext, dataTransferApi,
writerValidator, resultList, mockExtension, onCancelJob)
.waitForCompletion();
validateAllSlicesWereCalledAtMostOnce(resultList);
assertEquals(resultList.size(), appliedObjectKeys.getAllValues().size(),
"All objects should be applied and reported for exactly once");
assertEquals(allTestObjectKeys(), new HashSet<>(appliedObjectKeys.getAllValues()));
}
@Test // the test scenario has error when checking, but CL passes overall and the import is successful
void testAwaitForCompletionWithErrorsAndCLPasses()
{
// There is 1 failure in each replica set. 2 out of 3 replicas succeeds.
List<BlobStreamResult> resultList = buildBlobStreamResult(1, false, 0);
ImportCompletionCoordinator.of(0, mockWriterContext, dataTransferApi,
writerValidator, resultList, mockExtension, onCancelJob)
.waitForCompletion();
validateAllSlicesWereCalledAtMostOnce(resultList);
assertEquals(resultList.size(), appliedObjectKeys.getAllValues().size(),
"All objects should be applied and reported for exactly once");
assertEquals(allTestObjectKeys(), new HashSet<>(appliedObjectKeys.getAllValues()));
}
@Test // the test scenario has errors that fails CL, the import fails
void testAwaitForCompletionWithErrorsAndCLFails()
{
// There is 2 failure in each replica set. Only 1 out of 3 replicas succeeds.
// All replica sets fail, the number of ranges is not deterministic.
// Therefore, the assertion omits the number of ranges in the message
String errorMessage = "ranges with QUORUM for job " + jobId + " in phase WaitForCommitCompletion";
List<BlobStreamResult> resultList = buildBlobStreamResult(2, false, 0);
RuntimeException exception = assertThrows(RuntimeException.class, () -> {
ImportCompletionCoordinator.of(0, mockWriterContext, dataTransferApi,
writerValidator, resultList, mockExtension, onCancelJob)
.waitForCompletion();
});
assertNotNull(exception.getMessage());
assertTrue(exception.getMessage().contains("Failed to load"));
assertTrue(exception.getMessage().contains(errorMessage));
assertNotNull(exception.getCause());
validateAllSlicesWereCalledAtMostOnce(resultList);
assertEquals(0, appliedObjectKeys.getAllValues().size(),
"No object should be applied and reported");
}
@Test
void testCLUnsatisfiedRanges()
{
String errorMessage = "Some of the token ranges cannot satisfy with consistency level. job=" + jobId + " phase=WaitForCommitCompletion";
// CL check won't fail as there is no failed instances.
// The check won't be satisfied too since there is not enough available instances.
List<BlobStreamResult> resultList = buildBlobStreamResult(0, false, 2);
RuntimeException exception = assertThrows(RuntimeException.class, () -> {
ImportCompletionCoordinator.of(0, mockWriterContext, dataTransferApi,
writerValidator, resultList, mockExtension, onCancelJob)
.waitForCompletion();
});
assertNotNull(exception.getMessage());
assertTrue(exception.getMessage().contains(errorMessage));
assertNull(exception.getCause());
validateAllSlicesWereCalledAtMostOnce(resultList);
assertEquals(0, appliedObjectKeys.getAllValues().size(),
"No object should be applied and reported");
}
@Test
void testAwaitShouldPassWithStuckSliceWhenClSatisfied()
{
/*
* When slice import is stuck on server side, i.e. import request never indicate the slice is complete.
* If the consistency level has been satisfied for all ranges, it is safe to ignore the abnormal status
* of the stuck slices.
* The test verifies that in such scenario, ImportCompletionCoordinator does not block forever,
* and it can conclude success result
*/
List<BlobStreamResult> resultList = buildBlobStreamResultWithNoProgressImports(1);
ImportCompletionCoordinator coordinator = ImportCompletionCoordinator.of(0, mockWriterContext, dataTransferApi,
writerValidator, resultList, mockExtension, onCancelJob);
coordinator.waitForCompletion();
assertEquals(resultList.size(), appliedObjectKeys.getAllValues().size(),
"All objects should be applied and reported for exactly once");
assertEquals(allTestObjectKeys(), new HashSet<>(appliedObjectKeys.getAllValues()));
Map<CompletableFuture<Void>, RequestAndInstance> importFutures = coordinator.importFutures();
int cancelledImports = importFutures.keySet().stream().mapToInt(f -> f.isCancelled() ? 1 : 0).sum();
assertEquals(TOTAL_INSTANCES, cancelledImports,
"Each replica set should have a slice gets cancelled due to making no progress");
}
@Test
void testJobCancelOnTopologyChanged()
{
AtomicBoolean isCancelled = new AtomicBoolean(false);
Consumer<CancelJobEvent> onCancel = event -> {
isCancelled.set(true);
};
BulkWriterContext mockWriterContext = mock(BulkWriterContext.class);
ClusterInfo mockClusterInfo = mock(ClusterInfo.class);
when(mockWriterContext.cluster()).thenReturn(mockClusterInfo);
when(mockClusterInfo.getTokenRangeMapping(false))
.thenReturn(TokenRangeMappingUtils.buildTokenRangeMapping(0,
ImmutableMap.of("DC1", 3),
TOTAL_INSTANCES))
.thenReturn(TokenRangeMappingUtils.buildTokenRangeMapping(0,
ImmutableMap.of("DC1", 3),
TOTAL_INSTANCES + 1)); // adding a new instance; expansion
List<BlobStreamResult> resultList = buildBlobStreamResult(0, false, 0);
AtomicReference<CassandraTopologyMonitor> monitorRef = new AtomicReference<>(null);
ImportCompletionCoordinator coordinator = new ImportCompletionCoordinator(0, mockWriterContext, dataTransferApi,
writerValidator, resultList, mockExtension, onCancel,
(clusterInfo, onCancelJob) -> {
monitorRef.set(new CassandraTopologyMonitor(clusterInfo, onCancelJob));
return monitorRef.get();
});
monitorRef.get().checkTopologyOnDemand();
CompletionException coordinatorEx = assertThrows(CompletionException.class, coordinator::waitForCompletion);
assertEquals("Topology changed during bulk write", coordinatorEx.getCause().getMessage());
assertTrue(isCancelled.get());
CompletableFuture<Void> firstFailure = coordinator.firstFailure();
assertTrue(firstFailure.isCompletedExceptionally());
ExecutionException firstFailureEx = assertThrows(ExecutionException.class, firstFailure::get);
assertEquals(coordinatorEx.getCause(), firstFailureEx.getCause());
}
private Set<String> allTestObjectKeys()
{
return IntStream.range(0, 10).boxed().map(i -> "key_for_instance_" + i).collect(Collectors.toSet());
}
private List<BlobStreamResult> buildBlobStreamResultWithNoProgressImports(int noProgressInstanceCount)
{
return buildBlobStreamResult(0, false, 0, noProgressInstanceCount);
}
private List<BlobStreamResult> buildBlobStreamResult(int failedInstanceCount, boolean simulateSlowImport, int unavailableInstanceCount)
{
return buildBlobStreamResult(failedInstanceCount, simulateSlowImport, unavailableInstanceCount, 0);
}
/**
* @param failedInstanceCount number of instances in each replica set that fail the http request
* @param simulateSlowImport slow import with artificial delay
* @param unavailableInstanceCount number of instances in each replica set that is not included in the BlobStreamResult
* @param noProgressInstanceCount number of instances in each replica set that make no progress, i.e. future never complete
* @return a list of blob stream result
*/
private List<BlobStreamResult> buildBlobStreamResult(int failedInstanceCount,
boolean simulateSlowImport,
int unavailableInstanceCount,
int noProgressInstanceCount)
{
List<BlobStreamResult> resultList = new ArrayList<>();
int totalInstances = 10;
for (int i = 0; i < totalInstances; i++)
{
List<RingInstance> replicaSet = Arrays.asList(ringInstance(i, totalInstances),
ringInstance(i + 1, totalInstances),
ringInstance(i + 2, totalInstances));
Set<CreatedRestoreSlice> createdRestoreSlices = new HashSet<>();
int failedPerReplica = failedInstanceCount;
int unavailablePerReplica = unavailableInstanceCount;
int noProgressPerReplicaSet = noProgressInstanceCount;
// create one distinct slice per instance
CreateSliceRequestPayload mockCreateSliceRequestPayload = mock(CreateSliceRequestPayload.class);
when(mockCreateSliceRequestPayload.startToken()).thenReturn(BigInteger.valueOf(100 * i));
when(mockCreateSliceRequestPayload.endToken()).thenReturn(BigInteger.valueOf(100 * (1 + i)));
when(mockCreateSliceRequestPayload.sliceId()).thenReturn(UUID.randomUUID().toString());
when(mockCreateSliceRequestPayload.key()).thenReturn("key_for_instance_" + i); // to be captured by extension mock
when(mockCreateSliceRequestPayload.bucket()).thenReturn("bucket"); // to be captured by extension mock
when(mockCreateSliceRequestPayload.compressedSize()).thenReturn(1L); // to be captured by extension mock
when(mockCreateSliceRequestPayload.compressedSizeOrZero()).thenReturn(1L);
List<RingInstance> passedReplicaSet = new ArrayList<>();
for (RingInstance instance : replicaSet)
{
if (unavailablePerReplica-- > 0)
{
continue; // do not include this instance
}
passedReplicaSet.add(instance);
createdRestoreSlices.add(new CreatedRestoreSlice(mockCreateSliceRequestPayload));
if (simulateSlowImport && i == totalInstances - 1)
{
// only add slowness for the last import
doAnswer((Answer<CompletableFuture<Void>>) invocation -> {
Thread.sleep(ThreadLocalRandom.current().nextInt(2000));
return CompletableFuture.completedFuture(null);
})
.when(dataTransferApi)
.createRestoreSliceFromDriver(eq(new SidecarInstanceImpl(instance.nodeName(), 9043)),
eq(mockCreateSliceRequestPayload));
}
else if (noProgressPerReplicaSet-- > 0)
{
// return a future that does complete
doReturn(new CompletableFuture<>())
.when(dataTransferApi)
.createRestoreSliceFromDriver(eq(new SidecarInstanceImpl(instance.nodeName(), 9043)),
eq(mockCreateSliceRequestPayload));
}
else if (failedPerReplica-- > 0)
{
CompletableFuture<Void> future = new CompletableFuture<>();
future.completeExceptionally(RetriesExhaustedException.of(10, mock(Request.class), null));
doReturn(future)
.when(dataTransferApi)
.createRestoreSliceFromDriver(eq(new SidecarInstanceImpl(instance.nodeName(), 9043)),
eq(mockCreateSliceRequestPayload));
}
else
{
doReturn(CompletableFuture.completedFuture(null))
.when(dataTransferApi)
.createRestoreSliceFromDriver(eq(new SidecarInstanceImpl(instance.nodeName(), 9043)),
eq(mockCreateSliceRequestPayload));
}
}
BlobStreamResult result = new BlobStreamResult("", mock(Range.class), Collections.emptyList(),
passedReplicaSet, createdRestoreSlices, 0, 0);
resultList.add(result);
}
return resultList;
}
// Some slice might not be called due to short circuit, hence at most once
private void validateAllSlicesWereCalledAtMostOnce(List<BlobStreamResult> resultList)
{
for (BlobStreamResult blobStreamResult : resultList)
{
for (RingInstance instance : blobStreamResult.passed)
{
for (CreatedRestoreSlice createdRestoreSlice : blobStreamResult.createdRestoreSlices)
{
verify(dataTransferApi, atMostOnce())
.createRestoreSliceFromDriver(eq(new SidecarInstanceImpl(instance.nodeName(), 9043)),
eq(createdRestoreSlice.sliceRequestPayload()));
}
}
}
}
private RingInstance ringInstance(int i, int totalInstances)
{
int instanceInRing = i % totalInstances + 1;
return new RingInstance(new RingEntry.Builder()
.datacenter("DC1")
.address("127.0.0." + instanceInRing)
.token(String.valueOf(i * 100))
.fqdn("instance-" + instanceInRing)
.build());
}
}