blob: 7e0e72a5b67811e518fe68cff627149b826258a8 [file] [log] [blame]
package org.apache.airavata.datalake.metadata.interceptors;
import io.grpc.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Iterator;
import java.util.Stack;
/**
* This class execute interceptor stack sequentially
*/
public class InterceptorPipelineExecutor implements ServerInterceptor {
private final Logger LOGGER = LoggerFactory.getLogger(InterceptorPipelineExecutor.class);
private Stack<ServiceInterceptor> interceptorSet;
public InterceptorPipelineExecutor(Stack<ServiceInterceptor> integrationServiceInterceptors) {
this.interceptorSet = integrationServiceInterceptors;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> serverCall,
Metadata metadata, ServerCallHandler<ReqT, RespT> serverCallHandler) {
String fullMethod = serverCall.getMethodDescriptor().getFullMethodName();
String methodName = fullMethod.split("/")[1];
String serviceName = fullMethod.split("/")[0];
LOGGER.debug("Calling method : " + serverCall.getMethodDescriptor().getFullMethodName());
metadata.put(Metadata.Key.of("service-name", Metadata.ASCII_STRING_MARSHALLER), serviceName);
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(serverCallHandler.startCall(serverCall, metadata)) {
ReqT resp = null;
@Override
public void onMessage(ReqT message) {
try {
Iterator it = interceptorSet.iterator();
while (it.hasNext()) {
ServiceInterceptor interceptor = (ServiceInterceptor) it.next();
resp = interceptor.intercept(methodName, metadata, (resp == null) ? message : resp);
}
super.onMessage(resp);
} catch (Exception ex) {
String msg = "Error while validating method " + methodName + ", " + ex.getMessage();
LOGGER.error(msg, ex);
serverCall.close(Status.INVALID_ARGUMENT.withDescription(msg), new Metadata());
}
}
@Override
public void onHalfClose() {
try {
super.onHalfClose();
} catch (IllegalStateException e) {
LOGGER.debug(e.getMessage());
} catch (Exception e) {
String msg = "Error while validating method " + methodName + ", " + e.getMessage();
LOGGER.error(msg);
serverCall.close(Status.INVALID_ARGUMENT.withDescription(msg), metadata);
}
}
};
}
}