MAPREDUCE-5027. Shuffle does not limit number of outstanding connections (Robert Parker via jeagles)

git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@1453098 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/hadoop-mapreduce-project/CHANGES.txt b/hadoop-mapreduce-project/CHANGES.txt
index 11528ba..9014022 100644
--- a/hadoop-mapreduce-project/CHANGES.txt
+++ b/hadoop-mapreduce-project/CHANGES.txt
@@ -731,6 +731,9 @@
     MAPREDUCE-4989. JSONify DataTables input data for Attempts page (Ravi
     Prakash via jlowe)
 
+    MAPREDUCE-5027. Shuffle does not limit number of outstanding connections
+    (Robert Parker via jeagles)
+
   OPTIMIZATIONS
 
     MAPREDUCE-4946. Fix a performance problem for large jobs by reducing the
diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/resources/mapred-default.xml b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/resources/mapred-default.xml
index e756860..83131e7 100644
--- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/resources/mapred-default.xml
+++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/resources/mapred-default.xml
@@ -296,6 +296,14 @@
 </property>
 
 <property>
+  <name>mapreduce.shuffle.max.connections</name>
+  <value>0</value>
+  <description>Max allowed connections for the shuffle.  Set to 0 (zero)
+               to indicate no limit on the number of connections.
+  </description>
+</property>
+
+<property>
   <name>mapreduce.reduce.markreset.buffer.percent</name>
   <value>0.0</value>
   <description>The percentage of memory -relative to the maximum heap size- to
diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/ShuffleHandler.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/ShuffleHandler.java
index cb3dfad..56ede18 100644
--- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/ShuffleHandler.java
+++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/main/java/org/apache/hadoop/mapred/ShuffleHandler.java
@@ -88,6 +88,7 @@
 import org.jboss.netty.channel.ChannelHandlerContext;
 import org.jboss.netty.channel.ChannelPipeline;
 import org.jboss.netty.channel.ChannelPipelineFactory;
+import org.jboss.netty.channel.ChannelStateEvent;
 import org.jboss.netty.channel.Channels;
 import org.jboss.netty.channel.ExceptionEvent;
 import org.jboss.netty.channel.MessageEvent;
@@ -121,7 +122,7 @@
 
   public static final String SHUFFLE_READAHEAD_BYTES = "mapreduce.shuffle.readahead.bytes";
   public static final int DEFAULT_SHUFFLE_READAHEAD_BYTES = 4 * 1024 * 1024;
-
+  
   // pattern to identify errors related to the client closing the socket early
   // idea borrowed from Netty SslHandler
   private static final Pattern IGNORABLE_ERROR_MESSAGE = Pattern.compile(
@@ -133,15 +134,15 @@
   private final ChannelGroup accepted = new DefaultChannelGroup();
   protected HttpPipelineFactory pipelineFact;
   private int sslFileBufferSize;
-
+  
   /**
    * Should the shuffle use posix_fadvise calls to manage the OS cache during
    * sendfile
    */
   private boolean manageOsCache;
   private int readaheadLength;
+  private int maxShuffleConnections;
   private ReadaheadPool readaheadPool = ReadaheadPool.getInstance();
-   
 
   public static final String MAPREDUCE_SHUFFLE_SERVICEID =
       "mapreduce.shuffle";
@@ -159,6 +160,9 @@
 
   public static final int DEFAULT_SUFFLE_SSL_FILE_BUFFER_SIZE = 60 * 1024;
 
+  public static final String MAX_SHUFFLE_CONNECTIONS = "mapreduce.shuffle.max.connections";
+  public static final int DEFAULT_MAX_SHUFFLE_CONNECTIONS = 0; // 0 implies no limit
+
   @Metrics(about="Shuffle output metrics", context="mapred")
   static class ShuffleMetrics implements ChannelFutureListener {
     @Metric("Shuffle output in bytes")
@@ -270,6 +274,9 @@
     readaheadLength = conf.getInt(SHUFFLE_READAHEAD_BYTES,
         DEFAULT_SHUFFLE_READAHEAD_BYTES);
     
+    maxShuffleConnections = conf.getInt(MAX_SHUFFLE_CONNECTIONS, 
+                                        DEFAULT_MAX_SHUFFLE_CONNECTIONS);
+
     ThreadFactory bossFactory = new ThreadFactoryBuilder()
       .setNameFormat("ShuffleHandler Netty Boss #%d")
       .build();
@@ -400,6 +407,21 @@
     }
 
     @Override
+    public void channelOpen(ChannelHandlerContext ctx, ChannelStateEvent evt) 
+        throws Exception {
+      if ((maxShuffleConnections > 0) && (accepted.size() >= maxShuffleConnections)) {
+        LOG.info(String.format("Current number of shuffle connections (%d) is " + 
+            "greater than or equal to the max allowed shuffle connections (%d)", 
+            accepted.size(), maxShuffleConnections));
+        evt.getChannel().close();
+        return;
+      }
+      accepted.add(evt.getChannel());
+      super.channelOpen(ctx, evt);
+     
+    }
+
+    @Override
     public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt)
         throws Exception {
       HttpRequest request = (HttpRequest) evt.getMessage();
@@ -620,6 +642,5 @@
         sendError(ctx, INTERNAL_SERVER_ERROR);
       }
     }
-
   }
 }
diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/java/org/apache/hadoop/mapred/TestShuffleHandler.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/java/org/apache/hadoop/mapred/TestShuffleHandler.java
index 309a789..4d845c3 100644
--- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/java/org/apache/hadoop/mapred/TestShuffleHandler.java
+++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/java/org/apache/hadoop/mapred/TestShuffleHandler.java
@@ -24,13 +24,15 @@
 import static org.apache.hadoop.test.MockitoMaker.stub;
 import static org.jboss.netty.buffer.ChannelBuffers.wrappedBuffer;
 import static org.junit.Assert.assertEquals;
-
 import java.io.DataInputStream;
 import java.io.IOException;
 import java.net.HttpURLConnection;
+import java.net.SocketException;
 import java.net.URL;
 import java.util.ArrayList;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.DataOutputBuffer;
 import org.apache.hadoop.mapreduce.task.reduce.ShuffleHeader;
@@ -47,10 +49,13 @@
 import org.junit.Assert;
 import org.junit.Test;
 
-public class TestShuffleHandler {
-  static final long MiB = 1024 * 1024;
 
-  @Test public void testSerializeMeta()  throws Exception {
+public class TestShuffleHandler {
+  static final long MiB = 1024 * 1024; 
+  private static final Log LOG = LogFactory.getLog(TestShuffleHandler.class);
+
+  @Test (timeout = 10000)
+  public void testSerializeMeta()  throws Exception {
     assertEquals(1, ShuffleHandler.deserializeMetaData(
         ShuffleHandler.serializeMetaData(1)));
     assertEquals(-1, ShuffleHandler.deserializeMetaData(
@@ -59,7 +64,8 @@
         ShuffleHandler.serializeMetaData(8080)));
   }
 
-  @Test public void testShuffleMetrics() throws Exception {
+  @Test (timeout = 10000)
+  public void testShuffleMetrics() throws Exception {
     MetricsSystem ms = new MetricsSystemImpl();
     ShuffleHandler sh = new ShuffleHandler(ms);
     ChannelFuture cf = make(stub(ChannelFuture.class).
@@ -88,7 +94,7 @@
     assertGauge("ShuffleConnections", connections, rb);
   }
 
-  @Test
+  @Test (timeout = 10000)
   public void testClientClosesConnection() throws Exception {
     final ArrayList<Throwable> failures = new ArrayList<Throwable>(1);
     Configuration conf = new Configuration();
@@ -159,4 +165,84 @@
     Assert.assertTrue("sendError called when client closed connection",
         failures.size() == 0);
   }
+  
+  @Test (timeout = 10000)
+  public void testMaxConnections() throws Exception {
+    
+    Configuration conf = new Configuration();
+    conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0);
+    conf.setInt(ShuffleHandler.MAX_SHUFFLE_CONNECTIONS, 3);
+    ShuffleHandler shuffleHandler = new ShuffleHandler() {
+      @Override
+      protected Shuffle getShuffle(Configuration conf) {
+        // replace the shuffle handler with one stubbed for testing
+        return new Shuffle(conf) {
+          @Override
+          protected void verifyRequest(String appid, ChannelHandlerContext ctx,
+              HttpRequest request, HttpResponse response, URL requestUri)
+                  throws IOException {
+          }
+          @Override
+          protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx,
+              Channel ch, String user, String jobId, String mapId, int reduce)
+                  throws IOException {
+            // send a shuffle header and a lot of data down the channel
+            // to trigger a broken pipe
+            ShuffleHeader header =
+                new ShuffleHeader("dummy_header", 5678, 5678, 1);
+            DataOutputBuffer dob = new DataOutputBuffer();
+            header.write(dob);
+            ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength()));
+            dob = new DataOutputBuffer();
+            for (int i=0; i<100000; ++i) {
+              header.write(dob);
+            }
+            return ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength()));
+          }
+        };
+      }
+    };
+    shuffleHandler.init(conf);
+    shuffleHandler.start();
+
+    // setup connections
+    int connAttempts = 3;
+    HttpURLConnection conns[] = new HttpURLConnection[connAttempts];
+
+    for (int i = 0; i < connAttempts; i++) {
+      String URLstring = "http://127.0.0.1:" 
+           + shuffleHandler.getConfig().get(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY)
+           + "/mapOutput?job=job_12345_1&reduce=1&map=attempt_12345_1_m_"
+           + i + "_0";
+      URL url = new URL(URLstring);
+      conns[i] = (HttpURLConnection)url.openConnection();
+    }
+
+    // Try to open numerous connections
+    for (int i = 0; i < connAttempts; i++) {
+      conns[i].connect();
+    }
+
+    //Ensure first connections are okay
+    conns[0].getInputStream();
+    int rc = conns[0].getResponseCode();
+    Assert.assertEquals(HttpURLConnection.HTTP_OK, rc);
+    
+    conns[1].getInputStream();
+    rc = conns[1].getResponseCode();
+    Assert.assertEquals(HttpURLConnection.HTTP_OK, rc);
+
+    // This connection should be closed because it to above the limit
+    try {
+      conns[2].getInputStream();
+      rc = conns[2].getResponseCode();
+      Assert.fail("Expected a SocketException");
+    } catch (SocketException se) {
+      LOG.info("Expected - connection should not be open");
+    } catch (Exception e) {
+      Assert.fail("Expected a SocketException");
+    }
+    
+    shuffleHandler.stop(); 
+  }
 }