YARN-10883. [Router] Router Audit Log Add Client IP Address. (#4426)

diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-router/src/main/java/org/apache/hadoop/yarn/server/router/RouterAuditLogger.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-router/src/main/java/org/apache/hadoop/yarn/server/router/RouterAuditLogger.java
index cd5b0c9..cc82087 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-router/src/main/java/org/apache/hadoop/yarn/server/router/RouterAuditLogger.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-router/src/main/java/org/apache/hadoop/yarn/server/router/RouterAuditLogger.java
@@ -18,11 +18,14 @@
 
 package org.apache.hadoop.yarn.server.router;
 
+import org.apache.hadoop.ipc.Server;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.server.federation.store.records.SubClusterId;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.net.InetAddress;
+
 /**
  * Manages Router audit logs.
  * Audit log format is written as key=value pairs. Tab separated.
@@ -111,6 +114,7 @@
       String operation, String target) {
     StringBuilder b = new StringBuilder();
     start(Keys.USER, user, b);
+    addRemoteIP(b);
     add(Keys.OPERATION, operation, b);
     add(Keys.TARGET, target, b);
     add(Keys.RESULT, AuditConstants.SUCCESS, b);
@@ -216,6 +220,7 @@
       String operation, String target, String description, String perm) {
     StringBuilder b = new StringBuilder();
     start(Keys.USER, user, b);
+    addRemoteIP(b);
     add(Keys.OPERATION, operation, b);
     add(Keys.TARGET, target, b);
     add(Keys.RESULT, AuditConstants.FAILURE, b);
@@ -240,4 +245,15 @@
     b.append(AuditConstants.PAIR_SEPARATOR).append(key.name())
         .append(AuditConstants.KEY_VAL_SEPARATOR).append(value);
   }
+
+  /**
+   * A helper api to add remote IP address.
+   */
+  static void addRemoteIP(StringBuilder b) {
+    InetAddress ip = Server.getRemoteIp();
+    // ip address can be null for testcases
+    if (ip != null) {
+      add(Keys.IP, ip.getHostAddress(), b);
+    }
+  }
 }
\ No newline at end of file
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-router/src/test/java/org/apache/hadoop/yarn/server/router/TestRouterAuditLogger.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-router/src/test/java/org/apache/hadoop/yarn/server/router/TestRouterAuditLogger.java
index 40e2296..48d3ef6 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-router/src/test/java/org/apache/hadoop/yarn/server/router/TestRouterAuditLogger.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-router/src/test/java/org/apache/hadoop/yarn/server/router/TestRouterAuditLogger.java
@@ -22,11 +22,28 @@
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.ipc.ClientId;
+import org.apache.hadoop.ipc.ProtobufRpcEngine2;
+import org.apache.hadoop.ipc.RPC;
+import org.apache.hadoop.ipc.Server;
+import org.apache.hadoop.ipc.TestRpcBase;
+import org.apache.hadoop.ipc.protobuf.TestProtos;
+import org.apache.hadoop.ipc.protobuf.TestRpcServiceProtos;
+import org.apache.hadoop.ipc.TestRPC;
+import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.thirdparty.protobuf.BlockingService;
+import org.apache.hadoop.thirdparty.protobuf.RpcController;
+import org.apache.hadoop.thirdparty.protobuf.ServiceException;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.server.federation.store.records.SubClusterId;
+import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+
 /**
  * Tests {@link RouterAuditLogger}.
  */
@@ -76,13 +93,17 @@
   /**
    * Test the AuditLog format for successful events.
    */
-  private void testSuccessLogFormatHelper(ApplicationId appId,
+  private void testSuccessLogFormatHelper(boolean checkIP, ApplicationId appId,
       SubClusterId subClusterId) {
     // check without the IP
     String sLog = RouterAuditLogger
         .createSuccessLog(USER, OPERATION, TARGET, appId, subClusterId);
     StringBuilder expLog = new StringBuilder();
     expLog.append("USER=test\t");
+    if (checkIP) {
+      InetAddress ip = Server.getRemoteIp();
+      expLog.append(RouterAuditLogger.Keys.IP.name() + "=" + ip.getHostAddress() + "\t");
+    }
     expLog.append("OPERATION=oper\tTARGET=tgt\tRESULT=SUCCESS");
     if (appId != null) {
       expLog.append("\tAPPID=app_1");
@@ -109,23 +130,27 @@
    * Test the AuditLog format for successful events with the various
    * parameters.
    */
-  private void testSuccessLogFormat() {
-    testSuccessLogFormatHelper(null, null);
-    testSuccessLogFormatHelper(APPID, null);
-    testSuccessLogFormatHelper(null, SUBCLUSTERID);
-    testSuccessLogFormatHelper(APPID, SUBCLUSTERID);
+  private void testSuccessLogFormat(boolean checkIP) {
+    testSuccessLogFormatHelper(checkIP, null, null);
+    testSuccessLogFormatHelper(checkIP, APPID, null);
+    testSuccessLogFormatHelper(checkIP, null, SUBCLUSTERID);
+    testSuccessLogFormatHelper(checkIP, APPID, SUBCLUSTERID);
   }
 
   /**
    *  Test the AuditLog format for failure events.
    */
-  private void testFailureLogFormatHelper(ApplicationId appId,
+  private void testFailureLogFormatHelper(boolean checkIP, ApplicationId appId,
       SubClusterId subClusterId) {
     String fLog = RouterAuditLogger
         .createFailureLog(USER, OPERATION, "UNKNOWN", TARGET, DESC, appId,
             subClusterId);
     StringBuilder expLog = new StringBuilder();
     expLog.append("USER=test\t");
+    if (checkIP) {
+      InetAddress ip = Server.getRemoteIp();
+      expLog.append(RouterAuditLogger.Keys.IP.name() + "=" + ip.getHostAddress() + "\t");
+    }
     expLog.append("OPERATION=oper\tTARGET=tgt\tRESULT=FAILURE\t");
     expLog.append("DESCRIPTION=description of an audit log");
     expLog.append("\tPERMISSIONS=UNKNOWN");
@@ -143,18 +168,79 @@
    * Test the AuditLog format for failure events with the various
    * parameters.
    */
-  private void testFailureLogFormat() {
-    testFailureLogFormatHelper(null, null);
-    testFailureLogFormatHelper(APPID, null);
-    testFailureLogFormatHelper(null, SUBCLUSTERID);
-    testFailureLogFormatHelper(APPID, SUBCLUSTERID);
+  private void testFailureLogFormat(boolean checkIP) {
+    testFailureLogFormatHelper(checkIP, null, null);
+    testFailureLogFormatHelper(checkIP, APPID, null);
+    testFailureLogFormatHelper(checkIP, null, SUBCLUSTERID);
+    testFailureLogFormatHelper(checkIP, APPID, SUBCLUSTERID);
   }
 
   /**
    *  Test {@link RouterAuditLogger}.
    */
-  @Test public void testRouterAuditLogger() throws Exception {
-    testSuccessLogFormat();
-    testFailureLogFormat();
+  @Test
+  public void testRouterAuditLoggerWithOutIP() throws Exception {
+    testSuccessLogFormat(false);
+    testFailureLogFormat(false);
+  }
+
+  /**
+   * A special extension of {@link TestRPC.TestImpl} RPC server with
+   * {@link TestRPC.TestImpl#ping()} testing the audit logs.
+   */
+  private class MyTestRouterRPCServer extends TestRpcBase.PBServerImpl {
+    @Override
+    public TestProtos.EmptyResponseProto ping(
+            RpcController unused, TestProtos.EmptyRequestProto request)
+            throws ServiceException {
+      // Ensure clientId is received
+      byte[] clientId = Server.getClientId();
+      Assert.assertNotNull(clientId);
+      Assert.assertEquals(ClientId.BYTE_LENGTH, clientId.length);
+      // test with ip set
+      testSuccessLogFormat(true);
+      testFailureLogFormat(true);
+      return TestProtos.EmptyResponseProto.newBuilder().build();
+    }
+  }
+
+  /**
+   * Test {@link RouterAuditLogger} with IP set.
+   */
+  @Test
+  public void testRouterAuditLoggerWithIP() throws Exception {
+    Configuration conf = new Configuration();
+    RPC.setProtocolEngine(conf, TestRpcBase.TestRpcService.class, ProtobufRpcEngine2.class);
+
+    // Create server side implementation
+    MyTestRouterRPCServer serverImpl = new MyTestRouterRPCServer();
+    BlockingService service = TestRpcServiceProtos.TestProtobufRpcProto
+        .newReflectiveBlockingService(serverImpl);
+
+    // start the IPC server
+    Server server = new RPC.Builder(conf)
+        .setProtocol(TestRpcBase.TestRpcService.class)
+        .setInstance(service).setBindAddress("0.0.0.0")
+        .setPort(0).setNumHandlers(5).setVerbose(true).build();
+
+    server.start();
+
+    InetSocketAddress address = NetUtils.getConnectAddress(server);
+
+    // Make a client connection and test the audit log
+    TestRpcBase.TestRpcService proxy = null;
+    try {
+      proxy = RPC.getProxy(TestRpcBase.TestRpcService.class,
+          TestRPC.TestProtocol.versionID, address, conf);
+      // Start the testcase
+      TestProtos.EmptyRequestProto pingRequest =
+          TestProtos.EmptyRequestProto.newBuilder().build();
+      proxy.ping(null, pingRequest);
+    } finally {
+      server.stop();
+      if (proxy != null) {
+        RPC.stopProxy(proxy);
+      }
+    }
   }
 }