Merge pull request #2336 from allengeorge/thrift-5299

THRIFT-5299: Encode sequence numbers as non-zigzag varint
diff --git a/lib/rs/src/protocol/compact.rs b/lib/rs/src/protocol/compact.rs
index 6fa364f..f885e40 100644
--- a/lib/rs/src/protocol/compact.rs
+++ b/lib/rs/src/protocol/compact.rs
@@ -128,7 +128,8 @@
 
         // NOTE: unsigned right shift will pad with 0s
         let message_type: TMessageType = TMessageType::try_from(type_and_byte >> 5)?;
-        let sequence_number = self.read_i32()?;
+        // writing side wrote signed sequence number as u32 to avoid zigzag encoding
+        let sequence_number = self.transport.read_varint::<u32>()? as i32;
         let service_call_name = self.read_string()?;
 
         self.last_read_field_id = 0;
@@ -247,7 +248,9 @@
     }
 
     fn read_double(&mut self) -> crate::Result<f64> {
-        self.transport.read_f64::<LittleEndian>().map_err(From::from)
+        self.transport
+            .read_f64::<LittleEndian>()
+            .map_err(From::from)
     }
 
     fn read_string(&mut self) -> crate::Result<String> {
@@ -388,7 +391,11 @@
         Ok(())
     }
 
-    fn write_list_set_begin(&mut self, element_type: TType, element_count: i32) -> crate::Result<()> {
+    fn write_list_set_begin(
+        &mut self,
+        element_type: TType,
+        element_count: i32,
+    ) -> crate::Result<()> {
         let elem_identifier = collection_type_to_u8(element_type);
         if element_count <= 14 {
             let header = (element_count as u8) << 4 | elem_identifier;
@@ -396,6 +403,8 @@
         } else {
             let header = 0xF0 | elem_identifier;
             self.write_byte(header)?;
+            // element count is strictly positive as per the spec, so
+            // cast i32 as u32 so that varint writing won't use zigzag encoding
             self.transport
                 .write_varint(element_count as u32)
                 .map_err(From::from)
@@ -417,7 +426,8 @@
     fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> crate::Result<()> {
         self.write_byte(COMPACT_PROTOCOL_ID)?;
         self.write_byte((u8::from(identifier.message_type) << 5) | COMPACT_VERSION)?;
-        self.write_i32(identifier.sequence_number)?;
+        // cast i32 as u32 so that varint writing won't use zigzag encoding
+        self.transport.write_varint(identifier.sequence_number as u32)?;
         self.write_string(&identifier.name)?;
         Ok(())
     }
@@ -491,6 +501,8 @@
     }
 
     fn write_bytes(&mut self, b: &[u8]) -> crate::Result<()> {
+        // length is strictly positive as per the spec, so
+        // cast i32 as u32 so that varint writing won't use zigzag encoding
         self.transport.write_varint(b.len() as u32)?;
         self.transport.write_all(b).map_err(From::from)
     }
@@ -521,7 +533,9 @@
     }
 
     fn write_double(&mut self, d: f64) -> crate::Result<()> {
-        self.transport.write_f64::<LittleEndian>(d).map_err(From::from)
+        self.transport
+            .write_f64::<LittleEndian>(d)
+            .map_err(From::from)
     }
 
     fn write_string(&mut self, s: &str) -> crate::Result<()> {
@@ -548,6 +562,8 @@
         if identifier.size == 0 {
             self.write_byte(0)
         } else {
+            // element count is strictly positive as per the spec, so
+            // cast i32 as u32 so that varint writing won't use zigzag encoding
             self.transport.write_varint(identifier.size as u32)?;
 
             let key_type = identifier
@@ -593,7 +609,10 @@
 }
 
 impl TOutputProtocolFactory for TCompactOutputProtocolFactory {
-    fn create(&self, transport: Box<dyn TWriteTransport + Send>) -> Box<dyn TOutputProtocol + Send> {
+    fn create(
+        &self,
+        transport: Box<dyn TWriteTransport + Send>,
+    ) -> Box<dyn TOutputProtocol + Send> {
         Box::new(TCompactOutputProtocol::new(transport))
     }
 }
@@ -655,6 +674,8 @@
 #[cfg(test)]
 mod tests {
 
+    use std::i32;
+
     use crate::protocol::{
         TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier,
         TMessageType, TOutputProtocol, TSetIdentifier, TStructIdentifier, TType,
@@ -664,7 +685,62 @@
     use super::*;
 
     #[test]
-    fn must_write_message_begin_0() {
+    fn must_write_message_begin_largest_maximum_positive_sequence_number() {
+        let (_, mut o_prot) = test_objects();
+
+        assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new(
+            "bar",
+            TMessageType::Reply,
+            i32::MAX
+        )));
+
+        #[rustfmt::skip]
+        let expected: [u8; 11] = [
+            0x82, /* protocol ID */
+            0x41, /* message type | protocol version */
+            0xFF,
+            0xFF,
+            0xFF,
+            0xFF,
+            0x07, /* non-zig-zag varint sequence number */
+            0x03, /* message-name length */
+            0x62,
+            0x61,
+            0x72 /* "bar" */,
+        ];
+
+        assert_eq_written_bytes!(o_prot, expected);
+    }
+
+    #[test]
+    fn must_read_message_begin_largest_maximum_positive_sequence_number() {
+        let (mut i_prot, _) = test_objects();
+
+        #[rustfmt::skip]
+        let source_bytes: [u8; 11] = [
+            0x82, /* protocol ID */
+            0x41, /* message type | protocol version */
+            0xFF,
+            0xFF,
+            0xFF,
+            0xFF,
+            0x07, /* non-zig-zag varint sequence number */
+            0x03, /* message-name length */
+            0x62,
+            0x61,
+            0x72 /* "bar" */,
+        ];
+
+        i_prot.transport.set_readable_bytes(&source_bytes);
+
+        let expected = TMessageIdentifier::new("bar", TMessageType::Reply, i32::MAX);
+        let res = assert_success!(i_prot.read_message_begin());
+
+        assert_eq!(&expected, &res);
+    }
+
+    #[test]
+    fn must_write_message_begin_positive_sequence_number_0() {
         let (_, mut o_prot) = test_objects();
 
         assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new(
@@ -677,8 +753,8 @@
         let expected: [u8; 8] = [
             0x82, /* protocol ID */
             0x21, /* message type | protocol version */
-            0xDE,
-            0x06, /* zig-zag varint sequence number */
+            0xAF,
+            0x03, /* non-zig-zag varint sequence number */
             0x03, /* message-name length */
             0x66,
             0x6F,
@@ -689,7 +765,31 @@
     }
 
     #[test]
-    fn must_write_message_begin_1() {
+    fn must_read_message_begin_positive_sequence_number_0() {
+        let (mut i_prot, _) = test_objects();
+
+        #[rustfmt::skip]
+        let source_bytes: [u8; 8] = [
+            0x82, /* protocol ID */
+            0x21, /* message type | protocol version */
+            0xAF,
+            0x03, /* non-zig-zag varint sequence number */
+            0x03, /* message-name length */
+            0x66,
+            0x6F,
+            0x6F /* "foo" */,
+        ];
+
+        i_prot.transport.set_readable_bytes(&source_bytes);
+
+        let expected = TMessageIdentifier::new("foo", TMessageType::Call, 431);
+        let res = assert_success!(i_prot.read_message_begin());
+
+        assert_eq!(&expected, &res);
+    }
+
+    #[test]
+    fn must_write_message_begin_positive_sequence_number_1() {
         let (_, mut o_prot) = test_objects();
 
         assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new(
@@ -702,9 +802,9 @@
         let expected: [u8; 9] = [
             0x82, /* protocol ID */
             0x41, /* message type | protocol version */
-            0xA8,
-            0x89,
-            0x79, /* zig-zag varint sequence number */
+            0xD4,
+            0xC4,
+            0x3C, /* non-zig-zag varint sequence number */
             0x03, /* message-name length */
             0x62,
             0x61,
@@ -715,6 +815,305 @@
     }
 
     #[test]
+    fn must_read_message_begin_positive_sequence_number_1() {
+        let (mut i_prot, _) = test_objects();
+
+        #[rustfmt::skip]
+        let source_bytes: [u8; 9] = [
+            0x82, /* protocol ID */
+            0x41, /* message type | protocol version */
+            0xD4,
+            0xC4,
+            0x3C, /* non-zig-zag varint sequence number */
+            0x03, /* message-name length */
+            0x62,
+            0x61,
+            0x72 /* "bar" */,
+        ];
+
+        i_prot.transport.set_readable_bytes(&source_bytes);
+
+        let expected = TMessageIdentifier::new("bar", TMessageType::Reply, 991_828);
+        let res = assert_success!(i_prot.read_message_begin());
+
+        assert_eq!(&expected, &res);
+    }
+
+    #[test]
+    fn must_write_message_begin_zero_sequence_number() {
+        let (_, mut o_prot) = test_objects();
+
+        assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new(
+            "bar",
+            TMessageType::Reply,
+            0
+        )));
+
+        #[rustfmt::skip]
+        let expected: [u8; 7] = [
+            0x82, /* protocol ID */
+            0x41, /* message type | protocol version */
+            0x00, /* non-zig-zag varint sequence number */
+            0x03, /* message-name length */
+            0x62,
+            0x61,
+            0x72 /* "bar" */,
+        ];
+
+        assert_eq_written_bytes!(o_prot, expected);
+    }
+
+    #[test]
+    fn must_read_message_begin_zero_sequence_number() {
+        let (mut i_prot, _) = test_objects();
+
+        #[rustfmt::skip]
+        let source_bytes: [u8; 7] = [
+            0x82, /* protocol ID */
+            0x41, /* message type | protocol version */
+            0x00, /* non-zig-zag varint sequence number */
+            0x03, /* message-name length */
+            0x62,
+            0x61,
+            0x72 /* "bar" */,
+        ];
+
+        i_prot.transport.set_readable_bytes(&source_bytes);
+
+        let expected = TMessageIdentifier::new("bar", TMessageType::Reply, 0);
+        let res = assert_success!(i_prot.read_message_begin());
+
+        assert_eq!(&expected, &res);
+    }
+
+    #[test]
+    fn must_write_message_begin_largest_minimum_negative_sequence_number() {
+        let (_, mut o_prot) = test_objects();
+
+        assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new(
+            "bar",
+            TMessageType::Reply,
+            i32::MIN
+        )));
+
+        // two's complement notation of i32::MIN = 1000_0000_0000_0000_0000_0000_0000_0000
+        #[rustfmt::skip]
+        let expected: [u8; 11] = [
+            0x82, /* protocol ID */
+            0x41, /* message type | protocol version */
+            0x80,
+            0x80,
+            0x80,
+            0x80,
+            0x08, /* non-zig-zag varint sequence number */
+            0x03, /* message-name length */
+            0x62,
+            0x61,
+            0x72 /* "bar" */,
+        ];
+
+        assert_eq_written_bytes!(o_prot, expected);
+    }
+
+    #[test]
+    fn must_read_message_begin_largest_minimum_negative_sequence_number() {
+        let (mut i_prot, _) = test_objects();
+
+        // two's complement notation of i32::MIN = 1000_0000_0000_0000_0000_0000_0000_0000
+        #[rustfmt::skip]
+        let source_bytes: [u8; 11] = [
+            0x82, /* protocol ID */
+            0x41, /* message type | protocol version */
+            0x80,
+            0x80,
+            0x80,
+            0x80,
+            0x08, /* non-zig-zag varint sequence number */
+            0x03, /* message-name length */
+            0x62,
+            0x61,
+            0x72 /* "bar" */,
+        ];
+
+        i_prot.transport.set_readable_bytes(&source_bytes);
+
+        let expected = TMessageIdentifier::new("bar", TMessageType::Reply, i32::MIN);
+        let res = assert_success!(i_prot.read_message_begin());
+
+        assert_eq!(&expected, &res);
+    }
+
+    #[test]
+    fn must_write_message_begin_negative_sequence_number_0() {
+        let (_, mut o_prot) = test_objects();
+
+        assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new(
+            "foo",
+            TMessageType::Call,
+            -431
+        )));
+
+        // signed two's complement of -431 = 1111_1111_1111_1111_1111_1110_0101_0001
+        #[rustfmt::skip]
+        let expected: [u8; 11] = [
+            0x82, /* protocol ID */
+            0x21, /* message type | protocol version */
+            0xD1,
+            0xFC,
+            0xFF,
+            0xFF,
+            0x0F, /* non-zig-zag varint sequence number */
+            0x03, /* message-name length */
+            0x66,
+            0x6F,
+            0x6F /* "foo" */,
+        ];
+
+        assert_eq_written_bytes!(o_prot, expected);
+    }
+
+    #[test]
+    fn must_read_message_begin_negative_sequence_number_0() {
+        let (mut i_prot, _) = test_objects();
+
+        // signed two's complement of -431 = 1111_1111_1111_1111_1111_1110_0101_0001
+        #[rustfmt::skip]
+        let source_bytes: [u8; 11] = [
+            0x82, /* protocol ID */
+            0x21, /* message type | protocol version */
+            0xD1,
+            0xFC,
+            0xFF,
+            0xFF,
+            0x0F, /* non-zig-zag varint sequence number */
+            0x03, /* message-name length */
+            0x66,
+            0x6F,
+            0x6F /* "foo" */,
+        ];
+
+        i_prot.transport.set_readable_bytes(&source_bytes);
+
+        let expected = TMessageIdentifier::new("foo", TMessageType::Call, -431);
+        let res = assert_success!(i_prot.read_message_begin());
+
+        assert_eq!(&expected, &res);
+    }
+
+    #[test]
+    fn must_write_message_begin_negative_sequence_number_1() {
+        let (_, mut o_prot) = test_objects();
+
+        assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new(
+            "foo",
+            TMessageType::Call,
+            -73_184_125
+        )));
+
+        // signed two's complement of -73184125 = 1111_1011_1010_0011_0100_1100_1000_0011
+        #[rustfmt::skip]
+        let expected: [u8; 11] = [
+            0x82, /* protocol ID */
+            0x21, /* message type | protocol version */
+            0x83,
+            0x99,
+            0x8D,
+            0xDD,
+            0x0F, /* non-zig-zag varint sequence number */
+            0x03, /* message-name length */
+            0x66,
+            0x6F,
+            0x6F /* "foo" */,
+        ];
+
+        assert_eq_written_bytes!(o_prot, expected);
+    }
+
+    #[test]
+    fn must_read_message_begin_negative_sequence_number_1() {
+        let (mut i_prot, _) = test_objects();
+
+        // signed two's complement of -73184125 = 1111_1011_1010_0011_0100_1100_1000_0011
+        #[rustfmt::skip]
+        let source_bytes: [u8; 11] = [
+            0x82, /* protocol ID */
+            0x21, /* message type | protocol version */
+            0x83,
+            0x99,
+            0x8D,
+            0xDD,
+            0x0F, /* non-zig-zag varint sequence number */
+            0x03, /* message-name length */
+            0x66,
+            0x6F,
+            0x6F /* "foo" */,
+        ];
+
+        i_prot.transport.set_readable_bytes(&source_bytes);
+
+        let expected = TMessageIdentifier::new("foo", TMessageType::Call, -73_184_125);
+        let res = assert_success!(i_prot.read_message_begin());
+
+        assert_eq!(&expected, &res);
+    }
+
+    #[test]
+    fn must_write_message_begin_negative_sequence_number_2() {
+        let (_, mut o_prot) = test_objects();
+
+        assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new(
+            "foo",
+            TMessageType::Call,
+            -1_073_741_823
+        )));
+
+        // signed two's complement of -1073741823 = 1100_0000_0000_0000_0000_0000_0000_0001
+        #[rustfmt::skip]
+        let expected: [u8; 11] = [
+            0x82, /* protocol ID */
+            0x21, /* message type | protocol version */
+            0x81,
+            0x80,
+            0x80,
+            0x80,
+            0x0C, /* non-zig-zag varint sequence number */
+            0x03, /* message-name length */
+            0x66,
+            0x6F,
+            0x6F /* "foo" */,
+        ];
+
+        assert_eq_written_bytes!(o_prot, expected);
+    }
+
+    #[test]
+    fn must_read_message_begin_negative_sequence_number_2() {
+        let (mut i_prot, _) = test_objects();
+
+        // signed two's complement of -1073741823 = 1100_0000_0000_0000_0000_0000_0000_0001
+        let source_bytes: [u8; 11] = [
+            0x82, /* protocol ID */
+            0x21, /* message type | protocol version */
+            0x81,
+            0x80,
+            0x80,
+            0x80,
+            0x0C, /* non-zig-zag varint sequence number */
+            0x03, /* message-name length */
+            0x66,
+            0x6F,
+            0x6F /* "foo" */,
+        ];
+
+        i_prot.transport.set_readable_bytes(&source_bytes);
+
+        let expected = TMessageIdentifier::new("foo", TMessageType::Call, -1_073_741_823);
+        let res = assert_success!(i_prot.read_message_begin());
+
+        assert_eq!(&expected, &res);
+    }
+
+    #[test]
     fn must_round_trip_upto_i64_maxvalue() {
         // See https://issues.apache.org/jira/browse/THRIFT-5131
         for i in 0..64 {
@@ -722,11 +1121,7 @@
             let val: i64 = ((1u64 << i) - 1) as i64;
 
             o_prot
-                .write_field_begin(&TFieldIdentifier::new(
-                    "val",
-                    TType::I64,
-                    1
-                ))
+                .write_field_begin(&TFieldIdentifier::new("val", TType::I64, 1))
                 .unwrap();
             o_prot.write_i64(val).unwrap();
             o_prot.write_field_end().unwrap();