set reading length in ThriftBytesWriteSupport to avoid potential OOM caused by corrupted data
diff --git a/parquet-thrift/src/main/java/parquet/hadoop/thrift/ThriftBytesWriteSupport.java b/parquet-thrift/src/main/java/parquet/hadoop/thrift/ThriftBytesWriteSupport.java index f03858f..5eb4a30 100644 --- a/parquet-thrift/src/main/java/parquet/hadoop/thrift/ThriftBytesWriteSupport.java +++ b/parquet-thrift/src/main/java/parquet/hadoop/thrift/ThriftBytesWriteSupport.java
@@ -21,6 +21,7 @@ import org.apache.hadoop.io.BytesWritable; import org.apache.thrift.TBase; import org.apache.thrift.TException; +import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.protocol.TProtocolFactory; import org.apache.thrift.transport.TIOStreamTransport; @@ -54,7 +55,7 @@ } try { @SuppressWarnings("unchecked") - Class<TProtocolFactory> tProtocolFactoryClass = (Class<TProtocolFactory>)Class.forName(tProtocolClassName+"$Factory"); + Class<TProtocolFactory> tProtocolFactoryClass = (Class<TProtocolFactory>)Class.forName(tProtocolClassName + "$Factory"); return tProtocolFactoryClass; } catch (ClassNotFoundException e) { throw new BadConfigurationException("the Factory for class " + tProtocolClassName + " in job conf at " + PARQUET_PROTOCOL_CLASS + " could not be found", e); @@ -116,7 +117,16 @@ } private TProtocol protocol(BytesWritable record) { - return protocolFactory.getProtocol(new TIOStreamTransport(new ByteArrayInputStream(record.getBytes()))); + TProtocol protocol = protocolFactory.getProtocol(new TIOStreamTransport(new ByteArrayInputStream(record.getBytes()))); + + /* Reduce the chance of OOM when data is corrupted. When readBinary is called on TBinaryProtocol, it reads the length of the binary first, + so if the data is corrupted, it could read a big integer as the length of the binary and therefore causes OOM to happen. + Currently this fix only applies to TBinaryProtocol which has the setReadLength defined. + */ + if (protocol instanceof TBinaryProtocol) { + ((TBinaryProtocol)protocol).setReadLength(record.getLength()); + } + return protocol; } @Override