Allow concatenated encoded strings provided that the padding is in the correct place.
Detect and report truncated input and badly placed padding

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/fileupload/trunk@1459121 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/src/main/java/org/apache/commons/fileupload/util/mime/Base64Decoder.java b/src/main/java/org/apache/commons/fileupload/util/mime/Base64Decoder.java
index ea51361..8e486df 100644
--- a/src/main/java/org/apache/commons/fileupload/util/mime/Base64Decoder.java
+++ b/src/main/java/org/apache/commons/fileupload/util/mime/Base64Decoder.java
@@ -30,6 +30,11 @@
     private static final int INVALID_BYTE = -1; // must be outside range 0-63
 
     /**
+     * Decoding table value for padding bytes, so can detect PAD afer conversion.
+     */
+    private static final int PAD_BYTE = -2; // must be outside range 0-63
+
+    /**
      * Mask to treat byte as unsigned integer.
      */
     private static final int MASK_BYTE_UNSIGNED = 0xFF;
@@ -62,7 +67,7 @@
     private static final byte PADDING = (byte) '=';
 
     /**
-     * Set up the decoding table; this is indexed by a byte converted to an int,
+     * Set up the decoding table; this is indexed by a byte converted to an unsigned int,
      * so must be at least as large as the number of different byte values,
      * positive and negative and zero.
      */
@@ -77,6 +82,8 @@
         for (int i = 0; i < ENCODING_TABLE.length; i++) {
             DECODING_TABLE[ENCODING_TABLE[i]] = (byte) i;
         }
+        // Allow pad byte to be easily detected after conversion
+        DECODING_TABLE[PADDING] = PAD_BYTE;
     }
 
     /**
@@ -101,9 +108,6 @@
         int cachedBytes = 0;
 
         for (byte b : data) {
-            if (b == PADDING) { // Padding means end of input
-                break;
-            }
             final byte d = DECODING_TABLE[MASK_BYTE_UNSIGNED & b];
             if (d == INVALID_BYTE) {
                 continue; // Ignore invalid bytes
@@ -111,32 +115,27 @@
             cache[cachedBytes++] = d;
             if (cachedBytes == INPUT_BYTES_PER_CHUNK) {
                 // Convert 4 6-bit bytes to 3 8-bit bytes
-                // CHECKSTYLE IGNORE MagicNumber FOR NEXT 3 LINES
-                out.write((cache[0] << 2) | (cache[1] >> 4)); // 6 bits of b1 plus 2 bits of b2
-                out.write((cache[1] << 4) | (cache[2] >> 2)); // 4 bits of b2 plus 4 bits of b3
-                out.write((cache[2] << 6) | cache[3]);        // 2 bits of b3 plus 6 bits of b4
-
                 // CHECKSTYLE IGNORE MagicNumber FOR NEXT 1 LINE
-                outLen += 3;
+                out.write((cache[0] << 2) | (cache[1] >> 4)); // 6 bits of b1 plus 2 bits of b2
+                outLen++;
+                if (cache[2] != PAD_BYTE) {
+                    // CHECKSTYLE IGNORE MagicNumber FOR NEXT 1 LINE
+                    out.write((cache[1] << 4) | (cache[2] >> 2)); // 4 bits of b2 plus 4 bits of b3
+                    outLen++;
+                    if (cache[3] != PAD_BYTE) {
+                        // CHECKSTYLE IGNORE MagicNumber FOR NEXT 1 LINE
+                        out.write((cache[2] << 6) | cache[3]);        // 2 bits of b3 plus 6 bits of b4
+                        outLen++;
+                    }
+                } else if (cache[3] != PAD_BYTE) { // if byte 3 is pad, byte 4 must be pad too
+                    throw new IOException("Invalid Base64 input: incorrect padding");                    
+                }
                 cachedBytes = 0;
             }
         }
-        // CHECKSTYLE IGNORE MagicNumber FOR NEXT 2 LINES
-        if (cachedBytes >= 2) {
-            out.write((cache[0] << 2) | (cache[1] >> 4)); // 6 bits of b1 plus 2 bits of b2
-            outLen++;
-            // CHECKSTYLE IGNORE MagicNumber FOR NEXT 2 LINES
-            if (cachedBytes >= 3) {
-                out.write((cache[1] << 4) | (cache[2] >> 2)); // 4 bits of b2 plus 4 bits of b3
-                outLen++;
-                // CHECKSTYLE IGNORE MagicNumber FOR NEXT 2 LINES
-                if (cachedBytes >= 4) {
-                    out.write((cache[2] << 6) | cache[3]);        // 2 bits of b3 plus 6 bits of b4
-                    outLen++;
-                }
-            }
-        } else if (cachedBytes != 0){
-            throw new IOException("Invalid Base64 input: truncated");            
+        // Check for anything left over
+        if (cachedBytes != 0){
+            throw new IOException("Invalid Base64 input: truncated");
         }
         return outLen;
     }
diff --git a/src/test/java/org/apache/commons/fileupload/util/mime/Base64DecoderTestCase.java b/src/test/java/org/apache/commons/fileupload/util/mime/Base64DecoderTestCase.java
index 5e4bf74..efc9a17 100644
--- a/src/test/java/org/apache/commons/fileupload/util/mime/Base64DecoderTestCase.java
+++ b/src/test/java/org/apache/commons/fileupload/util/mime/Base64DecoderTestCase.java
@@ -17,9 +17,12 @@
 package org.apache.commons.fileupload.util.mime;
 
 import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
+import java.io.UnsupportedEncodingException;
 
 import org.junit.Test;
 
@@ -57,25 +60,12 @@
 
     /**
      * Test our decode with pad character in the middle.
-     * Returns data up to pad character.
-     *
-     *
-     * @throws Exception if any error occurs while decoding the input string.
+     * Continues provided that the padding is in the correct place,
+     * i.e. concatenated valid strings decode OK.
      */
     @Test
     public void decodeWithInnerPad() throws Exception {
-        assertEncoded("Hello World", "SGVsbG8gV29ybGQ=SGVsbG8gV29ybGQ=");
-    }
-
-    private static void assertEncoded(String clearText, String encoded) throws Exception {
-        byte[] expected = clearText.getBytes(US_ASCII_CHARSET);
-
-        ByteArrayOutputStream out = new ByteArrayOutputStream(encoded.length());
-        byte[] encodedData = encoded.getBytes(US_ASCII_CHARSET);
-        Base64Decoder.decode(encodedData, out);
-        byte[] actual = out.toByteArray();
-
-        assertArrayEquals(expected, actual);
+        assertEncoded("Hello WorldHello World", "SGVsbG8gV29ybGQ=SGVsbG8gV29ybGQ=");
     }
 
     /**
@@ -92,4 +82,55 @@
         Base64Decoder.decode(x, new ByteArrayOutputStream());
     }
 
+    @Test
+    public void decodeTrailingJunk() throws Exception {
+        assertEncoded("foobar", "Zm9vYmFy!!!");
+    }
+
+    // If there are valid trailing Base64 chars, complain
+    @Test
+    public void decodeTrailing1() throws Exception {
+        assertIOException("truncated", "Zm9vYmFy1");
+    }
+
+    // If there are valid trailing Base64 chars, complain
+    @Test
+    public void decodeTrailing2() throws Exception {
+        assertIOException("truncated", "Zm9vYmFy12");
+    }
+
+    // If there are valid trailing Base64 chars, complain
+    @Test
+    public void decodeTrailing3() throws Exception {
+        assertIOException("truncated", "Zm9vYmFy123");
+    }
+
+    @Test
+    public void badPadding() throws Exception {
+        assertIOException("incorrect padding", "Zg=a");
+    }
+
+    private static void assertEncoded(String clearText, String encoded) throws Exception {
+        byte[] expected = clearText.getBytes(US_ASCII_CHARSET);
+
+        ByteArrayOutputStream out = new ByteArrayOutputStream(encoded.length());
+        byte[] encodedData = encoded.getBytes(US_ASCII_CHARSET);
+        Base64Decoder.decode(encodedData, out);
+        byte[] actual = out.toByteArray();
+
+        assertArrayEquals(expected, actual);
+    }
+
+    private static void assertIOException(String messageText, String encoded) throws UnsupportedEncodingException {
+        ByteArrayOutputStream out = new ByteArrayOutputStream(encoded.length());
+        byte[] encodedData = encoded.getBytes(US_ASCII_CHARSET);
+        try {
+            Base64Decoder.decode(encodedData, out);
+            fail("Expected IOException");
+        } catch (IOException e) {
+            String em = e.getMessage();
+            assertTrue("Expected to find " + messageText + " in '" + em + "'",em.contains(messageText));
+        }
+    }
+
 }