blob: 01adddc8f3062c0543f1aabb39b959f056d0890a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
using System;
using System.Threading.Tasks;
using Grpc.Core;
using Grpc.Core.Interceptors;
using NLog;
namespace Org.Apache.Rocketmq
{
public class ClientLoggerInterceptor : Interceptor
{
private static readonly Logger Logger = MqLogManager.Instance.GetCurrentClassLogger();
public override TResponse BlockingUnaryCall<TRequest, TResponse>(
TRequest request,
ClientInterceptorContext<TRequest, TResponse> context,
BlockingUnaryCallContinuation<TRequest, TResponse> continuation)
{
LogCall(context.Method);
AddCallerMetadata(ref context);
return continuation(request, context);
}
public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(
TRequest request,
ClientInterceptorContext<TRequest, TResponse> context,
AsyncUnaryCallContinuation<TRequest, TResponse> continuation)
{
LogCall(context.Method);
AddCallerMetadata(ref context);
var call = continuation(request, context);
return new AsyncUnaryCall<TResponse>(HandleResponse(call.ResponseAsync), call.ResponseHeadersAsync, call.GetStatus, call.GetTrailers, call.Dispose);
}
private async Task<TResponse> HandleResponse<TResponse>(Task<TResponse> t)
{
try
{
var response = await t;
Logger.Debug($"Response received: {response}");
return response;
}
catch (Exception ex)
{
Logger.Error($"Call error: {ex.Message}");
throw;
}
}
public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(
ClientInterceptorContext<TRequest, TResponse> context,
AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation)
{
LogCall(context.Method);
AddCallerMetadata(ref context);
return continuation(context);
}
public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(
TRequest request,
ClientInterceptorContext<TRequest, TResponse> context,
AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation)
{
LogCall(context.Method);
AddCallerMetadata(ref context);
return continuation(request, context);
}
public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreamingCall<TRequest, TResponse>(
ClientInterceptorContext<TRequest, TResponse> context,
AsyncDuplexStreamingCallContinuation<TRequest, TResponse> continuation)
{
LogCall(context.Method);
AddCallerMetadata(ref context);
return continuation(context);
}
private void LogCall<TRequest, TResponse>(Method<TRequest, TResponse> method)
where TRequest : class
where TResponse : class
{
Logger.Debug($"Starting call. Type: {method.Type}. Request: {typeof(TRequest)}. Response: {typeof(TResponse)}");
}
private void AddCallerMetadata<TRequest, TResponse>(ref ClientInterceptorContext<TRequest, TResponse> context)
where TRequest : class
where TResponse : class
{
var headers = context.Options.Headers;
// Call doesn't have a headers collection to add to.
// Need to create a new context with headers for the call.
if (headers == null)
{
headers = new Metadata();
var options = context.Options.WithHeaders(headers);
context = new ClientInterceptorContext<TRequest, TResponse>(context.Method, context.Host, options);
}
// Add caller metadata to call headers
headers.Add("caller-user", Environment.UserName);
headers.Add("caller-machine", Environment.MachineName);
headers.Add("caller-os", Environment.OSVersion.ToString());
}
}
}