| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.spark.serializer |
| |
| import java.io._ |
| import java.lang.invoke.MethodHandles |
| import java.lang.reflect.{Field, Method} |
| import java.security.{AccessController, PrivilegedAction} |
| |
| import scala.annotation.tailrec |
| import scala.collection.mutable |
| import scala.util.control.NonFatal |
| |
| import org.apache.spark.internal.Logging |
| import org.apache.spark.util.SparkClassUtils |
| |
| private[spark] object SerializationDebugger extends Logging { |
| |
| /** |
| * Improve the given NotSerializableException with the serialization path leading from the given |
| * object to the problematic object. This is turned off automatically if |
| * `sun.io.serialization.extendedDebugInfo` flag is turned on for the JVM. |
| */ |
| def improveException(obj: Any, e: NotSerializableException): NotSerializableException = { |
| if (enableDebugging && reflect != null) { |
| try { |
| new NotSerializableException( |
| e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) |
| } catch { |
| case NonFatal(t) => |
| // Fall back to old exception |
| logWarning("Exception in serialization debugger", t) |
| e |
| } |
| } else { |
| e |
| } |
| } |
| |
| /** |
| * Find the path leading to a not serializable object. This method is modeled after OpenJDK's |
| * serialization mechanism, and handles the following cases: |
| * |
| * - primitives |
| * - arrays of primitives |
| * - arrays of non-primitive objects |
| * - Serializable objects |
| * - Externalizable objects |
| * - writeReplace |
| * |
| * It does not yet handle writeObject override, but that shouldn't be too hard to do either. |
| */ |
| private[serializer] def find(obj: Any): List[String] = { |
| new SerializationDebugger().visit(obj, List.empty) |
| } |
| |
| private[serializer] var enableDebugging: Boolean = { |
| val lookup = MethodHandles.lookup() |
| val clazz = SparkClassUtils.classForName("sun.security.action.GetBooleanAction") |
| val constructor = clazz.getConstructor(classOf[String]) |
| val mh = lookup.unreflectConstructor(constructor) |
| val action = mh.invoke("sun.io.serialization.extendedDebugInfo") |
| .asInstanceOf[PrivilegedAction[Boolean]] |
| !AccessController.doPrivileged(action).booleanValue() |
| } |
| |
| private class SerializationDebugger { |
| |
| /** A set to track the list of objects we have visited, to avoid cycles in the graph. */ |
| private val visited = new mutable.HashSet[Any] |
| |
| /** |
| * Visit the object and its fields and stop when we find an object that is not serializable. |
| * Return the path as a list. If everything can be serialized, return an empty list. |
| */ |
| def visit(o: Any, stack: List[String]): List[String] = { |
| if (o == null) { |
| List.empty |
| } else if (visited.contains(o)) { |
| List.empty |
| } else { |
| visited += o |
| o match { |
| // Primitive value, string, and primitive arrays are always serializable |
| case _ if o.getClass.isPrimitive => List.empty |
| case _: String => List.empty |
| case _ if o.getClass.isArray && o.getClass.getComponentType.isPrimitive => List.empty |
| |
| // Traverse non primitive array. |
| case a: Array[_] if o.getClass.isArray && !o.getClass.getComponentType.isPrimitive => |
| val elem = s"array (class ${a.getClass.getName}, size ${a.length})" |
| visitArray(o.asInstanceOf[Array[_]], elem :: stack) |
| |
| case e: java.io.Externalizable => |
| val elem = s"externalizable object (class ${e.getClass.getName}, $e)" |
| visitExternalizable(e, elem :: stack) |
| |
| case s: Object with java.io.Serializable => |
| val elem = s"object (class ${s.getClass.getName}, $s)" |
| visitSerializable(s, elem :: stack) |
| |
| case _ => |
| // Found an object that is not serializable! |
| s"object not serializable (class: ${o.getClass.getName}, value: $o)" :: stack |
| } |
| } |
| } |
| |
| private def visitArray(o: Array[_], stack: List[String]): List[String] = { |
| var i = 0 |
| while (i < o.length) { |
| val childStack = visit(o(i), s"element of array (index: $i)" :: stack) |
| if (childStack.nonEmpty) { |
| return childStack |
| } |
| i += 1 |
| } |
| List.empty |
| } |
| |
| /** |
| * Visit an externalizable object. |
| * Since writeExternal() can choose to add arbitrary objects at the time of serialization, |
| * the only way to capture all the objects it will serialize is by using a |
| * dummy ObjectOutput that collects all the relevant objects for further testing. |
| */ |
| private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] = |
| { |
| val fieldList = new ListObjectOutput |
| o.writeExternal(fieldList) |
| val childObjects = fieldList.outputArray |
| var i = 0 |
| while (i < childObjects.length) { |
| val childStack = visit(childObjects(i), "writeExternal data" :: stack) |
| if (childStack.nonEmpty) { |
| return childStack |
| } |
| i += 1 |
| } |
| List.empty |
| } |
| |
| private def visitSerializable(o: Object, stack: List[String]): List[String] = { |
| // An object contains multiple slots in serialization. |
| // Get the slots and visit fields in all of them. |
| val (finalObj, desc) = findObjectAndDescriptor(o) |
| |
| // If the object has been replaced using writeReplace(), |
| // then call visit() on it again to test its type again. |
| if (finalObj.getClass != o.getClass) { |
| return visit(finalObj, s"writeReplace data (class: ${finalObj.getClass.getName})" :: stack) |
| } |
| |
| // Every class is associated with one or more "slots", each slot refers to the parent |
| // classes of this class. These slots are used by the ObjectOutputStream |
| // serialization code to recursively serialize the fields of an object and |
| // its parent classes. For example, if there are the following classes. |
| // |
| // class ParentClass(parentField: Int) |
| // class ChildClass(childField: Int) extends ParentClass(1) |
| // |
| // Then serializing the an object Obj of type ChildClass requires first serializing the fields |
| // of ParentClass (that is, parentField), and then serializing the fields of ChildClass |
| // (that is, childField). Correspondingly, there will be two slots related to this object: |
| // |
| // 1. ParentClass slot, which will be used to serialize parentField of Obj |
| // 2. ChildClass slot, which will be used to serialize childField fields of Obj |
| // |
| // The following code uses the description of each slot to find the fields in the |
| // corresponding object to visit. |
| // |
| val slotDescs = desc.getSlotDescs |
| var i = 0 |
| while (i < slotDescs.length) { |
| val slotDesc = slotDescs(i) |
| if (slotDesc.hasWriteObjectMethod) { |
| // If the class type corresponding to current slot has writeObject() defined, |
| // then its not obvious which fields of the class will be serialized as the writeObject() |
| // can choose arbitrary fields for serialization. This case is handled separately. |
| val elem = s"writeObject data (class: ${slotDesc.getName})" |
| val childStack = visitSerializableWithWriteObjectMethod(finalObj, elem :: stack) |
| if (childStack.nonEmpty) { |
| return childStack |
| } |
| } else { |
| // Visit all the fields objects of the class corresponding to the current slot. |
| val fields: Array[ObjectStreamField] = slotDesc.getFields |
| val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields) |
| val numPrims = fields.length - objFieldValues.length |
| slotDesc.getObjFieldValues(finalObj, objFieldValues) |
| |
| var j = 0 |
| while (j < objFieldValues.length) { |
| val fieldDesc = fields(numPrims + j) |
| val elem = s"field (class: ${slotDesc.getName}" + |
| s", name: ${fieldDesc.getName}" + |
| s", type: ${fieldDesc.getType})" |
| val childStack = visit(objFieldValues(j), elem :: stack) |
| if (childStack.nonEmpty) { |
| return childStack |
| } |
| j += 1 |
| } |
| } |
| i += 1 |
| } |
| List.empty |
| } |
| |
| /** |
| * Visit a serializable object which has the writeObject() defined. |
| * Since writeObject() can choose to add arbitrary objects at the time of serialization, |
| * the only way to capture all the objects it will serialize is by using a |
| * dummy ObjectOutputStream that collects all the relevant fields for further testing. |
| * This is similar to how externalizable objects are visited. |
| */ |
| private def visitSerializableWithWriteObjectMethod( |
| o: Object, stack: List[String]): List[String] = { |
| val innerObjectsCatcher = new ListObjectOutputStream |
| var notSerializableFound = false |
| try { |
| innerObjectsCatcher.writeObject(o) |
| } catch { |
| case io: IOException => |
| notSerializableFound = true |
| } |
| |
| // If something was not serializable, then visit the captured objects. |
| // Otherwise, all the captured objects are safely serializable, so no need to visit them. |
| // As an optimization, just added them to the visited list. |
| if (notSerializableFound) { |
| val innerObjects = innerObjectsCatcher.outputArray |
| var k = 0 |
| while (k < innerObjects.length) { |
| val childStack = visit(innerObjects(k), stack) |
| if (childStack.nonEmpty) { |
| return childStack |
| } |
| k += 1 |
| } |
| } else { |
| visited ++= innerObjectsCatcher.outputArray |
| } |
| List.empty |
| } |
| } |
| |
| /** |
| * Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles |
| * writeReplace in Serializable. It starts with the object itself, and keeps calling the |
| * writeReplace method until there is no more. |
| */ |
| @tailrec |
| private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = { |
| val cl = o.getClass |
| val desc = ObjectStreamClass.lookupAny(cl) |
| if (!desc.hasWriteReplaceMethod) { |
| (o, desc) |
| } else { |
| val replaced = desc.invokeWriteReplace(o) |
| // `writeReplace` recursion stops when the returned object has the same class. |
| if (replaced.getClass == o.getClass) { |
| (replaced, desc) |
| } else { |
| findObjectAndDescriptor(replaced) |
| } |
| } |
| } |
| |
| /** |
| * A dummy [[ObjectOutput]] that simply saves the list of objects written by a writeExternal |
| * call, and returns them through `outputArray`. |
| */ |
| private class ListObjectOutput extends ObjectOutput { |
| private val output = new mutable.ArrayBuffer[Any] |
| def outputArray: Array[Any] = output.toArray |
| override def writeObject(o: Any): Unit = output += o |
| override def flush(): Unit = {} |
| override def write(i: Int): Unit = {} |
| override def write(bytes: Array[Byte]): Unit = {} |
| override def write(bytes: Array[Byte], i: Int, i1: Int): Unit = {} |
| override def close(): Unit = {} |
| override def writeFloat(v: Float): Unit = {} |
| override def writeChars(s: String): Unit = {} |
| override def writeDouble(v: Double): Unit = {} |
| override def writeUTF(s: String): Unit = {} |
| override def writeShort(i: Int): Unit = {} |
| override def writeInt(i: Int): Unit = {} |
| override def writeBoolean(b: Boolean): Unit = {} |
| override def writeBytes(s: String): Unit = {} |
| override def writeChar(i: Int): Unit = {} |
| override def writeLong(l: Long): Unit = {} |
| override def writeByte(i: Int): Unit = {} |
| } |
| |
| /** An output stream that emulates /dev/null */ |
| private class NullOutputStream extends OutputStream { |
| override def write(b: Int): Unit = { } |
| } |
| |
| /** |
| * A dummy [[ObjectOutputStream]] that saves the list of objects written to it and returns |
| * them through `outputArray`. This works by using the [[ObjectOutputStream]]'s `replaceObject()` |
| * method which gets called on every object, only if replacing is enabled. So this subclass |
| * of [[ObjectOutputStream]] enabled replacing, and uses replaceObject to get the objects that |
| * are being serializabled. The serialized bytes are ignored by sending them to a |
| * [[NullOutputStream]], which acts like a /dev/null. |
| */ |
| private class ListObjectOutputStream extends ObjectOutputStream(new NullOutputStream) { |
| private val output = new mutable.ArrayBuffer[Any] |
| this.enableReplaceObject(true) |
| |
| def outputArray: Array[Any] = output.toArray |
| |
| override def replaceObject(obj: Object): Object = { |
| output += obj |
| obj |
| } |
| } |
| |
| /** An implicit class that allows us to call private methods of ObjectStreamClass. */ |
| implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal { |
| def getSlotDescs: Array[ObjectStreamClass] = { |
| reflect.GetClassDataLayout.invoke(desc).asInstanceOf[Array[Object]].map { |
| classDataSlot => reflect.DescField.get(classDataSlot).asInstanceOf[ObjectStreamClass] |
| } |
| } |
| |
| def hasWriteObjectMethod: Boolean = { |
| reflect.HasWriteObjectMethod.invoke(desc).asInstanceOf[Boolean] |
| } |
| |
| def hasWriteReplaceMethod: Boolean = { |
| reflect.HasWriteReplaceMethod.invoke(desc).asInstanceOf[Boolean] |
| } |
| |
| def invokeWriteReplace(obj: Object): Object = { |
| reflect.InvokeWriteReplace.invoke(desc, obj) |
| } |
| |
| def getNumObjFields: Int = { |
| reflect.GetNumObjFields.invoke(desc).asInstanceOf[Int] |
| } |
| |
| def getObjFieldValues(obj: Object, out: Array[Object]): Unit = { |
| reflect.GetObjFieldValues.invoke(desc, obj, out) |
| } |
| } |
| |
| /** |
| * Object to hold all the reflection objects. If we run on a JVM that we cannot understand, |
| * this field will be null and this the debug helper should be disabled. |
| */ |
| private val reflect: ObjectStreamClassReflection = try { |
| new ObjectStreamClassReflection |
| } catch { |
| case e: Exception => |
| logWarning("Cannot find private methods using reflection", e) |
| null |
| } |
| |
| private class ObjectStreamClassReflection { |
| /** ObjectStreamClass.getClassDataLayout */ |
| val GetClassDataLayout: Method = { |
| val f = classOf[ObjectStreamClass].getDeclaredMethod("getClassDataLayout") |
| f.setAccessible(true) |
| f |
| } |
| |
| /** ObjectStreamClass.hasWriteObjectMethod */ |
| val HasWriteObjectMethod: Method = { |
| val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteObjectMethod") |
| f.setAccessible(true) |
| f |
| } |
| |
| /** ObjectStreamClass.hasWriteReplaceMethod */ |
| val HasWriteReplaceMethod: Method = { |
| val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteReplaceMethod") |
| f.setAccessible(true) |
| f |
| } |
| |
| /** ObjectStreamClass.invokeWriteReplace */ |
| val InvokeWriteReplace: Method = { |
| val f = classOf[ObjectStreamClass].getDeclaredMethod("invokeWriteReplace", classOf[Object]) |
| f.setAccessible(true) |
| f |
| } |
| |
| /** ObjectStreamClass.getNumObjFields */ |
| val GetNumObjFields: Method = { |
| val f = classOf[ObjectStreamClass].getDeclaredMethod("getNumObjFields") |
| f.setAccessible(true) |
| f |
| } |
| |
| /** ObjectStreamClass.getObjFieldValues */ |
| val GetObjFieldValues: Method = { |
| val f = classOf[ObjectStreamClass].getDeclaredMethod( |
| "getObjFieldValues", classOf[Object], classOf[Array[Object]]) |
| f.setAccessible(true) |
| f |
| } |
| |
| /** ObjectStreamClass$ClassDataSlot.desc field */ |
| val DescField: Field = { |
| // scalastyle:off classforname |
| val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc") |
| // scalastyle:on classforname |
| f.setAccessible(true) |
| f |
| } |
| } |
| } |