blob: 00cb699cfccf4fb8b1c9b3fbbac62f522f9f5caf [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ratis.grpc.server;
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.ratis.grpc.GrpcUtil;
import org.apache.ratis.grpc.metrics.ZeroCopyMetrics;
import org.apache.ratis.grpc.util.ZeroCopyMessageMarshaller;
import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.server.RaftServer;
import org.apache.ratis.server.protocol.RaftServerProtocol;
import org.apache.ratis.server.util.ServerStringUtils;
import org.apache.ratis.thirdparty.io.grpc.ServerServiceDefinition;
import org.apache.ratis.thirdparty.io.grpc.Status;
import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
import org.apache.ratis.proto.RaftProtos.*;
import org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.RaftServerProtocolServiceImplBase;
import org.apache.ratis.util.ProtoUtils;
import org.apache.ratis.util.ReferenceCountedObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import static org.apache.ratis.grpc.GrpcUtil.addMethodWithCustomMarshaller;
import static org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.getAppendEntriesMethod;
class GrpcServerProtocolService extends RaftServerProtocolServiceImplBase {
public static final Logger LOG = LoggerFactory.getLogger(GrpcServerProtocolService.class);
static class PendingServerRequest<REQUEST> {
private final AtomicReference<ReferenceCountedObject<REQUEST>> requestRef;
private final CompletableFuture<Void> future = new CompletableFuture<>();
PendingServerRequest(ReferenceCountedObject<REQUEST> requestRef) {
requestRef.retain();
this.requestRef = new AtomicReference<>(requestRef);
}
REQUEST getRequest() {
return Optional.ofNullable(requestRef.get())
.map(ReferenceCountedObject::get)
.orElse(null);
}
CompletableFuture<Void> getFuture() {
return future;
}
void release() {
Optional.ofNullable(requestRef.getAndSet(null))
.ifPresent(ReferenceCountedObject::release);
}
}
abstract class ServerRequestStreamObserver<REQUEST, REPLY> implements StreamObserver<REQUEST> {
private final RaftServer.Op op;
private final StreamObserver<REPLY> responseObserver;
/** For ordered {@link #onNext(Object)} requests. */
private final AtomicReference<PendingServerRequest<REQUEST>> previousOnNext = new AtomicReference<>();
/** For both ordered and unordered {@link #onNext(Object)} requests. */
private final AtomicReference<CompletableFuture<REPLY>> requestFuture
= new AtomicReference<>(CompletableFuture.completedFuture(null));
private final AtomicBoolean isClosed = new AtomicBoolean(false);
ServerRequestStreamObserver(RaftServer.Op op, StreamObserver<REPLY> responseObserver) {
this.op = op;
this.responseObserver = responseObserver;
}
private String getPreviousRequestString() {
return Optional.ofNullable(previousOnNext.get())
.map(PendingServerRequest::getRequest)
.map(this::requestToString)
.orElse(null);
}
CompletableFuture<REPLY> process(REQUEST request) throws IOException {
throw new UnsupportedOperationException("This method is not supported.");
}
CompletableFuture<REPLY> process(ReferenceCountedObject<REQUEST> requestRef)
throws IOException {
try {
return process(requestRef.retain());
} finally {
requestRef.release();
}
}
void release(REQUEST req) {
}
abstract long getCallId(REQUEST request);
boolean isHeartbeat(REQUEST request) {
return false;
}
abstract String requestToString(REQUEST request);
abstract String replyToString(REPLY reply);
abstract boolean replyInOrder(REQUEST request);
private synchronized void handleError(Throwable e, long callId, boolean isHeartbeat) {
GrpcUtil.warn(LOG, () -> getId() + ": Failed " + op + " request cid=" + callId + ", isHeartbeat? "
+ isHeartbeat, e);
if (isClosed.compareAndSet(false, true)) {
responseObserver.onError(GrpcUtil.wrapException(e, callId, isHeartbeat));
}
}
private synchronized REPLY handleReply(REPLY reply) {
if (!isClosed.get()) {
if (LOG.isDebugEnabled()) {
LOG.debug("{}: reply {}", getId(), replyToString(reply));
}
responseObserver.onNext(reply);
}
return reply;
}
void composeRequest(CompletableFuture<REPLY> current) {
requestFuture.updateAndGet(previous -> previous.thenCompose(reply -> current));
}
@Override
public void onNext(REQUEST request) {
ReferenceCountedObject<REQUEST> requestRef = ReferenceCountedObject.wrap(request, () -> {}, released -> {
if (released) {
release(request);
}
});
if (!replyInOrder(request)) {
try {
composeRequest(process(requestRef).thenApply(this::handleReply));
} catch (Exception e) {
handleError(e, getCallId(request), isHeartbeat(request));
release(request);
}
return;
}
final PendingServerRequest<REQUEST> current = new PendingServerRequest<>(requestRef);
final long callId = getCallId(current.getRequest());
final boolean isHeartbeat = isHeartbeat(current.getRequest());
final Optional<PendingServerRequest<REQUEST>> previous = Optional.ofNullable(previousOnNext.getAndSet(current));
final CompletableFuture<Void> previousFuture = previous.map(PendingServerRequest::getFuture)
.orElse(CompletableFuture.completedFuture(null));
try {
final CompletableFuture<REPLY> f = process(requestRef).exceptionally(e -> {
// Handle cases, such as RaftServer is paused
handleError(e, callId, isHeartbeat);
current.getFuture().completeExceptionally(e);
return null;
}).thenCombine(previousFuture, (reply, v) -> {
handleReply(reply);
current.getFuture().complete(null);
return null;
});
composeRequest(f);
} catch (Exception e) {
handleError(e, callId, isHeartbeat);
current.getFuture().completeExceptionally(e);
} finally {
previous.ifPresent(PendingServerRequest::release);
if (isClosed.get()) {
// Some requests may come after onCompleted or onError, ensure they're released.
releaseLast();
}
}
}
@Override
public void onCompleted() {
if (isClosed.compareAndSet(false, true)) {
LOG.info("{}: Completed {}, lastRequest: {}", getId(), op, getPreviousRequestString());
requestFuture.get().thenAccept(reply -> {
LOG.info("{}: Completed {}, lastReply: {}", getId(), op, reply);
responseObserver.onCompleted();
});
releaseLast();
}
}
@Override
public void onError(Throwable t) {
GrpcUtil.warn(LOG, () -> getId() + ": "+ op + " onError, lastRequest: " + getPreviousRequestString(), t);
if (isClosed.compareAndSet(false, true)) {
Status status = Status.fromThrowable(t);
if (status != null && status.getCode() != Status.Code.CANCELLED) {
responseObserver.onCompleted();
}
releaseLast();
}
}
private void releaseLast() {
Optional.ofNullable(previousOnNext.get()).ifPresent(PendingServerRequest::release);
}
}
private final Supplier<RaftPeerId> idSupplier;
private final RaftServer server;
private final ZeroCopyMessageMarshaller<AppendEntriesRequestProto> zeroCopyRequestMarshaller;
GrpcServerProtocolService(Supplier<RaftPeerId> idSupplier, RaftServer server, ZeroCopyMetrics zeroCopyMetrics) {
this.idSupplier = idSupplier;
this.server = server;
this.zeroCopyRequestMarshaller = new ZeroCopyMessageMarshaller<>(AppendEntriesRequestProto.getDefaultInstance(),
zeroCopyMetrics::onZeroCopyMessage, zeroCopyMetrics::onNonZeroCopyMessage, zeroCopyMetrics::onReleasedMessage);
}
RaftPeerId getId() {
return idSupplier.get();
}
ServerServiceDefinition bindServiceWithZeroCopy() {
ServerServiceDefinition orig = super.bindService();
ServerServiceDefinition.Builder builder = ServerServiceDefinition.builder(orig.getServiceDescriptor().getName());
// Add appendEntries with zero copy marshaller.
addMethodWithCustomMarshaller(orig, builder, getAppendEntriesMethod(), zeroCopyRequestMarshaller);
// Add remaining methods as is.
orig.getMethods().stream().filter(
x -> !x.getMethodDescriptor().getFullMethodName().equals(getAppendEntriesMethod().getFullMethodName())
).forEach(
builder::addMethod
);
return builder.build();
}
@Override
public void requestVote(RequestVoteRequestProto request,
StreamObserver<RequestVoteReplyProto> responseObserver) {
try {
final RequestVoteReplyProto reply = server.requestVote(request);
responseObserver.onNext(reply);
responseObserver.onCompleted();
} catch (Exception e) {
GrpcUtil.warn(LOG, () -> getId() + ": Failed requestVote " + ProtoUtils.toString(request.getServerRequest()), e);
responseObserver.onError(GrpcUtil.wrapException(e));
}
}
@Override
public void startLeaderElection(StartLeaderElectionRequestProto request,
StreamObserver<StartLeaderElectionReplyProto> responseObserver) {
try {
final StartLeaderElectionReplyProto reply = server.startLeaderElection(request);
responseObserver.onNext(reply);
responseObserver.onCompleted();
} catch (Throwable e) {
GrpcUtil.warn(LOG,
() -> getId() + ": Failed startLeaderElection " + ProtoUtils.toString(request.getServerRequest()), e);
responseObserver.onError(GrpcUtil.wrapException(e));
}
}
@Override
public void readIndex(ReadIndexRequestProto request, StreamObserver<ReadIndexReplyProto> responseObserver) {
final Consumer<Throwable> warning = e -> GrpcUtil.warn(LOG,
() -> getId() + ": Failed readIndex " + ProtoUtils.toString(request.getServerRequest()), e);
GrpcUtil.asyncCall(responseObserver, () -> server.readIndexAsync(request), Function.identity(), warning);
}
@Override
public StreamObserver<AppendEntriesRequestProto> appendEntries(
StreamObserver<AppendEntriesReplyProto> responseObserver) {
return new ServerRequestStreamObserver<AppendEntriesRequestProto, AppendEntriesReplyProto>(
RaftServerProtocol.Op.APPEND_ENTRIES, responseObserver) {
@Override
CompletableFuture<AppendEntriesReplyProto> process(ReferenceCountedObject<AppendEntriesRequestProto> requestRef)
throws IOException {
return server.appendEntriesAsync(requestRef);
}
@Override
void release(AppendEntriesRequestProto req) {
zeroCopyRequestMarshaller.release(req);
}
@Override
long getCallId(AppendEntriesRequestProto request) {
return request.getServerRequest().getCallId();
}
@Override
boolean isHeartbeat(AppendEntriesRequestProto request) {
return request.getEntriesCount() == 0;
}
@Override
String requestToString(AppendEntriesRequestProto request) {
return ServerStringUtils.toAppendEntriesRequestString(request, null);
}
@Override
String replyToString(AppendEntriesReplyProto reply) {
return ServerStringUtils.toAppendEntriesReplyString(reply);
}
@Override
boolean replyInOrder(AppendEntriesRequestProto request) {
return request.getEntriesCount() != 0;
}
};
}
@Override
public StreamObserver<InstallSnapshotRequestProto> installSnapshot(
StreamObserver<InstallSnapshotReplyProto> responseObserver) {
return new ServerRequestStreamObserver<InstallSnapshotRequestProto, InstallSnapshotReplyProto>(
RaftServerProtocol.Op.INSTALL_SNAPSHOT, responseObserver) {
@Override
CompletableFuture<InstallSnapshotReplyProto> process(InstallSnapshotRequestProto request) throws IOException {
return CompletableFuture.completedFuture(server.installSnapshot(request));
}
@Override
long getCallId(InstallSnapshotRequestProto request) {
return request.getServerRequest().getCallId();
}
@Override
String requestToString(InstallSnapshotRequestProto request) {
return ServerStringUtils.toInstallSnapshotRequestString(request);
}
@Override
String replyToString(InstallSnapshotReplyProto reply) {
return ServerStringUtils.toInstallSnapshotReplyString(reply);
}
@Override
boolean replyInOrder(InstallSnapshotRequestProto installSnapshotRequestProto) {
return true;
}
};
}
}