Merge pull request #29 from apache/allow_deny_list

AMQNET-829 Add allow, deny types support
diff --git a/src/Commands/ActiveMQObjectMessage.cs b/src/Commands/ActiveMQObjectMessage.cs
index 3919111..aec88b5 100644
--- a/src/Commands/ActiveMQObjectMessage.cs
+++ b/src/Commands/ActiveMQObjectMessage.cs
@@ -140,6 +140,10 @@
                 if (formatter == null)
                 {
                     formatter = new BinaryFormatter();
+                    if (Connection.DeserializationPolicy != null)
+                    {
+                        formatter.Binder = new TrustedClassFilter(Connection.DeserializationPolicy, Destination);
+                    }
                 }
                 return formatter;
             }
diff --git a/src/Commands/TrustedClassFilter.cs b/src/Commands/TrustedClassFilter.cs
new file mode 100644
index 0000000..eb90662
--- /dev/null
+++ b/src/Commands/TrustedClassFilter.cs
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using System.Reflection;
+using System.Runtime.Serialization;
+
+namespace Apache.NMS.ActiveMQ.Commands
+{
+    internal class TrustedClassFilter : SerializationBinder
+    {
+        private readonly INmsDeserializationPolicy deserializationPolicy;
+        private readonly IDestination destination;
+
+        public TrustedClassFilter(INmsDeserializationPolicy deserializationPolicy, IDestination destination)
+        {
+            this.deserializationPolicy = deserializationPolicy;
+            this.destination = destination;
+        }
+        
+        public override Type BindToType(string assemblyName, string typeName)
+        {
+            var name = new AssemblyName(assemblyName);
+            var assembly = Assembly.Load(name);
+            var type = FormatterServices.GetTypeFromAssembly(assembly, typeName);
+            if (deserializationPolicy.IsTrustedType(destination, type))
+            {
+                return type;
+            }
+
+            var message = $"Forbidden {type.FullName}! " +
+                          "This type is not trusted to be deserialized under the current configuration. " +
+                          "Please refer to the documentation for more information on how to configure trusted types.";
+            throw new SerializationException(message);
+        }
+    }
+}
\ No newline at end of file
diff --git a/src/Connection.cs b/src/Connection.cs
index df0cc57..a2d7df2 100644
--- a/src/Connection.cs
+++ b/src/Connection.cs
@@ -482,6 +482,8 @@
             get { return this.compressionPolicy; }
             set { this.compressionPolicy = value; }
         }
+        
+        public INmsDeserializationPolicy DeserializationPolicy { get; set; } = new NmsDefaultDeserializationPolicy();
 
         internal MessageTransformation MessageTransformation
         {
diff --git a/src/ConnectionFactory.cs b/src/ConnectionFactory.cs
index 84ecf22..6c778cb 100644
--- a/src/ConnectionFactory.cs
+++ b/src/ConnectionFactory.cs
@@ -397,6 +397,11 @@
 				}
 			}
 		}
+		
+		/// <summary>
+		/// The deserialization policy that is applied when a connection is created.
+		/// </summary>
+		public INmsDeserializationPolicy DeserializationPolicy { get; set; } = new NmsDefaultDeserializationPolicy();
 
 		public IdGenerator ClientIdGenerator
 		{
@@ -546,6 +551,7 @@
 			connection.RedeliveryPolicy = this.redeliveryPolicy.Clone() as IRedeliveryPolicy;
 			connection.PrefetchPolicy = this.prefetchPolicy.Clone() as PrefetchPolicy;
 			connection.CompressionPolicy = this.compressionPolicy.Clone() as ICompressionPolicy;
+			connection.DeserializationPolicy = this.DeserializationPolicy.Clone();
 			connection.ConsumerTransformer = this.consumerTransformer;
 			connection.ProducerTransformer = this.producerTransformer;
             connection.WatchTopicAdvisories = this.watchTopicAdvisories;
diff --git a/src/INmsDeserializationPolicy.cs b/src/INmsDeserializationPolicy.cs
new file mode 100644
index 0000000..743855d
--- /dev/null
+++ b/src/INmsDeserializationPolicy.cs
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+
+namespace Apache.NMS.ActiveMQ
+{
+    /// <summary>
+    /// Defines the interface for a policy that controls the permissible message content 
+    /// during the deserialization of the body of an incoming <see cref="IObjectMessage"/>.
+    /// </summary>
+    public interface INmsDeserializationPolicy
+    {
+        /// <summary>
+        /// Determines if the given class is a trusted type that can be deserialized by the client.
+        /// </summary>
+        /// <param name="destination">The Destination for the message containing the type to be deserialized.</param>
+        /// <param name="type">The type of the object that is about to be read.</param>
+        /// <returns>True if the type is trusted, otherwise false.</returns>
+        bool IsTrustedType(IDestination destination, Type type);
+        
+        /// <summary>
+        /// Makes a thread-safe copy of the INmsDeserializationPolicy object.
+        /// </summary>
+        /// <returns>A copy of INmsDeserializationPolicy object.</returns>
+        INmsDeserializationPolicy Clone();
+    }
+}
\ No newline at end of file
diff --git a/src/NmsDefaultDeserializationPolicy.cs b/src/NmsDefaultDeserializationPolicy.cs
new file mode 100644
index 0000000..b56531a
--- /dev/null
+++ b/src/NmsDefaultDeserializationPolicy.cs
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace Apache.NMS.ActiveMQ
+{
+    /// <summary>
+    /// Default implementation of the deserialization policy that can read allow and deny lists of
+    /// types/namespaces from the connection URI options.
+    ///
+    /// The policy reads a default deny list string value (comma separated) from the connection URI options
+    /// (nms.deserializationPolicy.deny) which defaults to null which indicates an empty deny list.
+    ///
+    /// The policy reads a default allow list string value (comma separated) from the connection URI options
+    /// (nms.deserializationPolicy.allowList) which defaults to <see cref="CATCH_ALL_WILDCARD"/> which
+    /// indicates that all types are allowed.
+    ///
+    /// The deny list overrides the allow list, entries that could match both are counted as denied.
+    ///
+    /// If the policy should treat all classes as untrusted, the deny list should be set to <see cref="CATCH_ALL_WILDCARD"/>.
+    /// </summary>
+    public class NmsDefaultDeserializationPolicy : INmsDeserializationPolicy
+    {
+        /// <summary>
+        /// Value used to indicate that all types should be allowed or denied
+        /// </summary>
+        public const string CATCH_ALL_WILDCARD = "*";
+        
+        private IReadOnlyList<string> denyList = Array.Empty<string>();
+        private IReadOnlyList<string> allowList = new[] { CATCH_ALL_WILDCARD };
+        
+        public bool IsTrustedType(IDestination destination, Type type)
+        {
+            var typeName = type?.FullName;
+            if (typeName == null)
+            {
+                return true;
+            }
+
+            foreach (var denyListEntry in denyList)
+            {
+                if (CATCH_ALL_WILDCARD == denyListEntry)
+                {
+                    return false;
+                }
+                if (IsTypeOrNamespaceMatch(typeName, denyListEntry))
+                {
+                    return false;
+                }
+            }
+
+            foreach (var allowListEntry in allowList)
+            {
+                if (CATCH_ALL_WILDCARD == allowListEntry)
+                {
+                    return true;
+                }
+                if (IsTypeOrNamespaceMatch(typeName, allowListEntry))
+                {
+                    return true;
+                }
+            }
+            
+            // Failing outright rejection or allow from above, reject.
+            return false;
+        }
+
+        private bool IsTypeOrNamespaceMatch(string typeName, string listEntry)
+        {
+            // Check if type is an exact match of the entry
+            if (typeName == listEntry)
+            {
+                return true;
+            }
+            
+            // Check if the type is from a namespace matching the entry
+            var entryLength = listEntry.Length;
+            return typeName.Length > entryLength && typeName.StartsWith(listEntry) && '.' == typeName[entryLength];
+        }
+
+        public INmsDeserializationPolicy Clone()
+        {
+            return new NmsDefaultDeserializationPolicy
+            {
+                allowList = allowList.ToArray(),
+                denyList = denyList.ToArray()
+            };
+        }
+
+        /// <summary>
+        /// Gets or sets the deny list on this policy instance.
+        /// </summary>
+        public string DenyList
+        {
+            get => string.Join(",", denyList);
+            set => denyList = string.IsNullOrWhiteSpace(value) 
+                ? Array.Empty<string>() 
+                : value.Split(',');
+        }
+
+        /// <summary>
+        /// Gets or sets the allow list on this policy instance.
+        /// </summary>
+        public string AllowList
+        {
+            get => string.Join(",", allowList);
+            set => allowList = string.IsNullOrWhiteSpace(value) 
+                ? Array.Empty<string>() 
+                : value.Split(',');
+        }
+    }
+}
\ No newline at end of file
diff --git a/test/MessageConsumerTest.cs b/test/MessageConsumerTest.cs
index e635434..4d59dcd 100644
--- a/test/MessageConsumerTest.cs
+++ b/test/MessageConsumerTest.cs
@@ -20,6 +20,7 @@
 using NUnit.Framework;
 using Apache.NMS.ActiveMQ.Commands;
 using System;
+using System.Runtime.Serialization;
 using Apache.NMS.Util;
 
 namespace Apache.NMS.ActiveMQ.Test
@@ -306,5 +307,91 @@
                 }
             }
         }
+        
+        [Test, Timeout(20_000)]
+        public void TestShouldNotDeserializeUntrustedType()
+        {
+            string uri = "activemq:tcp://${{activemqhost}}:61616";
+            var factory = new ConnectionFactory(ReplaceEnvVar(uri))
+            {
+                DeserializationPolicy = new NmsDefaultDeserializationPolicy
+                {
+                    DenyList = typeof(UntrustedType).FullName
+                }
+            };
+            using var connection = factory.CreateConnection("", "");
+
+            connection.Start();
+            var session = connection.CreateSession(AcknowledgementMode.AutoAcknowledge);
+            var queue = session.GetQueue(Guid.NewGuid().ToString());
+            var consumer = session.CreateConsumer(queue);
+            var producer = session.CreateProducer(queue);
+            
+            var message = producer.CreateObjectMessage(new UntrustedType { Prop1 = "foo" });
+            producer.Send(message);
+
+            var receivedMessage = consumer.Receive();
+            var objectMessage = receivedMessage as IObjectMessage;
+            Assert.NotNull(objectMessage);
+            var exception = Assert.Throws<SerializationException>(() =>
+            {
+                _ = objectMessage.Body;
+            });
+            Assert.AreEqual($"Forbidden {typeof(UntrustedType).FullName}! " +
+                            "This type is not trusted to be deserialized under the current configuration. " +
+                            "Please refer to the documentation for more information on how to configure trusted types.",
+                exception.Message);
+        }
+        
+        [Test]
+        public void TestShouldUseCustomDeserializationPolicy()
+        {
+            string uri = "activemq:tcp://${{activemqhost}}:61616";
+            var factory = new ConnectionFactory(ReplaceEnvVar(uri))
+            {
+                DeserializationPolicy = new CustomDeserializationPolicy()
+            };
+            using var connection = factory.CreateConnection("", "");
+            connection.Start();
+            var session = connection.CreateSession(AcknowledgementMode.AutoAcknowledge);
+            var queue = session.GetQueue(Guid.NewGuid().ToString());
+            var consumer = session.CreateConsumer(queue);
+            var producer = session.CreateProducer(queue);
+            
+            var message = producer.CreateObjectMessage(new UntrustedType { Prop1 = "foo" });
+            producer.Send(message);
+
+            var receivedMessage = consumer.Receive();
+            var objectMessage = receivedMessage as IObjectMessage;
+            Assert.NotNull(objectMessage);
+            _ = Assert.Throws<SerializationException>(() =>
+            {
+                _ = objectMessage.Body;
+            });
+        }
+        
+        [Serializable]
+        public class UntrustedType
+        {
+            public string Prop1 { get; set; }
+        }
+
+        private class CustomDeserializationPolicy : INmsDeserializationPolicy
+        {
+            public bool IsTrustedType(IDestination destination, Type type)
+            {
+                if (type == typeof(UntrustedType))
+                {
+                    return false;
+                }
+
+                return true;
+            }
+
+            public INmsDeserializationPolicy Clone()
+            {
+                return this;
+            }
+        }
     }
 }
diff --git a/test/NMSConnectionFactoryTest.cs b/test/NMSConnectionFactoryTest.cs
index 5a6ce48..5be6ff6 100644
--- a/test/NMSConnectionFactoryTest.cs
+++ b/test/NMSConnectionFactoryTest.cs
@@ -211,6 +211,28 @@
 
 				connection.Close();
 			}
-        }		
+        }
+
+        [Test]
+        public void TestSetDeserializationPolicy()
+        {
+	        string baseUri = "activemq:tcp://${{activemqhost}}:61616";
+	        string configuredUri = baseUri +
+	                               "?nms.deserializationPolicy.allowList=a,b,c" +
+	                               "&nms.deserializationPolicy.denyList=c,d,e";
+
+	        var factory = new NMSConnectionFactory(NMSTestSupport.ReplaceEnvVar(configuredUri));
+
+	        Assert.IsNotNull(factory);
+	        Assert.IsNotNull(factory.ConnectionFactory);
+	        using IConnection connection = factory.CreateConnection("", "");
+	        Assert.IsNotNull(connection);
+	        var amqConnection = connection as Connection;
+	        var deserializationPolicy = amqConnection.DeserializationPolicy as NmsDefaultDeserializationPolicy;
+	        Assert.IsNotNull(deserializationPolicy);
+	        Assert.AreEqual("a,b,c", deserializationPolicy.AllowList);
+	        Assert.AreEqual("c,d,e", deserializationPolicy.DenyList);
+	        connection.Close();
+        }
     }
 }
diff --git a/test/NmsDefaultDeserializationPolicyTest.cs b/test/NmsDefaultDeserializationPolicyTest.cs
new file mode 100644
index 0000000..7ac772b
--- /dev/null
+++ b/test/NmsDefaultDeserializationPolicyTest.cs
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using Apache.NMS.Commands;
+using NUnit.Framework;
+
+namespace Apache.NMS.ActiveMQ.Test
+{
+    [TestFixture]
+    public class NmsDefaultDeserializationPolicyTest
+    {
+        [Test]
+        public void TestIsTrustedType()
+        {
+            var destination = new Queue("test-queue");
+            var policy = new NmsDefaultDeserializationPolicy();
+            
+            Assert.True(policy.IsTrustedType(destination, null));
+            Assert.True(policy.IsTrustedType(destination, typeof(Guid)));
+            Assert.True(policy.IsTrustedType(destination, typeof(string)));
+            Assert.True(policy.IsTrustedType(destination, typeof(bool)));
+            Assert.True(policy.IsTrustedType(destination, typeof(double)));
+            Assert.True(policy.IsTrustedType(destination, typeof(object)));
+            
+            // Only types in System
+            policy.AllowList = "System";
+            Assert.True(policy.IsTrustedType(destination, null));
+            Assert.True(policy.IsTrustedType(destination, typeof(Guid)));
+            Assert.True(policy.IsTrustedType(destination, typeof(string)));
+            Assert.True(policy.IsTrustedType(destination, typeof(bool)));
+            Assert.True(policy.IsTrustedType(destination, typeof(double)));
+            Assert.True(policy.IsTrustedType(destination, typeof(object)));
+            Assert.False(policy.IsTrustedType(destination, GetType()));
+            
+            // Entry must be complete namespace name prefix to match
+            // i.e. while "System.C" is a prefix of "System.Collections", this
+            // wont match the Queue class below.
+            policy.AllowList = "System.C";
+            Assert.False(policy.IsTrustedType(destination, typeof(Guid)));
+            Assert.False(policy.IsTrustedType(destination, typeof(string)));
+            Assert.False(policy.IsTrustedType(destination, typeof(System.Collections.Queue)));
+
+            // Add a non-core namespace
+            policy.AllowList = $"System,{GetType().Namespace}";
+            Assert.True(policy.IsTrustedType(destination, typeof(string)));
+            Assert.True(policy.IsTrustedType(destination, GetType()));
+            
+            // Try with a type-specific entry
+            policy.AllowList = typeof(string).FullName;
+            Assert.True(policy.IsTrustedType(destination, typeof(string)));
+            Assert.False(policy.IsTrustedType(destination, typeof(bool)));
+            
+            // Verify deny list overrides allow list
+            policy.AllowList = "System";
+            policy.DenyList = "System";
+            Assert.False(policy.IsTrustedType(destination, typeof(string)));
+            
+            // Verify deny list entry prefix overrides allow list
+            policy.AllowList = typeof(string).FullName;
+            policy.DenyList = typeof(string).Namespace;
+            Assert.False(policy.IsTrustedType(destination, typeof(string)));
+            
+            // Verify deny list catch-all overrides allow list
+            policy.AllowList = typeof(string).FullName;
+            policy.DenyList = NmsDefaultDeserializationPolicy.CATCH_ALL_WILDCARD;
+            Assert.False(policy.IsTrustedType(destination, typeof(string)));
+        }
+
+        [Test]
+        public void TestNmsDefaultDeserializationPolicy()
+        {
+            var policy = new NmsDefaultDeserializationPolicy();
+            
+            Assert.IsNotEmpty(policy.AllowList);
+            Assert.IsEmpty(policy.DenyList);
+        }
+
+        [Test]
+        public void TestNmsDefaultDeserializationPolicyClone()
+        {
+            var policy = new NmsDefaultDeserializationPolicy
+            {
+                AllowList = "a.b.c",
+                DenyList = "d.e.f"
+            };
+
+            var clone = (NmsDefaultDeserializationPolicy) policy.Clone();
+            Assert.AreEqual(policy.AllowList, clone.AllowList);
+            Assert.AreEqual(policy.DenyList, clone.DenyList);
+            Assert.AreNotSame(clone, policy);
+        }
+
+        [Test]
+        public void TestSetAllowList()
+        {
+            var policy = new NmsDefaultDeserializationPolicy();
+            Assert.NotNull(policy.AllowList);
+
+            policy.AllowList = null;
+            Assert.NotNull(policy.AllowList);
+            Assert.IsEmpty(policy.AllowList);
+
+            policy.AllowList = string.Empty;
+            Assert.NotNull(policy.AllowList);
+            Assert.IsEmpty(policy.AllowList);
+            
+            policy.AllowList = "*";
+            Assert.NotNull(policy.AllowList);
+            Assert.IsNotEmpty(policy.AllowList);
+            
+            policy.AllowList = "a,b,c";
+            Assert.NotNull(policy.AllowList);
+            Assert.IsNotEmpty(policy.AllowList);
+            Assert.AreEqual("a,b,c", policy.AllowList);
+        }
+        
+        [Test]
+        public void TestSetDenyList()
+        {
+            var policy = new NmsDefaultDeserializationPolicy();
+            Assert.NotNull(policy.DenyList);
+
+            policy.DenyList = null;
+            Assert.NotNull(policy.DenyList);
+            Assert.IsEmpty(policy.DenyList);
+
+            policy.DenyList = string.Empty;
+            Assert.NotNull(policy.DenyList);
+            Assert.IsEmpty(policy.DenyList);
+            
+            policy.DenyList = "*";
+            Assert.NotNull(policy.DenyList);
+            Assert.IsNotEmpty(policy.DenyList);
+            
+            policy.DenyList = "a,b,c";
+            Assert.NotNull(policy.DenyList);
+            Assert.IsNotEmpty(policy.DenyList);
+            Assert.AreEqual("a,b,c", policy.DenyList);
+        }
+    }
+}
\ No newline at end of file