TEZ-3942. RPC getTask writable optimization invalid in hadoop 2.8+

Signed-off-by: Jason Lowe <jlowe@apache.org>
diff --git a/tez-api/src/main/java/org/apache/tez/common/TezUtils.java b/tez-api/src/main/java/org/apache/tez/common/TezUtils.java
index aed9e0f..072c02f 100644
--- a/tez-api/src/main/java/org/apache/tez/common/TezUtils.java
+++ b/tez-api/src/main/java/org/apache/tez/common/TezUtils.java
@@ -20,6 +20,7 @@
 
 import java.io.IOException;
 import java.io.OutputStream;
+import java.nio.ByteBuffer;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
@@ -97,7 +98,7 @@
    * @throws java.io.IOException
    */
   public static UserPayload createUserPayloadFromConf(Configuration conf) throws IOException {
-    return UserPayload.create(createByteStringFromConf(conf).asReadOnlyByteBuffer());
+    return UserPayload.create(ByteBuffer.wrap(createByteStringFromConf(conf).toByteArray()));
   }
 
   /**
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/DagTypeConverters.java b/tez-api/src/main/java/org/apache/tez/dag/api/DagTypeConverters.java
index c5d9c0b..acc5f12 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/DagTypeConverters.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/DagTypeConverters.java
@@ -735,7 +735,7 @@
     if (payload == null) {
       return null;
     }
-    return payload.getPayload();
+    return payload.getRawPayload();
   }
 
   public static VertexExecutionContextProto convertToProto(
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/EntityDescriptor.java b/tez-api/src/main/java/org/apache/tez/dag/api/EntityDescriptor.java
index dcc4ebf..13d4a93 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/EntityDescriptor.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/EntityDescriptor.java
@@ -23,6 +23,7 @@
 import java.io.IOException;
 import java.nio.ByteBuffer;
 
+import com.google.common.annotations.VisibleForTesting;
 import org.apache.hadoop.classification.InterfaceAudience.Private;
 import org.apache.hadoop.classification.InterfaceAudience.Public;
 import org.apache.hadoop.io.DataOutputBuffer;
@@ -94,36 +95,40 @@
     return this.className;
   }
 
+  void writeSingular(DataOutput out, ByteBuffer bb) throws IOException {
+    out.write(bb.array(), 0, bb.array().length);
+  }
+
+  void writeSegmented(DataOutput out, ByteBuffer bb) throws IOException {
+    // This code is just for fallback in case serialization is changed to
+    // use something other than DataOutputBuffer.
+    int len;
+    byte[] buf = new byte[SERIALIZE_BUFFER_SIZE];
+    do {
+      len = Math.min(bb.remaining(), SERIALIZE_BUFFER_SIZE);
+      bb.get(buf, 0, len);
+      out.write(buf, 0, len);
+    } while (bb.remaining() > 0);
+  }
+
   @Override
   public void write(DataOutput out) throws IOException {
     Text.writeString(out, className);
     // TODO: TEZ-305 - using protobuf serde instead of Writable serde.
     ByteBuffer bb = DagTypeConverters.convertFromTezUserPayload(userPayload);
-    if (bb == null) {
+    if (bb == null || bb.remaining() == 0) {
       out.writeInt(-1);
-    } else {
-      int size = bb.remaining();
-      if (size == 0) {
-        out.writeInt(-1);
-      } else {
-        out.writeInt(size);
-        if (out instanceof DataOutputBuffer) {
-          DataOutputBuffer buf = (DataOutputBuffer) out;
-          buf.write(new ByteBufferDataInput(bb), size);
-        } else {
-          // This code is just for fallback in case serialization is changed to
-          // use something other than DataOutputBuffer.
-          int len;
-          byte[] buf = new byte[SERIALIZE_BUFFER_SIZE];
-          do {
-            len = Math.min(bb.remaining(), SERIALIZE_BUFFER_SIZE);
-            bb.get(buf, 0, len);
-            out.write(buf, 0, len);
-          } while (bb.remaining() > 0);
-        }
-      }
-      out.writeInt(userPayload.getVersion());
+      return;
     }
+
+    // write size
+    out.writeInt(bb.remaining());
+    if (bb.hasArray()) {
+      writeSingular(out, bb);
+    } else {
+      writeSegmented(out, bb);
+    }
+    out.writeInt(userPayload.getVersion());
   }
 
   @Override
@@ -144,76 +149,4 @@
         userPayload == null ? false : userPayload.getPayload() == null ? false : true;
     return "ClassName=" + className + ", hasPayload=" + hasPayload;
   }
-
-  private static class ByteBufferDataInput implements DataInput {
-
-    private final ByteBuffer bb;
-
-    public ByteBufferDataInput(ByteBuffer bb) {
-      this.bb = bb;
-    }
-
-    @Override
-    public void readFully(byte[] b) throws IOException {
-      bb.get(b, 0, bb.remaining());
-    }
-
-    @Override
-    public void readFully(byte[] b, int off, int len) throws IOException {
-      bb.get(b, off, len);
-    }
-
-    @Override
-    public int skipBytes(int n) throws IOException {
-      throw new UnsupportedOperationException();
-    }
-    @Override
-    public boolean readBoolean() throws IOException {
-      throw new UnsupportedOperationException();
-    }
-    @Override
-    public byte readByte() throws IOException {
-      return bb.get();
-    }
-    @Override
-    public int readUnsignedByte() throws IOException {
-      throw new UnsupportedOperationException();
-    }
-    @Override
-    public short readShort() throws IOException {
-      throw new UnsupportedOperationException();
-    }
-    @Override
-    public int readUnsignedShort() throws IOException {
-      throw new UnsupportedOperationException();
-    }
-    @Override
-    public char readChar() throws IOException {
-      throw new UnsupportedOperationException();
-    }
-    @Override
-    public int readInt() throws IOException {
-      throw new UnsupportedOperationException();
-    }
-    @Override
-    public long readLong() throws IOException {
-      throw new UnsupportedOperationException();
-    }
-    @Override
-    public float readFloat() throws IOException {
-      throw new UnsupportedOperationException();
-    }
-    @Override
-    public double readDouble() throws IOException {
-      throw new UnsupportedOperationException();
-    }
-    @Override
-    public String readLine() throws IOException {
-      throw new UnsupportedOperationException();
-    }
-    @Override
-    public String readUTF() throws IOException {
-      throw new UnsupportedOperationException();
-    }
-  }
 }
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/UserPayload.java b/tez-api/src/main/java/org/apache/tez/dag/api/UserPayload.java
index fa617b5..087b17a 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/UserPayload.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/UserPayload.java
@@ -63,6 +63,17 @@
   }
 
   /**
+   * Return the payload as a ByteBuffer.
+   * @return ByteBuffer.
+   */
+  @Nullable
+  public ByteBuffer getRawPayload() {
+    // Note: Several bits of serialization, including deepCopyAsArray depend on a new instance of the
+    // ByteBuffer being returned, since they modify it. If changing this code to return the same
+    // ByteBuffer - deepCopyAsArray and TezEntityDescriptor need to be looked at.
+    return payload == EMPTY_BYTE ? null : payload.duplicate();
+  }
+  /**
    * Return the payload as a read-only ByteBuffer.
    * @return read-only ByteBuffer.
    */
diff --git a/tez-api/src/test/java/org/apache/tez/dag/api/TestEntityDescriptor.java b/tez-api/src/test/java/org/apache/tez/dag/api/TestEntityDescriptor.java
index 1e8a99d..606bf42 100644
--- a/tez-api/src/test/java/org/apache/tez/dag/api/TestEntityDescriptor.java
+++ b/tez-api/src/test/java/org/apache/tez/dag/api/TestEntityDescriptor.java
@@ -23,35 +23,24 @@
 import java.io.DataInputStream;
 import java.io.DataOutputStream;
 import java.io.IOException;
+import java.nio.ByteBuffer;
 
 import org.apache.commons.lang.RandomStringUtils;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DataOutputBuffer;
 import org.apache.tez.common.TezUtils;
 import org.junit.Assert;
 import org.junit.Test;
+import org.mockito.Mockito;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.spy;
 
 public class TestEntityDescriptor {
 
-  @Test
-  public void testEntityDescriptorHadoopSerialization() throws IOException {
-    // This tests the alternate serialization code path
-    // if the DataOutput is not DataOutputBuffer
-    Configuration conf = new Configuration(true);
-    String confVal = RandomStringUtils.random(10000, true, true);
-    conf.set("testKey", confVal);
-    UserPayload payload = TezUtils.createUserPayloadFromConf(conf);
-    InputDescriptor entityDescriptor =
-        InputDescriptor.create("inputClazz").setUserPayload(payload)
-        .setHistoryText("Bar123");
-
-    ByteArrayOutputStream bos = new ByteArrayOutputStream();
-    DataOutputStream out = new DataOutputStream(bos);
-    entityDescriptor.write(out);
-    out.close();
-
-    InputDescriptor deserialized = InputDescriptor.create("dummy");
-    deserialized.readFields(new DataInputStream(new ByteArrayInputStream(bos.toByteArray())));
-
+  public void verifyResults(InputDescriptor entityDescriptor, InputDescriptor deserialized, UserPayload payload,
+                             String confVal) throws IOException {
     Assert.assertEquals(entityDescriptor.getClassName(), deserialized.getClassName());
     // History text is not serialized when sending to tasks
     Assert.assertNull(deserialized.getHistoryText());
@@ -60,4 +49,54 @@
     Assert.assertEquals(confVal, deserializedConf.get("testKey"));
   }
 
+  public void testSingularWrite(InputDescriptor entityDescriptor, InputDescriptor deserialized, UserPayload payload,
+                                String confVal) throws IOException {
+    DataOutputBuffer out = new DataOutputBuffer();
+    entityDescriptor.write(out);
+    out.close();
+    ByteArrayOutputStream bos = new ByteArrayOutputStream(out.getData().length);
+    bos.write(out.getData());
+
+    Mockito.verify(entityDescriptor).writeSingular(eq(out), any(ByteBuffer.class));
+    deserialized.readFields(new DataInputStream(new ByteArrayInputStream(bos.toByteArray())));
+    verifyResults(entityDescriptor, deserialized, payload, confVal);
+  }
+
+  public void testSegmentedWrite(InputDescriptor entityDescriptor, InputDescriptor deserialized, UserPayload payload,
+                                 String confVal) throws IOException {
+    ByteArrayOutputStream bos = new ByteArrayOutputStream();
+    DataOutputStream out = new DataOutputStream(bos);
+    entityDescriptor.write(out);
+    out.close();
+
+    Mockito.verify(entityDescriptor).writeSegmented(eq(out), any(ByteBuffer.class));
+    deserialized.readFields(new DataInputStream(new ByteArrayInputStream(bos.toByteArray())));
+    verifyResults(entityDescriptor, deserialized, payload, confVal);
+  }
+
+  @Test (timeout=1000)
+  public void testEntityDescriptorHadoopSerialization() throws IOException {
+     /* This tests the alternate serialization code path
+     * if the DataOutput is not DataOutputBuffer
+     * AND, if it indeed is, with a read/write payload */
+    Configuration conf = new Configuration(true);
+    String confVal = RandomStringUtils.random(10000, true, true);
+    conf.set("testKey", confVal);
+    UserPayload payload = TezUtils.createUserPayloadFromConf(conf);
+
+    InputDescriptor deserialized = InputDescriptor.create("dummy");
+    InputDescriptor entityDescriptor =
+        InputDescriptor.create("inputClazz").setUserPayload(payload)
+                .setHistoryText("Bar123");
+    InputDescriptor entityDescriptorLivingInFear = spy(entityDescriptor);
+
+    testSingularWrite(entityDescriptorLivingInFear, deserialized, payload, confVal);
+
+    /* make read-only payload */
+    payload =  UserPayload.create(payload.getPayload());
+    entityDescriptor = InputDescriptor.create("inputClazz").setUserPayload(payload)
+                      .setHistoryText("Bar123");
+    entityDescriptorLivingInFear = spy(entityDescriptor);
+    testSegmentedWrite(entityDescriptorLivingInFear, deserialized, payload, confVal);
+  }
 }