blob: 2970bbef847703ef9eeea94be6056a8631d0b029 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ratis.datastream;
import org.apache.ratis.BaseTest;
import org.apache.ratis.client.impl.ClientProtoUtils;
import org.apache.ratis.client.impl.DataStreamClientImpl.DataStreamOutputImpl;
import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer;
import org.apache.ratis.datastream.impl.DataStreamRequestByteBuffer;
import org.apache.ratis.io.StandardWriteOption;
import org.apache.ratis.proto.RaftProtos.DataStreamPacketHeaderProto.Type;
import org.apache.ratis.proto.RaftProtos.LogEntryProto;
import org.apache.ratis.proto.RaftProtos.StateMachineLogEntryProto;
import org.apache.ratis.protocol.ClientId;
import org.apache.ratis.protocol.ClientInvocationId;
import org.apache.ratis.protocol.DataStreamReply;
import org.apache.ratis.protocol.Message;
import org.apache.ratis.protocol.RaftClientMessage;
import org.apache.ratis.protocol.RaftClientReply;
import org.apache.ratis.protocol.RaftClientRequest;
import org.apache.ratis.protocol.RaftPeer;
import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.protocol.RoutingTable;
import org.apache.ratis.protocol.exceptions.AlreadyClosedException;
import org.apache.ratis.server.RaftServer;
import org.apache.ratis.server.raftlog.LogEntryHeader;
import org.apache.ratis.server.raftlog.LogProtoUtils;
import org.apache.ratis.server.raftlog.RaftLog;
import org.apache.ratis.statemachine.StateMachine.DataChannel;
import org.apache.ratis.statemachine.StateMachine.DataStream;
import org.apache.ratis.statemachine.TransactionContext;
import org.apache.ratis.statemachine.impl.BaseStateMachine;
import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
import org.apache.ratis.util.CollectionUtils;
import org.apache.ratis.util.FileUtils;
import org.apache.ratis.util.JavaUtils;
import org.junit.jupiter.api.Assertions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.channels.ReadableByteChannel;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ThreadLocalRandom;
public interface DataStreamTestUtils {
Logger LOG = LoggerFactory.getLogger(DataStreamTestUtils.class);
ByteString MOCK = ByteString.copyFromUtf8("mock");
int MODULUS = 23;
static byte pos2byte(int pos) {
return (byte) ('A' + pos % MODULUS);
}
static ByteBuffer initBuffer(int offset, int size) {
final ByteBuffer buffer = ByteBuffer.allocateDirect(size);
final int length = buffer.capacity();
buffer.position(0).limit(length);
for (int j = 0; j < length; j++) {
buffer.put(pos2byte(offset + j));
}
buffer.flip();
Assertions.assertEquals(length, buffer.remaining());
return buffer;
}
static void createFile(File f, int size) throws Exception {
final ReadableByteChannel source = new ReadableByteChannel() {
private int offset = 0;
@Override
public boolean isOpen() {
return offset < size;
}
@Override
public void close() {
offset = size;
}
@Override
public int read(ByteBuffer dst) {
final int start = offset;
for(; dst.remaining() > 0 && isOpen(); offset++) {
dst.put(pos2byte(offset));
}
return offset - start;
}
};
FileUtils.createDirectories(f.getParentFile());
try(FileChannel out = FileUtils.newFileChannel(f, StandardOpenOption.CREATE, StandardOpenOption.WRITE)) {
final long transferred = out.transferFrom(source, 0, size);
Assertions.assertEquals(size, transferred);
}
}
static ByteString bytesWritten2ByteString(long bytesWritten) {
return ByteString.copyFromUtf8("bytesWritten=" + bytesWritten);
}
static RoutingTable getRoutingTableChainTopology(Iterable<RaftPeer> peers, RaftPeer primary) {
return getRoutingTableChainTopology(CollectionUtils.as(peers, RaftPeer::getId), primary.getId());
}
static RoutingTable getRoutingTableChainTopology(Iterable<RaftPeerId> peers, RaftPeerId primary) {
final RoutingTable.Builder builder = RoutingTable.newBuilder();
RaftPeerId previous = primary;
for (RaftPeerId peer : peers) {
if (peer.equals(primary)) {
continue;
}
builder.addSuccessor(previous, peer);
previous = peer;
}
return builder.build();
}
class MultiDataStreamStateMachine extends BaseStateMachine {
private final ConcurrentMap<ClientInvocationId, SingleDataStream> streams = new ConcurrentHashMap<>();
@Override
public CompletableFuture<DataStream> stream(RaftClientRequest request) {
final SingleDataStream s = new SingleDataStream(request);
streams.put(ClientInvocationId.valueOf(request), s);
return CompletableFuture.completedFuture(s);
}
@Override
public CompletableFuture<?> link(DataStream stream, LogEntryProto entry) {
LOG.info("link {}", stream);
if (stream == null) {
return JavaUtils.completeExceptionally(new IllegalStateException("Null stream: entry=" + entry));
}
((SingleDataStream)stream).setLogEntry(entry);
return CompletableFuture.completedFuture(null);
}
@Override
public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
final LogEntryProto entry = Objects.requireNonNull(trx.getLogEntryUnsafe());
updateLastAppliedTermIndex(entry.getTerm(), entry.getIndex());
final SingleDataStream s = getSingleDataStream(ClientInvocationId.valueOf(entry.getStateMachineLogEntry()));
final ByteString bytesWritten = bytesWritten2ByteString(s.getDataChannel().getBytesWritten());
return CompletableFuture.completedFuture(() -> bytesWritten);
}
SingleDataStream getSingleDataStream(RaftClientRequest request) {
return getSingleDataStream(ClientInvocationId.valueOf(request));
}
SingleDataStream getSingleDataStream(ClientInvocationId invocationId) {
return streams.get(invocationId);
}
Collection<SingleDataStream> getStreams() {
return streams.values();
}
}
class SingleDataStream implements DataStream {
private final RaftClientRequest writeRequest;
private final MyDataChannel channel = new MyDataChannel();
private volatile LogEntryProto logEntry;
SingleDataStream(RaftClientRequest request) {
this.writeRequest = request;
}
@Override
public MyDataChannel getDataChannel() {
return channel;
}
@Override
public CompletableFuture<?> cleanUp() {
try {
channel.close();
} catch (Throwable t) {
return JavaUtils.completeExceptionally(t);
}
return CompletableFuture.completedFuture(null);
}
void setLogEntry(LogEntryProto logEntry) {
this.logEntry = logEntry;
}
LogEntryProto getLogEntry() {
return logEntry;
}
RaftClientRequest getWriteRequest() {
return writeRequest;
}
@Override
public String toString() {
return JavaUtils.getClassSimpleName(getClass()) + ": writeRequest=" + writeRequest
+ ", logEntry=" + LogProtoUtils.toLogEntryString(logEntry);
}
}
class MyDataChannel implements DataChannel {
private volatile boolean open = true;
private int bytesWritten = 0;
private int forcedPosition = 0;
int getBytesWritten() {
return bytesWritten;
}
int getForcedPosition() {
return forcedPosition;
}
@Override
public void force(boolean metadata) {
forcedPosition = bytesWritten;
}
@Override
public int write(ByteBuffer src) {
if (!open) {
throw new IllegalStateException("Already closed");
}
final int remaining = src.remaining();
for (; src.remaining() > 0; ) {
Assertions.assertEquals(pos2byte(bytesWritten), src.get());
bytesWritten += 1;
}
return remaining;
}
@Override
public boolean isOpen() {
return open;
}
@Override
public void close() {
open = false;
}
}
static int writeAndAssertReplies(DataStreamOutputImpl out, int bufferSize, int bufferNum) {
final List<CompletableFuture<DataStreamReply>> futures = new ArrayList<>();
final List<Integer> sizes = new ArrayList<>();
//send data
final int halfBufferSize = bufferSize / 2;
int dataSize = 0;
for (int i = 0; i < bufferNum; i++) {
final int size = halfBufferSize + ThreadLocalRandom.current().nextInt(halfBufferSize);
sizes.add(size);
final ByteBuffer bf = initBuffer(dataSize, size);
futures.add(i == bufferNum - 1 ? out.writeAsync(bf, StandardWriteOption.FLUSH, StandardWriteOption.SYNC)
: out.writeAsync(bf));
dataSize += size;
}
{ // check header
final DataStreamReply reply = out.getHeaderFuture().join();
assertSuccessReply(Type.STREAM_HEADER, 0, reply);
}
// check writeAsync requests
for (int i = 0; i < futures.size(); i++) {
final DataStreamReply reply = futures.get(i).join();
final Type expectedType = Type.STREAM_DATA;
assertSuccessReply(expectedType, sizes.get(i).longValue(), reply);
}
return dataSize;
}
static void assertSuccessReply(Type expectedType, long expectedBytesWritten, DataStreamReply reply) {
Assertions.assertTrue(reply.isSuccess());
Assertions.assertEquals(expectedBytesWritten, reply.getBytesWritten());
Assertions.assertEquals(expectedType, reply.getType());
}
static CompletableFuture<RaftClientReply> writeAndCloseAndAssertReplies(
Iterable<RaftServer> servers, RaftPeerId leader, DataStreamOutputImpl out, int bufferSize, int bufferNum,
ClientId clientId, boolean stepDownLeader) {
LOG.info("start Stream{}", out.getHeader().getCallId());
final int bytesWritten = writeAndAssertReplies(out, bufferSize, bufferNum);
try {
for (RaftServer s : servers) {
assertHeader(s, out.getHeader(), bytesWritten, stepDownLeader);
}
} catch (Throwable e) {
throw new CompletionException(e);
}
LOG.info("Stream{}: bytesWritten={}", out.getHeader().getCallId(), bytesWritten);
return out.closeAsync().thenCompose(
reply -> assertCloseReply(out, reply, bytesWritten, leader, clientId, stepDownLeader));
}
static void assertHeader(RaftServer server, RaftClientRequest header, int dataSize, boolean stepDownLeader)
throws Exception {
// check header
Assertions.assertEquals(RaftClientRequest.dataStreamRequestType(), header.getType());
// check stream
final MultiDataStreamStateMachine stateMachine = (MultiDataStreamStateMachine) server.getDivision(header.getRaftGroupId()).getStateMachine();
final SingleDataStream stream = stateMachine.getSingleDataStream(header);
final MyDataChannel channel = stream.getDataChannel();
Assertions.assertEquals(dataSize, channel.getBytesWritten());
Assertions.assertEquals(dataSize, channel.getForcedPosition());
// check writeRequest
final RaftClientRequest writeRequest = stream.getWriteRequest();
Assertions.assertEquals(RaftClientRequest.dataStreamRequestType(), writeRequest.getType());
assertRaftClientMessage(header, null, writeRequest, header.getClientId(), stepDownLeader);
}
static CompletableFuture<RaftClientReply> assertCloseReply(DataStreamOutputImpl out, DataStreamReply dataStreamReply,
long bytesWritten, RaftPeerId leader, ClientId clientId, boolean stepDownLeader) {
// Test close idempotent
Assertions.assertSame(dataStreamReply, out.closeAsync().join());
Assertions.assertEquals(dataStreamReply.getClientId(), clientId);
BaseTest.testFailureCase("writeAsync should fail",
() -> out.writeAsync(DataStreamRequestByteBuffer.EMPTY_BYTE_BUFFER).join(),
CompletionException.class, (Logger) null, AlreadyClosedException.class);
final DataStreamReplyByteBuffer buffer = (DataStreamReplyByteBuffer) dataStreamReply;
try {
final RaftClientReply reply = ClientProtoUtils.toRaftClientReply(buffer.slice());
assertRaftClientMessage(out.getHeader(), leader, reply, clientId, stepDownLeader);
if (reply.isSuccess()) {
final ByteString bytes = reply.getMessage().getContent();
if (!bytes.equals(MOCK)) {
Assertions.assertEquals(bytesWritten2ByteString(bytesWritten), bytes);
}
}
return CompletableFuture.completedFuture(reply);
} catch (Throwable t) {
return JavaUtils.completeExceptionally(t);
}
}
static void assertRaftClientMessage(
RaftClientMessage expected, RaftPeerId expectedServerId, RaftClientMessage computed, ClientId expectedClientId,
boolean stepDownLeader) {
Assertions.assertNotNull(computed);
Assertions.assertEquals(expectedClientId, computed.getClientId());
if (!stepDownLeader) {
Assertions.assertEquals(
Optional.ofNullable(expectedServerId).orElseGet(expected::getServerId), computed.getServerId());
}
Assertions.assertEquals(expected.getRaftGroupId(), computed.getRaftGroupId());
}
static LogEntryProto searchLogEntry(ClientInvocationId invocationId, RaftLog log) throws Exception {
for (LogEntryHeader termIndex : log.getEntries(0, Long.MAX_VALUE)) {
final LogEntryProto entry = log.get(termIndex.getIndex());
if (entry.hasStateMachineLogEntry()) {
if (invocationId.match(entry.getStateMachineLogEntry())) {
return entry;
}
}
}
return null;
}
static void assertLogEntry(LogEntryProto logEntry, RaftClientRequest request) {
Assertions.assertNotNull(logEntry);
Assertions.assertTrue(logEntry.hasStateMachineLogEntry());
final StateMachineLogEntryProto s = logEntry.getStateMachineLogEntry();
Assertions.assertEquals(StateMachineLogEntryProto.Type.DATASTREAM, s.getType());
Assertions.assertEquals(request.getCallId(), s.getCallId());
Assertions.assertEquals(request.getClientId().toByteString(), s.getClientId());
}
static void assertLogEntry(RaftServer.Division division, SingleDataStream stream) throws Exception {
final RaftClientRequest request = stream.getWriteRequest();
final LogEntryProto entryFromStream = stream.getLogEntry();
assertLogEntry(entryFromStream, request);
final LogEntryProto entryFromLog = searchLogEntry(ClientInvocationId.valueOf(request), division.getRaftLog());
Assertions.assertEquals(entryFromStream, entryFromLog);
}
}