blob: b468468d7795166220f464630ca0ed395ee1cd0c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.qpid.systests.jms_1_1.extensions.message;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.io.Serializable;
import java.lang.reflect.Field;
import java.util.HashMap;
import javax.jms.Connection;
import javax.jms.JMSException;
import javax.jms.Message;
import javax.jms.MessageConsumer;
import javax.jms.MessageFormatException;
import javax.jms.MessageProducer;
import javax.jms.ObjectMessage;
import javax.jms.Queue;
import javax.jms.Session;
import org.junit.Test;
import org.apache.qpid.systests.JmsTestBase;
public class ObjectMessageClassWhitelistingTest extends JmsTestBase
{
private static final int TEST_VALUE = 37;
@Test
public void testObjectMessage() throws Exception
{
Queue destination = createQueue(getTestName());
final Connection c = getConnectionBuilder().setDeserializationPolicyWhiteList("*").build();
try
{
c.start();
Session s = c.createSession(false, Session.AUTO_ACKNOWLEDGE);
MessageConsumer _consumer = s.createConsumer(destination);
MessageProducer _producer = s.createProducer(destination);
sendTestObjectMessage(s, _producer);
Message receivedMessage = _consumer.receive(getReceiveTimeout());
assertNotNull("did not receive message within receive timeout", receivedMessage);
assertTrue("message is of wrong type", receivedMessage instanceof ObjectMessage);
ObjectMessage receivedObjectMessage = (ObjectMessage) receivedMessage;
Object payloadObject = receivedObjectMessage.getObject();
assertTrue("payload is of wrong type", payloadObject instanceof HashMap);
@SuppressWarnings("unchecked")
HashMap<String, Integer> payload = (HashMap<String, Integer>) payloadObject;
assertEquals("payload has wrong value", (Integer) TEST_VALUE, payload.get("value"));
}
finally
{
c.close();
}
}
@Test
public void testNotWhiteListedByConnectionUrlObjectMessage() throws Exception
{
Queue destination = createQueue(getTestName());
final Connection c = getConnectionBuilder().setDeserializationPolicyWhiteList("org.apache.qpid").build();
try
{
c.start();
Session s = c.createSession(false, Session.AUTO_ACKNOWLEDGE);
MessageConsumer consumer = s.createConsumer(destination);
MessageProducer producer = s.createProducer(destination);
sendTestObjectMessage(s, producer);
Message receivedMessage = consumer.receive(getReceiveTimeout());
assertNotNull("did not receive message within receive timeout", receivedMessage);
assertTrue("message is of wrong type", receivedMessage instanceof ObjectMessage);
ObjectMessage receivedObjectMessage = (ObjectMessage) receivedMessage;
try
{
receivedObjectMessage.getObject();
fail("should not deserialize class");
}
catch (MessageFormatException e)
{
// pass
}
}
finally
{
c.close();
}
}
@Test
public void testWhiteListedClassByConnectionUrlObjectMessage() throws Exception
{
Queue destination = createQueue(getTestName());
final Connection c =
getConnectionBuilder().setDeserializationPolicyWhiteList("java.util.HashMap,java.lang").build();
try
{
c.start();
Session s = c.createSession(false, Session.AUTO_ACKNOWLEDGE);
MessageConsumer consumer = s.createConsumer(destination);
MessageProducer producer = s.createProducer(destination);
sendTestObjectMessage(s, producer);
Message receivedMessage = consumer.receive(getReceiveTimeout());
assertNotNull("did not receive message within receive timeout", receivedMessage);
assertTrue("message is of wrong type", receivedMessage instanceof ObjectMessage);
ObjectMessage receivedObjectMessage = (ObjectMessage) receivedMessage;
@SuppressWarnings("unchecked")
HashMap<String, Integer> object = (HashMap<String, Integer>) receivedObjectMessage.getObject();
assertEquals("Unexpected value", (Integer) TEST_VALUE, object.get("value"));
}
finally
{
c.close();
}
}
@Test
public void testBlackListedClassByConnectionUrlObjectMessage() throws Exception
{
Queue destination = createQueue(getTestName());
final Connection c = getConnectionBuilder().setDeserializationPolicyWhiteList("java")
.setDeserializationPolicyBlackList("java.lang.Integer")
.build();
try
{
c.start();
Session s = c.createSession(false, Session.AUTO_ACKNOWLEDGE);
MessageConsumer consumer = s.createConsumer(destination);
MessageProducer producer = s.createProducer(destination);
sendTestObjectMessage(s, producer);
Message receivedMessage = consumer.receive(getReceiveTimeout());
assertNotNull("did not receive message within receive timeout", receivedMessage);
assertTrue("message is of wrong type", receivedMessage instanceof ObjectMessage);
ObjectMessage receivedObjectMessage = (ObjectMessage) receivedMessage;
try
{
receivedObjectMessage.getObject();
fail("Should not be allowed to deserialize black listed class");
}
catch (JMSException e)
{
// pass
}
}
finally
{
c.close();
}
}
@Test
public void testWhiteListedAnonymousClassByConnectionUrlObjectMessage() throws Exception
{
final Connection c =
getConnectionBuilder().setDeserializationPolicyWhiteList(ObjectMessageClassWhitelistingTest.class.getCanonicalName())
.build();
try
{
doTestWhiteListedEnclosedClassTest(c, createAnonymousObject(TEST_VALUE));
}
finally
{
c.close();
}
}
@Test
public void testBlackListedAnonymousClassByConnectionUrlObjectMessage() throws Exception
{
final Connection c = getConnectionBuilder()
.setDeserializationPolicyWhiteList(ObjectMessageClassWhitelistingTest.class.getPackage().getName())
.setDeserializationPolicyBlackList(ObjectMessageClassWhitelistingTest.class.getCanonicalName())
.build();
try
{
doTestBlackListedEnclosedClassTest(c, createAnonymousObject(TEST_VALUE));
}
finally
{
c.close();
}
}
@Test
public void testWhiteListedNestedClassByConnectionUrlObjectMessage() throws Exception
{
final Connection c = getConnectionBuilder()
.setDeserializationPolicyWhiteList(ObjectMessageClassWhitelistingTest.NestedClass.class.getCanonicalName())
.build();
try
{
doTestWhiteListedEnclosedClassTest(c, new NestedClass(TEST_VALUE));
}
finally
{
c.close();
}
}
@Test
public void testBlackListedNestedClassByConnectionUrlObjectMessage() throws Exception
{
final Connection c = getConnectionBuilder()
.setDeserializationPolicyWhiteList(ObjectMessageClassWhitelistingTest.class.getCanonicalName())
.setDeserializationPolicyBlackList(NestedClass.class.getCanonicalName())
.build();
try
{
doTestBlackListedEnclosedClassTest(c, new NestedClass(TEST_VALUE));
}
finally
{
c.close();
}
}
private void doTestWhiteListedEnclosedClassTest(Connection c, Serializable content) throws Exception
{
Queue destination = createQueue(getTestName());
c.start();
Session s = c.createSession(false, Session.AUTO_ACKNOWLEDGE);
MessageConsumer consumer = s.createConsumer(destination);
MessageProducer producer = s.createProducer(destination);
final ObjectMessage sendMessage = s.createObjectMessage();
sendMessage.setObject(content);
producer.send(sendMessage);
Message receivedMessage = consumer.receive(getReceiveTimeout());
assertNotNull("did not receive message within receive timeout", receivedMessage);
assertTrue("message is of wrong type", receivedMessage instanceof ObjectMessage);
Object receivedObject = ((ObjectMessage) receivedMessage).getObject();
assertEquals("Received object has unexpected class", content.getClass(), receivedObject.getClass());
assertEquals("Received object has unexpected content", content, receivedObject);
}
private void doTestBlackListedEnclosedClassTest(final Connection c, final Serializable content) throws Exception
{
Queue destination = createQueue(getTestName());
c.start();
Session s = c.createSession(false, Session.AUTO_ACKNOWLEDGE);
MessageConsumer consumer = s.createConsumer(destination);
MessageProducer producer = s.createProducer(destination);
final ObjectMessage sendMessage = s.createObjectMessage();
sendMessage.setObject(content);
producer.send(sendMessage);
Message receivedMessage = consumer.receive(getReceiveTimeout());
assertNotNull("did not receive message within receive timeout", receivedMessage);
assertTrue("message is of wrong type", receivedMessage instanceof ObjectMessage);
try
{
((ObjectMessage) receivedMessage).getObject();
fail("Exception not thrown");
}
catch (MessageFormatException e)
{
// pass
}
}
private void sendTestObjectMessage(final Session s, final MessageProducer producer) throws JMSException
{
HashMap<String, Integer> messageContent = new HashMap<>();
messageContent.put("value", TEST_VALUE);
Message objectMessage = s.createObjectMessage(messageContent);
producer.send(objectMessage);
}
public static Serializable createAnonymousObject(final int field)
{
return new Serializable()
{
private int _field = field;
@Override
public int hashCode()
{
return _field;
}
@Override
public boolean equals(final Object o)
{
if (this == o)
{
return true;
}
if (o == null || getClass() != o.getClass())
{
return false;
}
final Serializable that = (Serializable) o;
return getFieldValueByReflection(that).equals(_field);
}
private Object getFieldValueByReflection(final Serializable that)
{
try
{
final Field f = that.getClass().getDeclaredField("_field");
f.setAccessible(true);
return f.get(that);
}
catch (NoSuchFieldException | IllegalAccessException e)
{
throw new RuntimeException(e);
}
}
};
}
public static class NestedClass implements Serializable
{
private final int _field;
public NestedClass(final int field)
{
_field = field;
}
@Override
public boolean equals(final Object o)
{
if (this == o)
{
return true;
}
if (o == null || getClass() != o.getClass())
{
return false;
}
final NestedClass that = (NestedClass) o;
return _field == that._field;
}
@Override
public int hashCode()
{
return _field;
}
}
}