Add super class check
diff --git a/src/main/java/com/alibaba/com/caucho/hessian/io/ClassFactory.java b/src/main/java/com/alibaba/com/caucho/hessian/io/ClassFactory.java
index d5976ba..d64b6fe 100644
--- a/src/main/java/com/alibaba/com/caucho/hessian/io/ClassFactory.java
+++ b/src/main/java/com/alibaba/com/caucho/hessian/io/ClassFactory.java
@@ -54,6 +54,7 @@
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
+import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
@@ -71,6 +72,7 @@
protected static final Logger log
= Logger.getLogger(ClassFactory.class.getName());
private static final ArrayList<Allow> _staticAllowList;
+ private static final Map<String, Object> _allowSubClassSet = new ConcurrentHashMap<>();
private static final Map<String, Object> _allowClassSet = new ConcurrentHashMap<>();
private ClassLoader _loader;
@@ -88,10 +90,43 @@
throws ClassNotFoundException
{
if (isAllow(className)) {
- return Class.forName(className, false, _loader);
+ Class<?> aClass = Class.forName(className, false, _loader);
+
+ if (_allowClassSet.containsKey(className)) {
+ return aClass;
+ }
+
+ if (aClass.getInterfaces().length > 0) {
+ for (Class<?> anInterface : aClass.getInterfaces()) {
+ if(!isAllow(anInterface.getName())) {
+ log.log(Level.SEVERE, className + "'s interfaces: " + anInterface.getName() + " in blacklist or not in whitelist, deserialization with type 'HashMap' instead.");
+ return HashMap.class;
+ }
+ }
+ }
+
+ List<Class<?>> allSuperClasses = new LinkedList<>();
+
+ Class<?> superClass = aClass.getSuperclass();
+ while (superClass != null) {
+ // add current super class
+ allSuperClasses.add(superClass);
+ superClass = superClass.getSuperclass();
+ }
+
+ for (Class<?> aSuperClass : allSuperClasses) {
+ if(!isAllow(aSuperClass.getName())) {
+ log.log(Level.SEVERE, className + "'s superClass: " + aSuperClass.getName() + " in blacklist or not in whitelist, deserialization with type 'HashMap' instead.");
+ return HashMap.class;
+ }
+
+ }
+
+ _allowClassSet.put(className, className);
+ return aClass;
}
else {
- log.log(Level.SEVERE, className + " in blacklist or not in whitelist, deserialization with type 'HashMap' instead.");
+ log.log(Level.SEVERE, className + " in blacklist or not in whitelist, deserialization with type 'HashMap' instead.");
return HashMap.class;
}
}
@@ -104,19 +139,16 @@
return true;
}
- if (_allowClassSet.containsKey(className)) {
+ if (_allowSubClassSet.containsKey(className)) {
return true;
}
- int size = allowList.size();
- for (int i = 0; i < size; i++) {
- Allow allow = allowList.get(i);
-
+ for (Allow allow : allowList) {
Boolean isAllow = allow.allow(className);
if (isAllow != null) {
if (isAllow) {
- _allowClassSet.put(className, className);
+ _allowSubClassSet.put(className, className);
}
return isAllow;
}
@@ -126,13 +158,14 @@
return false;
}
- _allowClassSet.put(className, className);
+ _allowSubClassSet.put(className, className);
return true;
}
public void setWhitelist(boolean isWhitelist)
{
_allowClassSet.clear();
+ _allowSubClassSet.clear();
_isWhitelist = isWhitelist;
initAllow();
@@ -141,6 +174,7 @@
public void allow(String pattern)
{
_allowClassSet.clear();
+ _allowSubClassSet.clear();
initAllow();
synchronized (this) {
@@ -151,6 +185,7 @@
public void deny(String pattern)
{
_allowClassSet.clear();
+ _allowSubClassSet.clear();
initAllow();
synchronized (this) {
@@ -158,7 +193,7 @@
}
}
- private String toPattern(String pattern)
+ private static String toPattern(String pattern)
{
pattern = pattern.replace(".", "\\.");
pattern = pattern.replace("*", ".*");
@@ -233,7 +268,11 @@
if (denyClass.startsWith("#")) {
continue;
}
- _staticAllowList.add(new AllowPrefix(denyClass, false));
+ if (denyClass.endsWith(".")) {
+ _staticAllowList.add(new AllowPrefix(denyClass, false));
+ } else {
+ _staticAllowList.add(new Allow(toPattern(denyClass), false));
+ }
}
} catch (IOException ignore) {
diff --git a/src/test/java/com/alibaba/com/caucho/hessian/io/DenyListTest.java b/src/test/java/com/alibaba/com/caucho/hessian/io/DenyListTest.java
index d05a0c7..3c8375d 100644
--- a/src/test/java/com/alibaba/com/caucho/hessian/io/DenyListTest.java
+++ b/src/test/java/com/alibaba/com/caucho/hessian/io/DenyListTest.java
@@ -18,6 +18,7 @@
import org.junit.Assert;
import org.junit.Test;
+import sun.rmi.transport.StreamRemoteCall;
import java.lang.reflect.Array;
import java.util.HashMap;
@@ -35,6 +36,15 @@
Assert.assertEquals(HashMap.class, classFactory.load("java.beans.C"));
Assert.assertEquals(HashMap.class, classFactory.load("java.beans.D"));
Assert.assertEquals(HashMap.class, classFactory.load("java.beans.E"));
+ Assert.assertEquals(HashMap.class, classFactory.load("sun.rmi.transport.StreamRemoteCall"));
+
+ classFactory.deny(TestClass.class.getName());
+ Assert.assertEquals(HashMap.class, classFactory.load(TestClass.class.getName()));
+ Assert.assertEquals(HashMap.class, classFactory.load(TestClass1.class.getName()));
+
+ classFactory.deny(TestInterface.class.getName());
+ Assert.assertEquals(HashMap.class, classFactory.load(TestInterface.class.getName()));
+ Assert.assertEquals(HashMap.class, classFactory.load(TestImpl.class.getName()));
}
@Test
@@ -47,5 +57,6 @@
Assert.assertEquals(List.class, classFactory.load(List.class.getName()));
Assert.assertEquals(Array.class, classFactory.load(Array.class.getName()));
Assert.assertEquals(LinkedList.class, classFactory.load(LinkedList.class.getName()));
+ Assert.assertEquals(RuntimeException.class, classFactory.load(RuntimeException.class.getName()));
}
}
diff --git a/src/test/java/com/alibaba/com/caucho/hessian/io/TestClass.java b/src/test/java/com/alibaba/com/caucho/hessian/io/TestClass.java
index 43907b3..0420962 100644
--- a/src/test/java/com/alibaba/com/caucho/hessian/io/TestClass.java
+++ b/src/test/java/com/alibaba/com/caucho/hessian/io/TestClass.java
@@ -20,3 +20,7 @@
public class TestClass implements Serializable {
}
+
+class TestClass1 extends TestClass {
+
+}
\ No newline at end of file
diff --git a/src/test/java/com/alibaba/com/caucho/hessian/io/TestInterface.java b/src/test/java/com/alibaba/com/caucho/hessian/io/TestInterface.java
new file mode 100644
index 0000000..72d6d3c
--- /dev/null
+++ b/src/test/java/com/alibaba/com/caucho/hessian/io/TestInterface.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.alibaba.com.caucho.hessian.io;
+
+public interface TestInterface {
+}
+
+class TestImpl implements TestInterface {
+
+}
\ No newline at end of file