blob: be72946294bfe9b68cca7eb8b5a75aaf23572eb7 [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
pub mod base_one;
pub mod base_two;
pub mod midlayer;
pub mod recursive;
pub mod ultimate;
#[cfg(test)]
mod tests {
use std::default::Default;
use super::*;
#[test]
fn must_be_able_to_use_constructor() {
let _ = midlayer::Meal::new(Some(base_one::Noodle::default()), None);
}
#[test]
fn must_be_able_to_use_constructor_with_no_fields() {
let _ = midlayer::Meal::new(None, None);
}
#[test]
fn must_be_able_to_use_constructor_without_option_wrap() {
let _ = midlayer::Meal::new(base_one::Noodle::default(), None);
}
#[test]
fn must_be_able_to_use_defaults() {
let _ = midlayer::Meal {
noodle: Some(base_one::Noodle::default()),
..Default::default()
};
}
#[test]
fn unknown_union_variant_in_struct_field_is_treated_as_none() {
use std::io::Cursor;
use thrift::protocol::{
TBinaryInputProtocol, TBinaryOutputProtocol, TFieldIdentifier, TOutputProtocol,
TSerializable, TStructIdentifier, TType,
};
// Serialize AidKit with an unknown union variant (id=99), verify it deserializes as None.
let mut write_buf: Vec<u8> = Vec::new();
{
let cursor = Cursor::new(&mut write_buf);
let mut prot = TBinaryOutputProtocol::new(cursor, false);
prot.write_struct_begin(&TStructIdentifier {
name: "AidKit".to_owned(),
})
.unwrap();
prot.write_field_begin(&TFieldIdentifier {
name: None,
field_type: TType::Struct,
id: Some(1),
})
.unwrap();
prot.write_struct_begin(&TStructIdentifier {
name: "MeasuringAids".to_owned(),
})
.unwrap();
prot.write_field_begin(&TFieldIdentifier {
name: None,
field_type: TType::I32,
id: Some(99),
})
.unwrap();
prot.write_i32(42).unwrap();
prot.write_field_end().unwrap();
prot.write_field_stop().unwrap();
prot.write_struct_end().unwrap();
prot.write_field_end().unwrap();
prot.write_field_stop().unwrap();
prot.write_struct_end().unwrap();
}
let read_cursor = Cursor::new(write_buf);
let mut rprot = TBinaryInputProtocol::new(read_cursor, false);
let kit = base_one::AidKit::read_from_in_protocol(&mut rprot)
.expect("forward-compat deserialization should succeed");
assert!(
kit.aid.is_none(),
"unknown union variant should result in None field (forward compat)"
);
}
#[test]
fn union_with_known_and_unknown_fields_deserializes_to_known_variant() {
use std::io::Cursor;
use thrift::protocol::{
TBinaryInputProtocol, TBinaryOutputProtocol, TFieldIdentifier, TOutputProtocol,
TSerializable, TStructIdentifier, TType,
};
// Regression test for total_field_count: a newer server may send a union
// with a recognised variant (id=2, MeasuringCup) plus an extra unknown
// field (id=99). The deserializer must return the known variant rather
// than failing with "received multiple fields for union".
let mut buf: Vec<u8> = Vec::new();
{
let mut prot = TBinaryOutputProtocol::new(Cursor::new(&mut buf), false);
prot.write_struct_begin(&TStructIdentifier {
name: "MeasuringAids".to_owned(),
})
.unwrap();
// Known variant: MeasuringCup (field id=2)
prot.write_field_begin(&TFieldIdentifier {
name: None,
field_type: TType::Struct,
id: Some(2),
})
.unwrap();
prot.write_struct_begin(&TStructIdentifier {
name: "MeasuringCup".to_owned(),
})
.unwrap();
prot.write_field_begin(&TFieldIdentifier {
name: None,
field_type: TType::Double,
id: Some(1),
})
.unwrap();
prot.write_double(250.0).unwrap();
prot.write_field_end().unwrap();
prot.write_field_stop().unwrap();
prot.write_struct_end().unwrap();
prot.write_field_end().unwrap();
// Unknown field from a newer schema (field id=99)
prot.write_field_begin(&TFieldIdentifier {
name: None,
field_type: TType::I32,
id: Some(99),
})
.unwrap();
prot.write_i32(0).unwrap();
prot.write_field_end().unwrap();
prot.write_field_stop().unwrap();
prot.write_struct_end().unwrap();
}
let aids = base_one::MeasuringAids::read_from_in_protocol(&mut TBinaryInputProtocol::new(
Cursor::new(buf),
false,
))
.expect("union with a known variant plus an unknown field must deserialize successfully");
assert!(
matches!(aids, base_one::MeasuringAids::Cup(_)),
"known variant (Cup) must be returned when an unknown field is also present"
);
}
#[test]
fn nested_union_unknown_variant_does_not_corrupt_stream() {
use std::io::Cursor;
use thrift::protocol::{
TBinaryInputProtocol, TBinaryOutputProtocol, TFieldIdentifier, TOutputProtocol,
TSerializable, TStructIdentifier, TType,
};
// Regression test for stream corruption: when suppress_unknown catches
// an UnknownUnionVariant error from a union field whose variant value is
// itself a union, the outer union's Stop byte must have been consumed
// before the error propagates. If it hasn't, the enclosing struct loop
// reads that orphaned 0x00 as its own Stop and silently drops every
// subsequent field.
let mut buf: Vec<u8> = Vec::new();
{
let mut prot = TBinaryOutputProtocol::new(Cursor::new(&mut buf), false);
// InstrumentBox
prot.write_struct_begin(&TStructIdentifier {
name: "InstrumentBox".to_owned(),
})
.unwrap();
// field 1: InstrumentUnion
prot.write_field_begin(&TFieldIdentifier {
name: None,
field_type: TType::Struct,
id: Some(1),
})
.unwrap();
// InstrumentUnion: known variant 1 whose value (MeasuringAids) carries
// only an unknown sub-variant, so MeasuringAids::read returns
// UnknownUnionVariant, which propagates via ? before InstrumentUnion
// has consumed its own Stop byte.
prot.write_struct_begin(&TStructIdentifier {
name: "InstrumentUnion".to_owned(),
})
.unwrap();
prot.write_field_begin(&TFieldIdentifier {
name: None,
field_type: TType::Struct,
id: Some(1),
})
.unwrap();
prot.write_struct_begin(&TStructIdentifier {
name: "MeasuringAids".to_owned(),
})
.unwrap();
prot.write_field_begin(&TFieldIdentifier {
name: None,
field_type: TType::I32,
id: Some(99),
})
.unwrap();
prot.write_i32(0).unwrap();
prot.write_field_end().unwrap();
prot.write_field_stop().unwrap();
prot.write_struct_end().unwrap();
prot.write_field_end().unwrap();
prot.write_field_stop().unwrap(); // InstrumentUnion's Stop
prot.write_struct_end().unwrap();
prot.write_field_end().unwrap();
// field 2: tag=42 — must survive even if field 1 triggers suppress_unknown
prot.write_field_begin(&TFieldIdentifier {
name: None,
field_type: TType::I32,
id: Some(2),
})
.unwrap();
prot.write_i32(42).unwrap();
prot.write_field_end().unwrap();
prot.write_field_stop().unwrap();
prot.write_struct_end().unwrap();
}
let ibox = base_one::InstrumentBox::read_from_in_protocol(&mut TBinaryInputProtocol::new(
Cursor::new(buf),
false,
))
.expect("struct with a nested unknown union variant must deserialize without error");
assert!(
ibox.instrument.is_none(),
"outer union with nested unknown variant must become None"
);
assert_eq!(
ibox.tag,
Some(42),
"field following a suppressed union must not be silently dropped due to stream corruption"
);
}
// ---- Recursion depth limiting (THRIFT-6057) ----
//
// These exercise the generated read/write recursion guard (a thread-local
// DepthGuard acquired at the top of every struct/union read_from_in_protocol
// and write_to_out_protocol) through full round-trips over the recursive
// types in test/Recursive.thrift, not by poking the guard directly.
//
// NB: both skip() and the struct-read guard raise ProtocolErrorKind::DepthLimit,
// so a crafted over-limit payload must use the *known* recursive field (id 1,
// a struct) which the generated reader matches-and-recurses into -- an unknown
// field would be skip()-ed and trip the unrelated skip guard. The over-limit
// assertions therefore also rely on the master baseline (no DepthGuard emitted)
// to confirm the new guard -- not skip -- is what fires.
const LIMIT: usize = 64;
fn build_co_rec(depth: usize) -> recursive::CoRec {
if depth <= 1 {
recursive::CoRec { other: None }
} else {
recursive::CoRec {
other: Some(Box::new(build_co_rec2(depth - 1))),
}
}
}
fn build_co_rec2(depth: usize) -> recursive::CoRec2 {
if depth <= 1 {
recursive::CoRec2 { other: None }
} else {
recursive::CoRec2 {
other: Some(build_co_rec(depth - 1)),
}
}
}
fn build_co_error(depth: usize) -> recursive::CoError {
if depth <= 1 {
recursive::CoError { other: None }
} else {
recursive::CoError {
other: Some(Box::new(build_co_error2(depth - 1))),
}
}
}
fn build_co_error2(depth: usize) -> recursive::CoError2 {
if depth <= 1 {
recursive::CoError2 { other: None }
} else {
recursive::CoError2 {
other: Some(build_co_error(depth - 1)),
}
}
}
// Write a `depth`-deep nesting of structs that each carry the recursive
// field (id 1, a struct) using raw protocol calls, so no write-side guard is
// involved -- reading it back through a generated reader is what must trip
// the read-side guard.
fn write_nested_chain(prot: &mut dyn thrift::protocol::TOutputProtocol, depth: usize) {
use thrift::protocol::{TFieldIdentifier, TOutputProtocol, TStructIdentifier, TType};
prot.write_struct_begin(&TStructIdentifier {
name: "Rec".to_owned(),
})
.unwrap();
if depth > 1 {
prot.write_field_begin(&TFieldIdentifier {
name: None,
field_type: TType::Struct,
id: Some(1),
})
.unwrap();
write_nested_chain(prot, depth - 1);
prot.write_field_end().unwrap();
}
prot.write_field_stop().unwrap();
prot.write_struct_end().unwrap();
}
fn assert_depth_limit<T: std::fmt::Debug>(result: thrift::Result<T>) {
match result {
Err(thrift::Error::Protocol(pe)) => assert_eq!(
pe.kind,
thrift::ProtocolErrorKind::DepthLimit,
"expected DepthLimit, got {:?}",
pe
),
other => panic!("expected DepthLimit protocol error, got {:?}", other),
}
}
#[test]
fn recursion_depth_struct_round_trip_and_limit() {
use std::io::Cursor;
use thrift::protocol::{TBinaryInputProtocol, TBinaryOutputProtocol, TSerializable};
// A chain exactly at the limit round-trips and is value-preserving.
let original = build_co_rec(LIMIT);
let mut buf: Vec<u8> = Vec::new();
{
let mut o = TBinaryOutputProtocol::new(Cursor::new(&mut buf), false);
original
.write_to_out_protocol(&mut o)
.expect("at-limit write must succeed");
}
let read_back = recursive::CoRec::read_from_in_protocol(&mut TBinaryInputProtocol::new(
Cursor::new(buf),
false,
))
.expect("at-limit read must succeed");
assert_eq!(original, read_back);
// One level past the limit: the write is rejected (counter restored on drop).
let mut o = TBinaryOutputProtocol::new(Cursor::new(Vec::new()), false);
assert_depth_limit(build_co_rec(LIMIT + 1).write_to_out_protocol(&mut o));
// One level past the limit: a crafted payload is rejected on read.
let mut deep: Vec<u8> = Vec::new();
{
// craft with an unbounded writer so the bound is exercised only on read
let mut o = TBinaryOutputProtocol::with_config(
Cursor::new(&mut deep),
false,
thrift::TConfiguration::no_limits(),
);
write_nested_chain(&mut o, LIMIT + 1);
}
assert_depth_limit(recursive::CoRec::read_from_in_protocol(
&mut TBinaryInputProtocol::new(Cursor::new(deep), false),
));
}
#[test]
fn recursion_depth_exception_round_trip_and_limit() {
use std::io::Cursor;
use thrift::protocol::{TBinaryInputProtocol, TBinaryOutputProtocol, TSerializable};
let original = build_co_error(LIMIT);
let mut buf: Vec<u8> = Vec::new();
{
let mut o = TBinaryOutputProtocol::new(Cursor::new(&mut buf), false);
original
.write_to_out_protocol(&mut o)
.expect("at-limit write must succeed");
}
let read_back = recursive::CoError::read_from_in_protocol(&mut TBinaryInputProtocol::new(
Cursor::new(buf),
false,
))
.expect("at-limit read must succeed");
assert_eq!(original, read_back);
let mut o = TBinaryOutputProtocol::new(Cursor::new(Vec::new()), false);
assert_depth_limit(build_co_error(LIMIT + 1).write_to_out_protocol(&mut o));
let mut deep: Vec<u8> = Vec::new();
{
// craft with an unbounded writer so the bound is exercised only on read
let mut o = TBinaryOutputProtocol::with_config(
Cursor::new(&mut deep),
false,
thrift::TConfiguration::no_limits(),
);
write_nested_chain(&mut o, LIMIT + 1);
}
assert_depth_limit(recursive::CoError::read_from_in_protocol(
&mut TBinaryInputProtocol::new(Cursor::new(deep), false),
));
}
#[test]
fn recursion_depth_union_read_limit() {
use std::io::Cursor;
use thrift::protocol::{TBinaryInputProtocol, TBinaryOutputProtocol, TSerializable};
// CoUnion has only the recursive variant (no leaf), so a finite value
// cannot be constructed/written; exercise the union read guard with a
// crafted over-limit payload.
let mut deep: Vec<u8> = Vec::new();
{
// craft with an unbounded writer so the bound is exercised only on read
let mut o = TBinaryOutputProtocol::with_config(
Cursor::new(&mut deep),
false,
thrift::TConfiguration::no_limits(),
);
write_nested_chain(&mut o, LIMIT + 1);
}
assert_depth_limit(recursive::CoUnion::read_from_in_protocol(
&mut TBinaryInputProtocol::new(Cursor::new(deep), false),
));
}
// The bound lives in the protocol's struct read/write, so it applies to every
// generated type uniformly; the binary tests above cover struct/union/exception
// routing. This confirms the compact protocol enforces the same bound.
#[test]
fn recursion_depth_compact_round_trip_and_limit() {
use std::io::Cursor;
use thrift::protocol::{TCompactInputProtocol, TCompactOutputProtocol, TSerializable};
let original = build_co_rec(LIMIT);
let mut buf: Vec<u8> = Vec::new();
{
let mut o = TCompactOutputProtocol::new(Cursor::new(&mut buf));
original
.write_to_out_protocol(&mut o)
.expect("at-limit write must succeed");
}
let read_back = recursive::CoRec::read_from_in_protocol(&mut TCompactInputProtocol::new(
Cursor::new(buf),
))
.expect("at-limit read must succeed");
assert_eq!(original, read_back);
let mut o = TCompactOutputProtocol::new(Cursor::new(Vec::new()));
assert_depth_limit(build_co_rec(LIMIT + 1).write_to_out_protocol(&mut o));
let mut deep: Vec<u8> = Vec::new();
{
// craft with an unbounded writer so the bound is exercised only on read
let mut o = TCompactOutputProtocol::with_config(
Cursor::new(&mut deep),
thrift::TConfiguration::no_limits(),
);
write_nested_chain(&mut o, LIMIT + 1);
}
assert_depth_limit(recursive::CoRec::read_from_in_protocol(
&mut TCompactInputProtocol::new(Cursor::new(deep)),
));
}
}