Merge pull request #29 from apache/allow_deny_list
AMQNET-829 Add allow, deny types support
diff --git a/src/Commands/ActiveMQObjectMessage.cs b/src/Commands/ActiveMQObjectMessage.cs
index 3919111..aec88b5 100644
--- a/src/Commands/ActiveMQObjectMessage.cs
+++ b/src/Commands/ActiveMQObjectMessage.cs
@@ -140,6 +140,10 @@
if (formatter == null)
{
formatter = new BinaryFormatter();
+ if (Connection.DeserializationPolicy != null)
+ {
+ formatter.Binder = new TrustedClassFilter(Connection.DeserializationPolicy, Destination);
+ }
}
return formatter;
}
diff --git a/src/Commands/TrustedClassFilter.cs b/src/Commands/TrustedClassFilter.cs
new file mode 100644
index 0000000..eb90662
--- /dev/null
+++ b/src/Commands/TrustedClassFilter.cs
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using System.Reflection;
+using System.Runtime.Serialization;
+
+namespace Apache.NMS.ActiveMQ.Commands
+{
+ internal class TrustedClassFilter : SerializationBinder
+ {
+ private readonly INmsDeserializationPolicy deserializationPolicy;
+ private readonly IDestination destination;
+
+ public TrustedClassFilter(INmsDeserializationPolicy deserializationPolicy, IDestination destination)
+ {
+ this.deserializationPolicy = deserializationPolicy;
+ this.destination = destination;
+ }
+
+ public override Type BindToType(string assemblyName, string typeName)
+ {
+ var name = new AssemblyName(assemblyName);
+ var assembly = Assembly.Load(name);
+ var type = FormatterServices.GetTypeFromAssembly(assembly, typeName);
+ if (deserializationPolicy.IsTrustedType(destination, type))
+ {
+ return type;
+ }
+
+ var message = $"Forbidden {type.FullName}! " +
+ "This type is not trusted to be deserialized under the current configuration. " +
+ "Please refer to the documentation for more information on how to configure trusted types.";
+ throw new SerializationException(message);
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Connection.cs b/src/Connection.cs
index df0cc57..a2d7df2 100644
--- a/src/Connection.cs
+++ b/src/Connection.cs
@@ -482,6 +482,8 @@
get { return this.compressionPolicy; }
set { this.compressionPolicy = value; }
}
+
+ public INmsDeserializationPolicy DeserializationPolicy { get; set; } = new NmsDefaultDeserializationPolicy();
internal MessageTransformation MessageTransformation
{
diff --git a/src/ConnectionFactory.cs b/src/ConnectionFactory.cs
index 84ecf22..6c778cb 100644
--- a/src/ConnectionFactory.cs
+++ b/src/ConnectionFactory.cs
@@ -397,6 +397,11 @@
}
}
}
+
+ /// <summary>
+ /// The deserialization policy that is applied when a connection is created.
+ /// </summary>
+ public INmsDeserializationPolicy DeserializationPolicy { get; set; } = new NmsDefaultDeserializationPolicy();
public IdGenerator ClientIdGenerator
{
@@ -546,6 +551,7 @@
connection.RedeliveryPolicy = this.redeliveryPolicy.Clone() as IRedeliveryPolicy;
connection.PrefetchPolicy = this.prefetchPolicy.Clone() as PrefetchPolicy;
connection.CompressionPolicy = this.compressionPolicy.Clone() as ICompressionPolicy;
+ connection.DeserializationPolicy = this.DeserializationPolicy.Clone();
connection.ConsumerTransformer = this.consumerTransformer;
connection.ProducerTransformer = this.producerTransformer;
connection.WatchTopicAdvisories = this.watchTopicAdvisories;
diff --git a/src/INmsDeserializationPolicy.cs b/src/INmsDeserializationPolicy.cs
new file mode 100644
index 0000000..743855d
--- /dev/null
+++ b/src/INmsDeserializationPolicy.cs
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+
+namespace Apache.NMS.ActiveMQ
+{
+ /// <summary>
+ /// Defines the interface for a policy that controls the permissible message content
+ /// during the deserialization of the body of an incoming <see cref="IObjectMessage"/>.
+ /// </summary>
+ public interface INmsDeserializationPolicy
+ {
+ /// <summary>
+ /// Determines if the given class is a trusted type that can be deserialized by the client.
+ /// </summary>
+ /// <param name="destination">The Destination for the message containing the type to be deserialized.</param>
+ /// <param name="type">The type of the object that is about to be read.</param>
+ /// <returns>True if the type is trusted, otherwise false.</returns>
+ bool IsTrustedType(IDestination destination, Type type);
+
+ /// <summary>
+ /// Makes a thread-safe copy of the INmsDeserializationPolicy object.
+ /// </summary>
+ /// <returns>A copy of INmsDeserializationPolicy object.</returns>
+ INmsDeserializationPolicy Clone();
+ }
+}
\ No newline at end of file
diff --git a/src/NmsDefaultDeserializationPolicy.cs b/src/NmsDefaultDeserializationPolicy.cs
new file mode 100644
index 0000000..b56531a
--- /dev/null
+++ b/src/NmsDefaultDeserializationPolicy.cs
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace Apache.NMS.ActiveMQ
+{
+ /// <summary>
+ /// Default implementation of the deserialization policy that can read allow and deny lists of
+ /// types/namespaces from the connection URI options.
+ ///
+ /// The policy reads a default deny list string value (comma separated) from the connection URI options
+ /// (nms.deserializationPolicy.deny) which defaults to null which indicates an empty deny list.
+ ///
+ /// The policy reads a default allow list string value (comma separated) from the connection URI options
+ /// (nms.deserializationPolicy.allowList) which defaults to <see cref="CATCH_ALL_WILDCARD"/> which
+ /// indicates that all types are allowed.
+ ///
+ /// The deny list overrides the allow list, entries that could match both are counted as denied.
+ ///
+ /// If the policy should treat all classes as untrusted, the deny list should be set to <see cref="CATCH_ALL_WILDCARD"/>.
+ /// </summary>
+ public class NmsDefaultDeserializationPolicy : INmsDeserializationPolicy
+ {
+ /// <summary>
+ /// Value used to indicate that all types should be allowed or denied
+ /// </summary>
+ public const string CATCH_ALL_WILDCARD = "*";
+
+ private IReadOnlyList<string> denyList = Array.Empty<string>();
+ private IReadOnlyList<string> allowList = new[] { CATCH_ALL_WILDCARD };
+
+ public bool IsTrustedType(IDestination destination, Type type)
+ {
+ var typeName = type?.FullName;
+ if (typeName == null)
+ {
+ return true;
+ }
+
+ foreach (var denyListEntry in denyList)
+ {
+ if (CATCH_ALL_WILDCARD == denyListEntry)
+ {
+ return false;
+ }
+ if (IsTypeOrNamespaceMatch(typeName, denyListEntry))
+ {
+ return false;
+ }
+ }
+
+ foreach (var allowListEntry in allowList)
+ {
+ if (CATCH_ALL_WILDCARD == allowListEntry)
+ {
+ return true;
+ }
+ if (IsTypeOrNamespaceMatch(typeName, allowListEntry))
+ {
+ return true;
+ }
+ }
+
+ // Failing outright rejection or allow from above, reject.
+ return false;
+ }
+
+ private bool IsTypeOrNamespaceMatch(string typeName, string listEntry)
+ {
+ // Check if type is an exact match of the entry
+ if (typeName == listEntry)
+ {
+ return true;
+ }
+
+ // Check if the type is from a namespace matching the entry
+ var entryLength = listEntry.Length;
+ return typeName.Length > entryLength && typeName.StartsWith(listEntry) && '.' == typeName[entryLength];
+ }
+
+ public INmsDeserializationPolicy Clone()
+ {
+ return new NmsDefaultDeserializationPolicy
+ {
+ allowList = allowList.ToArray(),
+ denyList = denyList.ToArray()
+ };
+ }
+
+ /// <summary>
+ /// Gets or sets the deny list on this policy instance.
+ /// </summary>
+ public string DenyList
+ {
+ get => string.Join(",", denyList);
+ set => denyList = string.IsNullOrWhiteSpace(value)
+ ? Array.Empty<string>()
+ : value.Split(',');
+ }
+
+ /// <summary>
+ /// Gets or sets the allow list on this policy instance.
+ /// </summary>
+ public string AllowList
+ {
+ get => string.Join(",", allowList);
+ set => allowList = string.IsNullOrWhiteSpace(value)
+ ? Array.Empty<string>()
+ : value.Split(',');
+ }
+ }
+}
\ No newline at end of file
diff --git a/test/MessageConsumerTest.cs b/test/MessageConsumerTest.cs
index e635434..4d59dcd 100644
--- a/test/MessageConsumerTest.cs
+++ b/test/MessageConsumerTest.cs
@@ -20,6 +20,7 @@
using NUnit.Framework;
using Apache.NMS.ActiveMQ.Commands;
using System;
+using System.Runtime.Serialization;
using Apache.NMS.Util;
namespace Apache.NMS.ActiveMQ.Test
@@ -306,5 +307,91 @@
}
}
}
+
+ [Test, Timeout(20_000)]
+ public void TestShouldNotDeserializeUntrustedType()
+ {
+ string uri = "activemq:tcp://${{activemqhost}}:61616";
+ var factory = new ConnectionFactory(ReplaceEnvVar(uri))
+ {
+ DeserializationPolicy = new NmsDefaultDeserializationPolicy
+ {
+ DenyList = typeof(UntrustedType).FullName
+ }
+ };
+ using var connection = factory.CreateConnection("", "");
+
+ connection.Start();
+ var session = connection.CreateSession(AcknowledgementMode.AutoAcknowledge);
+ var queue = session.GetQueue(Guid.NewGuid().ToString());
+ var consumer = session.CreateConsumer(queue);
+ var producer = session.CreateProducer(queue);
+
+ var message = producer.CreateObjectMessage(new UntrustedType { Prop1 = "foo" });
+ producer.Send(message);
+
+ var receivedMessage = consumer.Receive();
+ var objectMessage = receivedMessage as IObjectMessage;
+ Assert.NotNull(objectMessage);
+ var exception = Assert.Throws<SerializationException>(() =>
+ {
+ _ = objectMessage.Body;
+ });
+ Assert.AreEqual($"Forbidden {typeof(UntrustedType).FullName}! " +
+ "This type is not trusted to be deserialized under the current configuration. " +
+ "Please refer to the documentation for more information on how to configure trusted types.",
+ exception.Message);
+ }
+
+ [Test]
+ public void TestShouldUseCustomDeserializationPolicy()
+ {
+ string uri = "activemq:tcp://${{activemqhost}}:61616";
+ var factory = new ConnectionFactory(ReplaceEnvVar(uri))
+ {
+ DeserializationPolicy = new CustomDeserializationPolicy()
+ };
+ using var connection = factory.CreateConnection("", "");
+ connection.Start();
+ var session = connection.CreateSession(AcknowledgementMode.AutoAcknowledge);
+ var queue = session.GetQueue(Guid.NewGuid().ToString());
+ var consumer = session.CreateConsumer(queue);
+ var producer = session.CreateProducer(queue);
+
+ var message = producer.CreateObjectMessage(new UntrustedType { Prop1 = "foo" });
+ producer.Send(message);
+
+ var receivedMessage = consumer.Receive();
+ var objectMessage = receivedMessage as IObjectMessage;
+ Assert.NotNull(objectMessage);
+ _ = Assert.Throws<SerializationException>(() =>
+ {
+ _ = objectMessage.Body;
+ });
+ }
+
+ [Serializable]
+ public class UntrustedType
+ {
+ public string Prop1 { get; set; }
+ }
+
+ private class CustomDeserializationPolicy : INmsDeserializationPolicy
+ {
+ public bool IsTrustedType(IDestination destination, Type type)
+ {
+ if (type == typeof(UntrustedType))
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ public INmsDeserializationPolicy Clone()
+ {
+ return this;
+ }
+ }
}
}
diff --git a/test/NMSConnectionFactoryTest.cs b/test/NMSConnectionFactoryTest.cs
index 5a6ce48..5be6ff6 100644
--- a/test/NMSConnectionFactoryTest.cs
+++ b/test/NMSConnectionFactoryTest.cs
@@ -211,6 +211,28 @@
connection.Close();
}
- }
+ }
+
+ [Test]
+ public void TestSetDeserializationPolicy()
+ {
+ string baseUri = "activemq:tcp://${{activemqhost}}:61616";
+ string configuredUri = baseUri +
+ "?nms.deserializationPolicy.allowList=a,b,c" +
+ "&nms.deserializationPolicy.denyList=c,d,e";
+
+ var factory = new NMSConnectionFactory(NMSTestSupport.ReplaceEnvVar(configuredUri));
+
+ Assert.IsNotNull(factory);
+ Assert.IsNotNull(factory.ConnectionFactory);
+ using IConnection connection = factory.CreateConnection("", "");
+ Assert.IsNotNull(connection);
+ var amqConnection = connection as Connection;
+ var deserializationPolicy = amqConnection.DeserializationPolicy as NmsDefaultDeserializationPolicy;
+ Assert.IsNotNull(deserializationPolicy);
+ Assert.AreEqual("a,b,c", deserializationPolicy.AllowList);
+ Assert.AreEqual("c,d,e", deserializationPolicy.DenyList);
+ connection.Close();
+ }
}
}
diff --git a/test/NmsDefaultDeserializationPolicyTest.cs b/test/NmsDefaultDeserializationPolicyTest.cs
new file mode 100644
index 0000000..7ac772b
--- /dev/null
+++ b/test/NmsDefaultDeserializationPolicyTest.cs
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using Apache.NMS.Commands;
+using NUnit.Framework;
+
+namespace Apache.NMS.ActiveMQ.Test
+{
+ [TestFixture]
+ public class NmsDefaultDeserializationPolicyTest
+ {
+ [Test]
+ public void TestIsTrustedType()
+ {
+ var destination = new Queue("test-queue");
+ var policy = new NmsDefaultDeserializationPolicy();
+
+ Assert.True(policy.IsTrustedType(destination, null));
+ Assert.True(policy.IsTrustedType(destination, typeof(Guid)));
+ Assert.True(policy.IsTrustedType(destination, typeof(string)));
+ Assert.True(policy.IsTrustedType(destination, typeof(bool)));
+ Assert.True(policy.IsTrustedType(destination, typeof(double)));
+ Assert.True(policy.IsTrustedType(destination, typeof(object)));
+
+ // Only types in System
+ policy.AllowList = "System";
+ Assert.True(policy.IsTrustedType(destination, null));
+ Assert.True(policy.IsTrustedType(destination, typeof(Guid)));
+ Assert.True(policy.IsTrustedType(destination, typeof(string)));
+ Assert.True(policy.IsTrustedType(destination, typeof(bool)));
+ Assert.True(policy.IsTrustedType(destination, typeof(double)));
+ Assert.True(policy.IsTrustedType(destination, typeof(object)));
+ Assert.False(policy.IsTrustedType(destination, GetType()));
+
+ // Entry must be complete namespace name prefix to match
+ // i.e. while "System.C" is a prefix of "System.Collections", this
+ // wont match the Queue class below.
+ policy.AllowList = "System.C";
+ Assert.False(policy.IsTrustedType(destination, typeof(Guid)));
+ Assert.False(policy.IsTrustedType(destination, typeof(string)));
+ Assert.False(policy.IsTrustedType(destination, typeof(System.Collections.Queue)));
+
+ // Add a non-core namespace
+ policy.AllowList = $"System,{GetType().Namespace}";
+ Assert.True(policy.IsTrustedType(destination, typeof(string)));
+ Assert.True(policy.IsTrustedType(destination, GetType()));
+
+ // Try with a type-specific entry
+ policy.AllowList = typeof(string).FullName;
+ Assert.True(policy.IsTrustedType(destination, typeof(string)));
+ Assert.False(policy.IsTrustedType(destination, typeof(bool)));
+
+ // Verify deny list overrides allow list
+ policy.AllowList = "System";
+ policy.DenyList = "System";
+ Assert.False(policy.IsTrustedType(destination, typeof(string)));
+
+ // Verify deny list entry prefix overrides allow list
+ policy.AllowList = typeof(string).FullName;
+ policy.DenyList = typeof(string).Namespace;
+ Assert.False(policy.IsTrustedType(destination, typeof(string)));
+
+ // Verify deny list catch-all overrides allow list
+ policy.AllowList = typeof(string).FullName;
+ policy.DenyList = NmsDefaultDeserializationPolicy.CATCH_ALL_WILDCARD;
+ Assert.False(policy.IsTrustedType(destination, typeof(string)));
+ }
+
+ [Test]
+ public void TestNmsDefaultDeserializationPolicy()
+ {
+ var policy = new NmsDefaultDeserializationPolicy();
+
+ Assert.IsNotEmpty(policy.AllowList);
+ Assert.IsEmpty(policy.DenyList);
+ }
+
+ [Test]
+ public void TestNmsDefaultDeserializationPolicyClone()
+ {
+ var policy = new NmsDefaultDeserializationPolicy
+ {
+ AllowList = "a.b.c",
+ DenyList = "d.e.f"
+ };
+
+ var clone = (NmsDefaultDeserializationPolicy) policy.Clone();
+ Assert.AreEqual(policy.AllowList, clone.AllowList);
+ Assert.AreEqual(policy.DenyList, clone.DenyList);
+ Assert.AreNotSame(clone, policy);
+ }
+
+ [Test]
+ public void TestSetAllowList()
+ {
+ var policy = new NmsDefaultDeserializationPolicy();
+ Assert.NotNull(policy.AllowList);
+
+ policy.AllowList = null;
+ Assert.NotNull(policy.AllowList);
+ Assert.IsEmpty(policy.AllowList);
+
+ policy.AllowList = string.Empty;
+ Assert.NotNull(policy.AllowList);
+ Assert.IsEmpty(policy.AllowList);
+
+ policy.AllowList = "*";
+ Assert.NotNull(policy.AllowList);
+ Assert.IsNotEmpty(policy.AllowList);
+
+ policy.AllowList = "a,b,c";
+ Assert.NotNull(policy.AllowList);
+ Assert.IsNotEmpty(policy.AllowList);
+ Assert.AreEqual("a,b,c", policy.AllowList);
+ }
+
+ [Test]
+ public void TestSetDenyList()
+ {
+ var policy = new NmsDefaultDeserializationPolicy();
+ Assert.NotNull(policy.DenyList);
+
+ policy.DenyList = null;
+ Assert.NotNull(policy.DenyList);
+ Assert.IsEmpty(policy.DenyList);
+
+ policy.DenyList = string.Empty;
+ Assert.NotNull(policy.DenyList);
+ Assert.IsEmpty(policy.DenyList);
+
+ policy.DenyList = "*";
+ Assert.NotNull(policy.DenyList);
+ Assert.IsNotEmpty(policy.DenyList);
+
+ policy.DenyList = "a,b,c";
+ Assert.NotNull(policy.DenyList);
+ Assert.IsNotEmpty(policy.DenyList);
+ Assert.AreEqual("a,b,c", policy.DenyList);
+ }
+ }
+}
\ No newline at end of file