set reading length in ThriftBytesWriteSupport to avoid potential OOM caused by corrupted data
diff --git a/parquet-thrift/src/main/java/parquet/hadoop/thrift/ThriftBytesWriteSupport.java b/parquet-thrift/src/main/java/parquet/hadoop/thrift/ThriftBytesWriteSupport.java
index f03858f..5eb4a30 100644
--- a/parquet-thrift/src/main/java/parquet/hadoop/thrift/ThriftBytesWriteSupport.java
+++ b/parquet-thrift/src/main/java/parquet/hadoop/thrift/ThriftBytesWriteSupport.java
@@ -21,6 +21,7 @@
import org.apache.hadoop.io.BytesWritable;
import org.apache.thrift.TBase;
import org.apache.thrift.TException;
+import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.transport.TIOStreamTransport;
@@ -54,7 +55,7 @@
}
try {
@SuppressWarnings("unchecked")
- Class<TProtocolFactory> tProtocolFactoryClass = (Class<TProtocolFactory>)Class.forName(tProtocolClassName+"$Factory");
+ Class<TProtocolFactory> tProtocolFactoryClass = (Class<TProtocolFactory>)Class.forName(tProtocolClassName + "$Factory");
return tProtocolFactoryClass;
} catch (ClassNotFoundException e) {
throw new BadConfigurationException("the Factory for class " + tProtocolClassName + " in job conf at " + PARQUET_PROTOCOL_CLASS + " could not be found", e);
@@ -116,7 +117,16 @@
}
private TProtocol protocol(BytesWritable record) {
- return protocolFactory.getProtocol(new TIOStreamTransport(new ByteArrayInputStream(record.getBytes())));
+ TProtocol protocol = protocolFactory.getProtocol(new TIOStreamTransport(new ByteArrayInputStream(record.getBytes())));
+
+ /* Reduce the chance of OOM when data is corrupted. When readBinary is called on TBinaryProtocol, it reads the length of the binary first,
+ so if the data is corrupted, it could read a big integer as the length of the binary and therefore causes OOM to happen.
+ Currently this fix only applies to TBinaryProtocol which has the setReadLength defined.
+ */
+ if (protocol instanceof TBinaryProtocol) {
+ ((TBinaryProtocol)protocol).setReadLength(record.getLength());
+ }
+ return protocol;
}
@Override