| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.ratis.grpc.server; |
| |
| import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; |
| import org.apache.ratis.grpc.GrpcUtil; |
| import org.apache.ratis.protocol.RaftPeerId; |
| import org.apache.ratis.server.RaftServer; |
| import org.apache.ratis.server.protocol.RaftServerProtocol; |
| import org.apache.ratis.server.util.ServerStringUtils; |
| import org.apache.ratis.thirdparty.io.grpc.Status; |
| import org.apache.ratis.thirdparty.io.grpc.StatusRuntimeException; |
| import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver; |
| import org.apache.ratis.proto.RaftProtos.*; |
| import org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.RaftServerProtocolServiceImplBase; |
| import org.apache.ratis.util.ProtoUtils; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| import java.io.IOException; |
| import java.util.Optional; |
| import java.util.concurrent.CompletableFuture; |
| import java.util.concurrent.atomic.AtomicBoolean; |
| import java.util.concurrent.atomic.AtomicReference; |
| import java.util.function.Supplier; |
| |
| class GrpcServerProtocolService extends RaftServerProtocolServiceImplBase { |
| public static final Logger LOG = LoggerFactory.getLogger(GrpcServerProtocolService.class); |
| |
| static class PendingServerRequest<REQUEST> { |
| private final REQUEST request; |
| private final CompletableFuture<Void> future = new CompletableFuture<>(); |
| |
| PendingServerRequest(REQUEST request) { |
| this.request = request; |
| } |
| |
| REQUEST getRequest() { |
| return request; |
| } |
| |
| CompletableFuture<Void> getFuture() { |
| return future; |
| } |
| } |
| |
| abstract class ServerRequestStreamObserver<REQUEST, REPLY> implements StreamObserver<REQUEST> { |
| private final RaftServer.Op op; |
| private final StreamObserver<REPLY> responseObserver; |
| private final AtomicReference<PendingServerRequest<REQUEST>> previousOnNext = new AtomicReference<>(); |
| private final AtomicBoolean isClosed = new AtomicBoolean(false); |
| |
| ServerRequestStreamObserver(RaftServer.Op op, StreamObserver<REPLY> responseObserver) { |
| this.op = op; |
| this.responseObserver = responseObserver; |
| } |
| |
| private String getPreviousRequestString() { |
| return Optional.ofNullable(previousOnNext.get()) |
| .map(PendingServerRequest::getRequest) |
| .map(this::requestToString) |
| .orElse(null); |
| } |
| |
| abstract CompletableFuture<REPLY> process(REQUEST request) throws IOException; |
| |
| abstract long getCallId(REQUEST request); |
| |
| abstract String requestToString(REQUEST request); |
| |
| abstract String replyToString(REPLY reply); |
| |
| abstract boolean replyInOrder(REQUEST request); |
| |
| StatusRuntimeException wrapException(Throwable e, REQUEST request) { |
| return GrpcUtil.wrapException(e, getCallId(request)); |
| } |
| |
| private void handleError(Throwable e, REQUEST request) { |
| GrpcUtil.warn(LOG, () -> getId() + ": Failed " + op + " request " + requestToString(request), e); |
| responseObserver.onError(wrapException(e, request)); |
| } |
| |
| private synchronized void handleReply(REPLY reply) { |
| if (!isClosed.get()) { |
| if (LOG.isDebugEnabled()) { |
| LOG.debug("{}: reply {}", getId(), replyToString(reply)); |
| } |
| responseObserver.onNext(reply); |
| } |
| } |
| |
| @Override |
| @SuppressFBWarnings("NP_NULL_PARAM_DEREF") |
| public void onNext(REQUEST request) { |
| if (!replyInOrder(request)) { |
| try { |
| process(request).thenAccept(this::handleReply); |
| } catch (Exception e) { |
| handleError(e, request); |
| } |
| return; |
| } |
| |
| final PendingServerRequest<REQUEST> current = new PendingServerRequest<>(request); |
| final PendingServerRequest<REQUEST> previous = previousOnNext.getAndSet(current); |
| final CompletableFuture<Void> previousFuture = Optional.ofNullable(previous) |
| .map(PendingServerRequest::getFuture) |
| .orElse(CompletableFuture.completedFuture(null)); |
| try { |
| process(request).thenCombine(previousFuture, (reply, v) -> { |
| handleReply(reply); |
| current.getFuture().complete(null); |
| return null; |
| }); |
| } catch (Exception e) { |
| handleError(e, request); |
| current.getFuture().completeExceptionally(e); |
| } |
| } |
| |
| @Override |
| public void onCompleted() { |
| if (isClosed.compareAndSet(false, true)) { |
| LOG.info("{}: Completed {}, lastRequest: {}", getId(), op, getPreviousRequestString()); |
| responseObserver.onCompleted(); |
| } |
| } |
| @Override |
| public void onError(Throwable t) { |
| GrpcUtil.warn(LOG, () -> getId() + ": installSnapshot onError, lastRequest: " + getPreviousRequestString(), t); |
| if (isClosed.compareAndSet(false, true)) { |
| Status status = Status.fromThrowable(t); |
| if (status != null && status.getCode() != Status.Code.CANCELLED) { |
| responseObserver.onCompleted(); |
| } |
| } |
| } |
| } |
| |
| private final Supplier<RaftPeerId> idSupplier; |
| private final RaftServer server; |
| |
| GrpcServerProtocolService(Supplier<RaftPeerId> idSupplier, RaftServer server) { |
| this.idSupplier = idSupplier; |
| this.server = server; |
| } |
| |
| RaftPeerId getId() { |
| return idSupplier.get(); |
| } |
| |
| @Override |
| public void requestVote(RequestVoteRequestProto request, |
| StreamObserver<RequestVoteReplyProto> responseObserver) { |
| try { |
| final RequestVoteReplyProto reply = server.requestVote(request); |
| responseObserver.onNext(reply); |
| responseObserver.onCompleted(); |
| } catch (Exception e) { |
| GrpcUtil.warn(LOG, () -> getId() + ": Failed requestVote " + ProtoUtils.toString(request.getServerRequest()), e); |
| responseObserver.onError(GrpcUtil.wrapException(e)); |
| } |
| } |
| |
| @Override |
| public void startLeaderElection(StartLeaderElectionRequestProto request, |
| StreamObserver<StartLeaderElectionReplyProto> responseObserver) { |
| try { |
| final StartLeaderElectionReplyProto reply = server.startLeaderElection(request); |
| responseObserver.onNext(reply); |
| responseObserver.onCompleted(); |
| } catch (Throwable e) { |
| GrpcUtil.warn(LOG, |
| () -> getId() + ": Failed startLeaderElection " + ProtoUtils.toString(request.getServerRequest()), e); |
| responseObserver.onError(GrpcUtil.wrapException(e)); |
| } |
| } |
| |
| @Override |
| public StreamObserver<AppendEntriesRequestProto> appendEntries( |
| StreamObserver<AppendEntriesReplyProto> responseObserver) { |
| return new ServerRequestStreamObserver<AppendEntriesRequestProto, AppendEntriesReplyProto>( |
| RaftServerProtocol.Op.APPEND_ENTRIES, responseObserver) { |
| @Override |
| CompletableFuture<AppendEntriesReplyProto> process(AppendEntriesRequestProto request) throws IOException { |
| return server.appendEntriesAsync(request); |
| } |
| |
| @Override |
| long getCallId(AppendEntriesRequestProto request) { |
| return request.getServerRequest().getCallId(); |
| } |
| |
| @Override |
| String requestToString(AppendEntriesRequestProto request) { |
| return ServerStringUtils.toAppendEntriesRequestString(request); |
| } |
| |
| @Override |
| String replyToString(AppendEntriesReplyProto reply) { |
| return ServerStringUtils.toAppendEntriesReplyString(reply); |
| } |
| |
| @Override |
| boolean replyInOrder(AppendEntriesRequestProto request) { |
| return request.getEntriesCount() != 0; |
| } |
| |
| @Override |
| StatusRuntimeException wrapException(Throwable e, AppendEntriesRequestProto request) { |
| return GrpcUtil.wrapException(e, getCallId(request), request.getEntriesCount() == 0); |
| } |
| }; |
| } |
| |
| @Override |
| public StreamObserver<InstallSnapshotRequestProto> installSnapshot( |
| StreamObserver<InstallSnapshotReplyProto> responseObserver) { |
| return new ServerRequestStreamObserver<InstallSnapshotRequestProto, InstallSnapshotReplyProto>( |
| RaftServerProtocol.Op.INSTALL_SNAPSHOT, responseObserver) { |
| @Override |
| CompletableFuture<InstallSnapshotReplyProto> process(InstallSnapshotRequestProto request) throws IOException { |
| return CompletableFuture.completedFuture(server.installSnapshot(request)); |
| } |
| |
| @Override |
| long getCallId(InstallSnapshotRequestProto request) { |
| return request.getServerRequest().getCallId(); |
| } |
| |
| @Override |
| String requestToString(InstallSnapshotRequestProto request) { |
| return ServerStringUtils.toInstallSnapshotRequestString(request); |
| } |
| |
| @Override |
| String replyToString(InstallSnapshotReplyProto reply) { |
| return ServerStringUtils.toInstallSnapshotReplyString(reply); |
| } |
| |
| @Override |
| boolean replyInOrder(InstallSnapshotRequestProto installSnapshotRequestProto) { |
| return true; |
| } |
| }; |
| } |
| } |