blob: 5a7f003183332239aae52786bb5b35f2fbc07115 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ratis.grpc;
import java.util.function.Consumer;
import org.apache.ratis.protocol.exceptions.ServerNotReadyException;
import org.apache.ratis.protocol.exceptions.TimeoutIOException;
import org.apache.ratis.security.TlsConf.TrustManagerConf;
import org.apache.ratis.security.TlsConf.CertificatesConf;
import org.apache.ratis.security.TlsConf.PrivateKeyConf;
import org.apache.ratis.security.TlsConf.KeyManagerConf;
import org.apache.ratis.thirdparty.com.google.protobuf.MessageLite;
import org.apache.ratis.thirdparty.io.grpc.ManagedChannel;
import org.apache.ratis.thirdparty.io.grpc.Metadata;
import org.apache.ratis.thirdparty.io.grpc.MethodDescriptor;
import org.apache.ratis.thirdparty.io.grpc.ServerCallHandler;
import org.apache.ratis.thirdparty.io.grpc.ServerServiceDefinition;
import org.apache.ratis.thirdparty.io.grpc.Status;
import org.apache.ratis.thirdparty.io.grpc.StatusRuntimeException;
import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder;
import org.apache.ratis.util.IOUtils;
import org.apache.ratis.util.JavaUtils;
import org.apache.ratis.util.LogUtils;
import org.apache.ratis.util.ReflectionUtils;
import org.apache.ratis.util.function.CheckedSupplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ssl.KeyManager;
import javax.net.ssl.TrustManager;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Supplier;
public interface GrpcUtil {
Logger LOG = LoggerFactory.getLogger(GrpcUtil.class);
Metadata.Key<String> EXCEPTION_TYPE_KEY =
Metadata.Key.of("exception-type", Metadata.ASCII_STRING_MARSHALLER);
Metadata.Key<byte[]> EXCEPTION_OBJECT_KEY =
Metadata.Key.of("exception-object-bin", Metadata.BINARY_BYTE_MARSHALLER);
Metadata.Key<String> CALL_ID =
Metadata.Key.of("call-id", Metadata.ASCII_STRING_MARSHALLER);
Metadata.Key<String> HEARTBEAT =
Metadata.Key.of("heartbeat", Metadata.ASCII_STRING_MARSHALLER);
static StatusRuntimeException wrapException(Throwable t) {
t = JavaUtils.unwrapCompletionException(t);
Metadata trailers = new StatusRuntimeExceptionMetadataBuilder(t).build();
return wrapException(t, trailers);
}
static StatusRuntimeException wrapException(Throwable t, long callId, boolean isHeartbeat) {
t = JavaUtils.unwrapCompletionException(t);
Metadata trailers = new StatusRuntimeExceptionMetadataBuilder(t)
.addCallId(callId)
.addIsHeartbeat(isHeartbeat)
.build();
return wrapException(t, trailers);
}
static StatusRuntimeException wrapException(Throwable t, Metadata trailers) {
return new StatusRuntimeException(
Status.INTERNAL.withCause(t).withDescription(t.getMessage()), trailers);
}
static Throwable unwrapThrowable(Throwable t) {
if (t instanceof StatusRuntimeException) {
final Throwable unwrapped = tryUnwrapThrowable((StatusRuntimeException)t);
if (unwrapped != null) {
return unwrapped;
}
}
return t;
}
static IOException unwrapException(StatusRuntimeException se) {
final Throwable t = tryUnwrapThrowable(se);
return t instanceof IOException? (IOException) t: new IOException(t != null? t: se);
}
static Throwable tryUnwrapThrowable(StatusRuntimeException se) {
final Status status = se.getStatus();
if (status != null && status.getCode() == Status.Code.DEADLINE_EXCEEDED) {
return new TimeoutIOException(status.getDescription(), se);
}
final Metadata trailers = se.getTrailers();
if (trailers == null) {
return null;
}
final byte[] bytes = trailers.get(EXCEPTION_OBJECT_KEY);
if (bytes != null) {
try {
return IOUtils.bytes2Object(bytes, Throwable.class);
} catch (Exception e) {
se.addSuppressed(e);
}
}
if (status != null) {
final String className = trailers.get(EXCEPTION_TYPE_KEY);
if (className != null) {
try {
final Class<? extends Throwable> clazz = Class.forName(className).asSubclass(Throwable.class);
final Throwable unwrapped = ReflectionUtils.instantiateException(clazz, status.getDescription());
return unwrapped.initCause(se);
} catch (Throwable e) {
se.addSuppressed(e);
return new IOException(se);
}
}
}
return null;
}
static long getCallId(Throwable t) {
if (t instanceof StatusRuntimeException) {
final Metadata trailers = ((StatusRuntimeException)t).getTrailers();
String callId = trailers.get(CALL_ID);
return callId != null ? Integer.parseInt(callId) : -1;
}
return -1;
}
static boolean isHeartbeat(Throwable t) {
if (t instanceof StatusRuntimeException) {
final Metadata trailers = ((StatusRuntimeException)t).getTrailers();
String isHeartbeat = trailers != null ? trailers.get(HEARTBEAT) : null;
return isHeartbeat != null && Boolean.valueOf(isHeartbeat);
}
return false;
}
static IOException unwrapIOException(Throwable t) {
final IOException e;
if (t instanceof StatusRuntimeException) {
e = GrpcUtil.unwrapException((StatusRuntimeException) t);
} else {
e = IOUtils.asIOException(t);
}
return e;
}
static <REPLY, REPLY_PROTO> void asyncCall(
StreamObserver<REPLY_PROTO> responseObserver,
CheckedSupplier<CompletableFuture<REPLY>, IOException> supplier,
Function<REPLY, REPLY_PROTO> toProto) {
asyncCall(responseObserver, supplier, toProto, throwable -> {});
}
static <REPLY, REPLY_PROTO> void asyncCall(
StreamObserver<REPLY_PROTO> responseObserver,
CheckedSupplier<CompletableFuture<REPLY>, IOException> supplier,
Function<REPLY, REPLY_PROTO> toProto,
Consumer<Throwable> warning) {
try {
supplier.get().whenComplete((reply, exception) -> {
if (exception != null) {
warning.accept(exception);
responseObserver.onError(GrpcUtil.wrapException(exception));
} else {
responseObserver.onNext(toProto.apply(reply));
responseObserver.onCompleted();
}
});
} catch (Exception e) {
warning.accept(e);
responseObserver.onError(GrpcUtil.wrapException(e));
}
}
static void warn(Logger log, Supplier<String> message, Throwable t) {
LogUtils.warn(log, message, unwrapThrowable(t), StatusRuntimeException.class, ServerNotReadyException.class);
}
class StatusRuntimeExceptionMetadataBuilder {
private Metadata trailers = new Metadata();
StatusRuntimeExceptionMetadataBuilder(Throwable t) {
trailers.put(EXCEPTION_TYPE_KEY, t.getClass().getCanonicalName());
trailers.put(EXCEPTION_OBJECT_KEY, IOUtils.object2Bytes(t));
}
StatusRuntimeExceptionMetadataBuilder addCallId(long callId) {
if (callId > 0) {
trailers.put(CALL_ID, String.valueOf(callId));
}
return this;
}
StatusRuntimeExceptionMetadataBuilder addIsHeartbeat(boolean isHeartbeat) {
trailers.put(HEARTBEAT, String.valueOf(isHeartbeat));
return this;
}
Metadata build() {
return trailers;
}
}
/**
* Tries to gracefully shut down the managed channel. Falls back to forceful shutdown if
* graceful shutdown times out.
*/
static void shutdownManagedChannel(ManagedChannel managedChannel) {
// Close the gRPC managed-channel if not shut down already.
if (!managedChannel.isShutdown()) {
try {
managedChannel.shutdown();
if (!managedChannel.awaitTermination(3, TimeUnit.SECONDS)) {
LOG.warn("Timed out gracefully shutting down connection: {}. ", managedChannel);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (Exception e) {
LOG.error("Unexpected exception while waiting for channel termination", e);
}
}
// Forceful shut down if still not terminated.
if (!managedChannel.isTerminated()) {
try {
managedChannel.shutdownNow();
if (!managedChannel.awaitTermination(2, TimeUnit.SECONDS)) {
LOG.warn("Timed out forcefully shutting down connection: {}. ", managedChannel);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (Exception e) {
LOG.error("Unexpected exception while waiting for channel termination", e);
}
}
}
static SslContextBuilder initSslContextBuilderForServer(KeyManagerConf keyManagerConfig) {
final KeyManager keyManager = keyManagerConfig.getKeyManager();
if (keyManager != null) {
return SslContextBuilder.forServer(keyManager);
}
final PrivateKeyConf privateKey = keyManagerConfig.getPrivateKey();
final CertificatesConf certificates = keyManagerConfig.getKeyCertificates();
if (keyManagerConfig.isFileBased()) {
return SslContextBuilder.forServer(certificates.getFile(), privateKey.getFile());
} else {
return SslContextBuilder.forServer(privateKey.get(), certificates.get());
}
}
static void setTrustManager(SslContextBuilder b, TrustManagerConf trustManagerConfig) {
if (trustManagerConfig == null) {
return;
}
final TrustManager trustManager = trustManagerConfig.getTrustManager();
if (trustManager != null) {
b.trustManager(trustManager);
return;
}
final CertificatesConf certificates = trustManagerConfig.getTrustCertificates();
if (certificates.isFileBased()) {
b.trustManager(certificates.getFile());
} else {
b.trustManager(certificates.get());
}
}
static void setKeyManager(SslContextBuilder b, KeyManagerConf keyManagerConfig) {
if (keyManagerConfig == null) {
return;
}
final KeyManager keyManager = keyManagerConfig.getKeyManager();
if (keyManager != null) {
b.keyManager(keyManager);
return;
}
final PrivateKeyConf privateKey = keyManagerConfig.getPrivateKey();
final CertificatesConf certificates = keyManagerConfig.getKeyCertificates();
if (keyManagerConfig.isFileBased()) {
b.keyManager(certificates.getFile(), privateKey.getFile());
} else {
b.keyManager(privateKey.get(), certificates.get());
}
}
/**
* Used to add a method to Service definition with a custom request marshaller.
*
* @param orig original service definition.
* @param newServiceBuilder builder of the new service definition.
* @param origMethod the original method definition.
* @param customMarshaller custom marshaller to be set for the method.
* @param <Req>
* @param <Resp>
*/
static <Req extends MessageLite, Resp> void addMethodWithCustomMarshaller(
ServerServiceDefinition orig, ServerServiceDefinition.Builder newServiceBuilder,
MethodDescriptor<Req, Resp> origMethod, MethodDescriptor.PrototypeMarshaller<Req> customMarshaller) {
MethodDescriptor<Req, Resp> newMethod = origMethod.toBuilder()
.setRequestMarshaller(customMarshaller)
.build();
@SuppressWarnings("unchecked")
ServerCallHandler<Req, Resp> serverCallHandler =
(ServerCallHandler<Req, Resp>) orig.getMethod(newMethod.getFullMethodName()).getServerCallHandler();
newServiceBuilder.addMethod(newMethod, serverCallHandler);
}
}