PROTON-2084 Only use the key type restriction on top level Map keys

Don't apply the fixed key type to other Maps that are nested inside a
map with a key type retriction.
diff --git a/proton-j/src/main/java/org/apache/qpid/proton/codec/MapType.java b/proton-j/src/main/java/org/apache/qpid/proton/codec/MapType.java
index b791c91..40c27b2 100644
--- a/proton-j/src/main/java/org/apache/qpid/proton/codec/MapType.java
+++ b/proton-j/src/main/java/org/apache/qpid/proton/codec/MapType.java
@@ -26,6 +26,7 @@
 import java.util.LinkedHashMap;
 import java.util.Map;
 
+@SuppressWarnings({ "rawtypes", "unchecked" })
 public class MapType extends AbstractPrimitiveType<Map>
 {
     private final MapEncoding _mapEncoding;
@@ -62,8 +63,8 @@
     @Override
     public MapEncoding getEncoding(final Map val)
     {
-        int calculatedSize = calculateSize(val, _encoder, fixedKeyType);
-        MapEncoding encoding = (val.size() > 127 || calculatedSize >= 254)
+        final int calculatedSize = calculateSize(val);
+        final MapEncoding encoding = (val.size() > 127 || calculatedSize >= 254)
                                     ? _mapEncoding
                                     : _shortMapEncoding;
 
@@ -71,26 +72,38 @@
         return encoding;
     }
 
-    private static int calculateSize(final Map map, EncoderImpl encoder, AMQPType<?> fixedKeyType)
+    private int calculateSize(final Map map)
     {
         int len = 0;
-        Iterator<Map.Entry> iter = map.entrySet().iterator();
+
+        final Iterator<Map.Entry<?, ?>> iter = map.entrySet().iterator();
+        final AMQPType fixedKeyType = this.fixedKeyType;
+
+        // Clear existing fixed key type encoding to prevent application to nested Maps
+        setKeyEncoding(null);
 
         while (iter.hasNext())
         {
-            Map.Entry element = iter.next();
+            final Map.Entry<?, ?> element = iter.next();
+            TypeEncoding elementEncoding;
 
-            AMQPType keyType = fixedKeyType;
             if (fixedKeyType == null)
             {
-                keyType = encoder.getType(element.getKey());
+                elementEncoding = _encoder.getType(element.getKey()).getEncoding(element.getKey());
+            }
+            else
+            {
+                elementEncoding = fixedKeyType.getEncoding(element.getKey());
             }
 
-            TypeEncoding elementEncoding = keyType.getEncoding(element.getKey());
-            len += elementEncoding.getConstructorSize()+elementEncoding.getValueSize(element.getKey());
-            elementEncoding = encoder.getType(element.getValue()).getEncoding(element.getValue());
-            len += elementEncoding.getConstructorSize()+elementEncoding.getValueSize(element.getValue());
+            len += elementEncoding.getConstructorSize() + elementEncoding.getValueSize(element.getKey());
+            elementEncoding = _encoder.getType(element.getValue()).getEncoding(element.getValue());
+            len += elementEncoding.getConstructorSize() + elementEncoding.getValueSize(element.getValue());
         }
+
+        // Reset Existing key type encoding for later encode step or reuse until cleared by caller
+        setKeyEncoding(fixedKeyType);
+
         return len;
     }
 
@@ -109,7 +122,7 @@
             }
             else
             {
-                PrimitiveTypeEncoding<?> primitiveConstructor = (PrimitiveTypeEncoding<?>) previousConstructor;
+                final PrimitiveTypeEncoding<?> primitiveConstructor = (PrimitiveTypeEncoding<?>) previousConstructor;
                 if (encodingCode != primitiveConstructor.getEncodingCode())
                 {
                     return decoder.readConstructor();
@@ -156,31 +169,41 @@
             getEncoder().getBuffer().ensureRemaining(getSizeBytes() + getEncodedValueSize(map));
             getEncoder().writeRaw(2 * map.size());
 
-            Iterator<Map.Entry> iter = map.entrySet().iterator();
+            final Iterator<Map.Entry> iter = map.entrySet().iterator();
+            final AMQPType fixedKeyType = MapType.this.fixedKeyType;
+
+            // Clear existing fixed key type encoding to prevent application to nested Maps
+            setKeyEncoding(null);
 
             while (iter.hasNext())
             {
-                Map.Entry element = iter.next();
+                final Map.Entry<?, ?> element = iter.next();
+                TypeEncoding elementEncoding;
 
-                AMQPType keyType = fixedKeyType;
-                if (keyType == null)
+                if (fixedKeyType == null)
                 {
-                    keyType = getEncoder().getType(element.getKey());
+                    elementEncoding = _encoder.getType(element.getKey()).getEncoding(element.getKey());
+                }
+                else
+                {
+                    elementEncoding = fixedKeyType.getEncoding(element.getKey());
                 }
 
-                TypeEncoding elementEncoding = keyType.getEncoding(element.getKey());
                 elementEncoding.writeConstructor();
                 elementEncoding.writeValue(element.getKey());
                 elementEncoding = getEncoder().getType(element.getValue()).getEncoding(element.getValue());
                 elementEncoding.writeConstructor();
                 elementEncoding.writeValue(element.getValue());
             }
+
+            // Reset Existing key type encoding for later encode step or reuse until cleared by caller
+            setKeyEncoding(fixedKeyType);
         }
 
         @Override
         protected int getEncodedValueSize(final Map val)
         {
-            return 4 + ((val == _value) ? _length : calculateSize(val, getEncoder(), fixedKeyType));
+            return 4 + ((val == _value) ? _length : calculateSize(val));
         }
 
         @Override
@@ -204,12 +227,12 @@
         @Override
         public Map readValue()
         {
-            DecoderImpl decoder = getDecoder();
-            ReadableBuffer buffer = decoder.getBuffer();
+            final DecoderImpl decoder = getDecoder();
+            final ReadableBuffer buffer = decoder.getBuffer();
 
-            int size = decoder.readRawInt();
+            final int size = decoder.readRawInt();
             // todo - limit the decoder with size
-            int count = decoder.readRawInt();
+            final int count = decoder.readRawInt();
             if (count > decoder.getByteBufferRemaining()) {
                 throw new IllegalArgumentException("Map element count "+count+" is specified to be greater than the amount of data available ("+
                                                    decoder.getByteBufferRemaining()+")");
@@ -227,10 +250,10 @@
                     throw new DecodeException("Unknown constructor");
                 }
 
-                Object key = keyConstructor.readValue();
+                final Object key = keyConstructor.readValue();
 
                 boolean arrayType = false;
-                byte code = buffer.get(buffer.position());
+                final byte code = buffer.get(buffer.position());
                 switch (code)
                 {
                     case EncodingCodes.ARRAY8:
@@ -264,9 +287,9 @@
         @Override
         public void skipValue()
         {
-            DecoderImpl decoder = getDecoder();
-            ReadableBuffer buffer = decoder.getBuffer();
-            int size = decoder.readRawInt();
+            final DecoderImpl decoder = getDecoder();
+            final ReadableBuffer buffer = decoder.getBuffer();
+            final int size = decoder.readRawInt();
             buffer.position(buffer.position() + size);
         }
 
@@ -282,7 +305,6 @@
             extends SmallFloatingSizePrimitiveTypeEncoding<Map>
             implements MapEncoding
     {
-
         private Map _value;
         private int _length;
 
@@ -297,30 +319,41 @@
             getEncoder().getBuffer().ensureRemaining(getSizeBytes() + getEncodedValueSize(map));
             getEncoder().writeRaw((byte)(2 * map.size()));
 
-            Iterator<Map.Entry> iter = map.entrySet().iterator();
+            final Iterator<Map.Entry> iter = map.entrySet().iterator();
+            final AMQPType fixedKeyType = MapType.this.fixedKeyType;
+
+            // Clear existing fixed key type encoding to prevent application to nested Maps
+            setKeyEncoding(null);
+
             while (iter.hasNext())
             {
-                Map.Entry element = iter.next();
+                final Map.Entry<?, ?> element = iter.next();
+                TypeEncoding elementEncoding;
 
-                AMQPType keyType = fixedKeyType;
-                if (keyType == null)
+                if (fixedKeyType == null)
                 {
-                    keyType = getEncoder().getType(element.getKey());
+                    elementEncoding = _encoder.getType(element.getKey()).getEncoding(element.getKey());
+                }
+                else
+                {
+                    elementEncoding = fixedKeyType.getEncoding(element.getKey());
                 }
 
-                TypeEncoding elementEncoding = keyType.getEncoding(element.getKey());
                 elementEncoding.writeConstructor();
                 elementEncoding.writeValue(element.getKey());
                 elementEncoding = getEncoder().getType(element.getValue()).getEncoding(element.getValue());
                 elementEncoding.writeConstructor();
                 elementEncoding.writeValue(element.getValue());
             }
+
+            // Reset Existing key type encoding for later encode step or reuse until cleared by caller
+            setKeyEncoding(fixedKeyType);
         }
 
         @Override
         protected int getEncodedValueSize(final Map val)
         {
-            return 1 + ((val == _value) ? _length : calculateSize(val, getEncoder(), fixedKeyType));
+            return 1 + ((val == _value) ? _length : calculateSize(val));
         }
 
         @Override
@@ -344,17 +377,17 @@
         @Override
         public Map readValue()
         {
-            DecoderImpl decoder = getDecoder();
-            ReadableBuffer buffer = decoder.getBuffer();
+            final DecoderImpl decoder = getDecoder();
+            final ReadableBuffer buffer = decoder.getBuffer();
 
-            int size = (decoder.readRawByte()) & 0xff;
+            final int size = (decoder.readRawByte()) & 0xff;
             // todo - limit the decoder with size
-            int count = (decoder.readRawByte()) & 0xff;
+            final int count = (decoder.readRawByte()) & 0xff;
 
             TypeConstructor<?> keyConstructor = null;
             TypeConstructor<?> valueConstructor = null;
 
-            Map<Object, Object> map = new LinkedHashMap<>(count);
+            final Map<Object, Object> map = new LinkedHashMap<>(count);
             for(int i = 0; i < count / 2; i++)
             {
                 keyConstructor = findNextDecoder(decoder, buffer, keyConstructor);
@@ -363,10 +396,10 @@
                     throw new DecodeException("Unknown constructor");
                 }
 
-                Object key = keyConstructor.readValue();
+                final Object key = keyConstructor.readValue();
 
                 boolean arrayType = false;
-                byte code = buffer.get(buffer.position());
+                final byte code = buffer.get(buffer.position());
                 switch (code)
                 {
                     case EncodingCodes.ARRAY8:
@@ -400,9 +433,9 @@
         @Override
         public void skipValue()
         {
-            DecoderImpl decoder = getDecoder();
-            ReadableBuffer buffer = decoder.getBuffer();
-            int size = ((int)decoder.readRawByte()) & 0xff;
+            final DecoderImpl decoder = getDecoder();
+            final ReadableBuffer buffer = decoder.getBuffer();
+            final int size = ((int)decoder.readRawByte()) & 0xff;
             buffer.position(buffer.position() + size);
         }
 
diff --git a/proton-j/src/test/java/org/apache/qpid/proton/codec/MessageAnnotationsTypeCodecTest.java b/proton-j/src/test/java/org/apache/qpid/proton/codec/MessageAnnotationsTypeCodecTest.java
index d1f6ba7..a19b10e 100644
--- a/proton-j/src/test/java/org/apache/qpid/proton/codec/MessageAnnotationsTypeCodecTest.java
+++ b/proton-j/src/test/java/org/apache/qpid/proton/codec/MessageAnnotationsTypeCodecTest.java
@@ -23,6 +23,7 @@
 import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.UUID;
 
 import org.apache.qpid.proton.amqp.Symbol;
 import org.apache.qpid.proton.amqp.UnsignedByte;
@@ -51,14 +52,13 @@
         doTestDecodeMessageAnnotationsSeries(1);
     }
 
-    @SuppressWarnings({ "rawtypes", "unchecked" })
     private void doTestDecodeMessageAnnotationsSeries(int size) throws IOException {
 
         final Symbol SYMBOL_1 = Symbol.valueOf("test1");
         final Symbol SYMBOL_2 = Symbol.valueOf("test2");
         final Symbol SYMBOL_3 = Symbol.valueOf("test3");
 
-        MessageAnnotations annotations = new MessageAnnotations(new HashMap());
+        MessageAnnotations annotations = new MessageAnnotations(new HashMap<>());
         annotations.getValue().put(SYMBOL_1, UnsignedByte.valueOf((byte) 128));
         annotations.getValue().put(SYMBOL_2, UnsignedShort.valueOf((short) 128));
         annotations.getValue().put(SYMBOL_3, UnsignedInteger.valueOf(128));
@@ -85,4 +85,45 @@
             assertEquals(resultMap.get(SYMBOL_3), UnsignedInteger.valueOf(128));
         }
     }
+
+    @Test
+    public void testEncodeAndDecodeAnnoationsWithEmbeddedMaps() throws IOException {
+        final Symbol SYMBOL_1 = Symbol.valueOf("x-opt-test1");
+        final Symbol SYMBOL_2 = Symbol.valueOf("x-opt-test2");
+
+        final String VALUE_1 = "string";
+        final UnsignedInteger VALUE_2 = UnsignedInteger.valueOf(42);
+        final UUID VALUE_3 = UUID.randomUUID();
+
+        Map<String, Object> stringKeyedMap = new HashMap<>();
+        stringKeyedMap.put("key1", VALUE_1);
+        stringKeyedMap.put("key2", VALUE_2);
+        stringKeyedMap.put("key3", VALUE_3);
+
+        Map<Symbol, Object> symbolKeyedMap = new HashMap<>();
+        symbolKeyedMap.put(Symbol.valueOf("key1"), VALUE_1);
+        symbolKeyedMap.put(Symbol.valueOf("key2"), VALUE_2);
+        symbolKeyedMap.put(Symbol.valueOf("key3"), VALUE_3);
+
+        MessageAnnotations annotations = new MessageAnnotations(new HashMap<>());
+        annotations.getValue().put(SYMBOL_1, stringKeyedMap);
+        annotations.getValue().put(SYMBOL_2, symbolKeyedMap);
+
+        encoder.writeObject(annotations);
+
+        buffer.clear();
+
+        final Object result = decoder.readObject();
+
+        assertNotNull(result);
+        assertTrue(result instanceof MessageAnnotations);
+
+        MessageAnnotations readAnnotations = (MessageAnnotations) result;
+
+        Map<Symbol, Object> resultMap = readAnnotations.getValue();
+
+        assertEquals(annotations.getValue().size(), resultMap.size());
+        assertEquals(resultMap.get(SYMBOL_1), stringKeyedMap);
+        assertEquals(resultMap.get(SYMBOL_2), symbolKeyedMap);
+    }
 }