THRIFT-5784: Add THeaderTransforms to TConfiguration

Client: go

While I'm here, also auto add compression transforms read (currently
only zlib is supported) to writeTransforms so that a server will auto
use the same compression on the responses as the client chose to use in
the requests.
diff --git a/lib/go/test/tests/header_zlib_test.go b/lib/go/test/tests/header_zlib_test.go
new file mode 100644
index 0000000..cf2f849
--- /dev/null
+++ b/lib/go/test/tests/header_zlib_test.go
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package tests
+
+import (
+	"context"
+	"errors"
+	"io"
+	"net"
+	"sync/atomic"
+	"testing"
+	"time"
+
+	"github.com/apache/thrift/lib/go/test/gopath/src/servicestest"
+	"github.com/apache/thrift/lib/go/thrift"
+)
+
+type zlibTestHandler struct {
+	servicestest.AServ
+
+	tb   testing.TB
+	text string
+}
+
+func (z zlibTestHandler) Stringfunc_1int_1s(ctx context.Context, i int64, s string) (string, error) {
+	if s != z.text {
+		z.tb.Errorf("string arg got %q want %q", s, z.text)
+	}
+	return z.text, nil
+}
+
+type countingProxy struct {
+	// Need to fill when constructing
+	tb         testing.TB
+	remoteAddr net.Addr
+
+	// internal states
+	listener      net.Listener
+	clientWritten atomic.Int64
+	serverWritten atomic.Int64
+}
+
+func (cp *countingProxy) getAndResetCounters() (req, resp int64) {
+	req = cp.clientWritten.Swap(0)
+	resp = cp.serverWritten.Swap(0)
+	return req, resp
+}
+
+func (cp *countingProxy) serve() {
+	cp.tb.Helper()
+
+	listener, err := net.Listen("tcp", ":0")
+	if err != nil {
+		cp.tb.Fatalf("Failed to listen proxy: %v", err)
+	}
+	go func() {
+		for {
+			client, err := listener.Accept()
+			if err != nil {
+				if !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
+					cp.tb.Errorf("proxy accept error: %v", err)
+				}
+				return
+			}
+			server, err := net.Dial(cp.remoteAddr.Network(), cp.remoteAddr.String())
+			if err != nil {
+				cp.tb.Logf("proxy failed to dial server %v: %v", cp.remoteAddr, err)
+			}
+			proxy := func(read, write net.Conn, count *atomic.Int64) {
+				var buf [1024]byte
+				for {
+					n, err := read.Read(buf[:])
+					if n > 0 {
+						count.Add(int64(n))
+						if _, err := write.Write(buf[:n]); err != nil {
+							cp.tb.Errorf("proxy write error: %v", err)
+						}
+					}
+					if err != nil {
+						if !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
+							cp.tb.Errorf("proxy read error: %v", err)
+						}
+						read.Close()
+						write.Close()
+						return
+					}
+				}
+			}
+			// Read from client
+			go proxy(client, server, &cp.clientWritten)
+			// Read from server
+			go proxy(server, client, &cp.serverWritten)
+		}
+	}()
+	cp.listener = listener
+}
+
+func TestTHeaderZlibClient(t *testing.T) {
+	// Some text that zlib should be able to compress
+	const text = `Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.`
+
+	socket, err := thrift.NewTServerSocket(":0")
+	if err != nil {
+		t.Fatalf("Failed to create server socket: %v", err)
+	}
+	// Call listen to reserve a port and check for any issues early
+	if err := socket.Listen(); err != nil {
+		t.Fatalf("Failed to listen server socket: %v", err)
+	}
+	server := thrift.NewTSimpleServer4(
+		servicestest.NewAServProcessor(zlibTestHandler{
+			tb:   t,
+			text: text,
+		}),
+		socket,
+		thrift.NewTHeaderTransportFactoryConf(nil, nil),
+		thrift.NewTHeaderProtocolFactoryConf(nil),
+	)
+	go server.Serve()
+	// give the server a little time to start serving
+	time.Sleep(10 * time.Millisecond)
+	t.Cleanup(func() {
+		server.Stop()
+	})
+	t.Logf("server running on %v", socket.Addr())
+
+	proxy := countingProxy{
+		tb:         t,
+		remoteAddr: socket.Addr(),
+	}
+	proxy.serve()
+	t.Cleanup(func() {
+		proxy.listener.Close()
+	})
+	t.Logf("proxy running on %v", proxy.listener.Addr())
+
+	clientRoundtrip := func(cfg *thrift.TConfiguration) {
+		t.Helper()
+
+		socket := thrift.NewTSocketConf(proxy.listener.Addr().String(), cfg)
+		if err := socket.Open(); err != nil {
+			t.Errorf("failed to open socket: %v", err)
+			return
+		}
+		defer socket.Close()
+		protoFactory := thrift.NewTHeaderProtocolFactoryConf(cfg)
+		client := thrift.NewTStandardClient(
+			protoFactory.GetProtocol(socket),
+			protoFactory.GetProtocol(socket),
+		)
+		c := servicestest.NewAServClient(client)
+		got, err := c.Stringfunc_1int_1s(context.Background(), 0, text)
+		if err != nil {
+			t.Errorf("Stringfunc_1int_1s call failed: %v", err)
+			return
+		}
+		if got != text {
+			t.Errorf("Stringfunc_1int_1s got %q want %q", got, text)
+		}
+	}
+
+	clientRoundtrip(nil)
+	nozlibReq, nozlibResp := proxy.getAndResetCounters()
+	t.Logf("nozlib request size: %d, response size: %d", nozlibReq, nozlibResp)
+
+	clientRoundtrip(&thrift.TConfiguration{
+		THeaderTransforms: []thrift.THeaderTransformID{thrift.TransformZlib},
+	})
+	zlibReq, zlibResp := proxy.getAndResetCounters()
+	t.Logf("zlib request size: %d, response size: %d", zlibReq, zlibResp)
+
+	if zlibReq >= nozlibReq {
+		t.Errorf("zlib request size %d >= nozlib request size %d", zlibReq, nozlibReq)
+	}
+	if zlibResp >= nozlibResp {
+		t.Errorf("zlib response size %d >= nozlib response size %d", zlibResp, nozlibResp)
+	}
+
+	clientRoundtrip(nil)
+	nozlibReq2, nozlibResp2 := proxy.getAndResetCounters()
+	t.Logf("nozlib2 request size: %d, response size: %d", nozlibReq, nozlibResp)
+
+	if nozlibReq2 != nozlibReq {
+		t.Errorf("nozlib request 2 size %d != nozlib request size %d", nozlibReq2, nozlibReq)
+	}
+	if nozlibResp2 != nozlibResp {
+		t.Errorf("nozlib response 2 size %d != nozlib response size %d", nozlibResp2, nozlibResp)
+	}
+}
diff --git a/lib/go/thrift/configuration.go b/lib/go/thrift/configuration.go
index de27edd..a9565d3 100644
--- a/lib/go/thrift/configuration.go
+++ b/lib/go/thrift/configuration.go
@@ -56,47 +56,47 @@
 //
 // For example, say you want to migrate this old code into using TConfiguration:
 //
-//     sccket, err := thrift.NewTSocketTimeout("host:port", time.Second, time.Second)
-//     transFactory := thrift.NewTFramedTransportFactoryMaxLength(
-//         thrift.NewTTransportFactory(),
-//         1024 * 1024 * 256,
-//     )
-//     protoFactory := thrift.NewTBinaryProtocolFactory(true, true)
+//	socket, err := thrift.NewTSocketTimeout("host:port", time.Second, time.Second)
+//	transFactory := thrift.NewTFramedTransportFactoryMaxLength(
+//	    thrift.NewTTransportFactory(),
+//	    1024 * 1024 * 256,
+//	)
+//	protoFactory := thrift.NewTBinaryProtocolFactory(true, true)
 //
 // This is the wrong way to do it because in the end the TConfiguration used by
 // socket and transFactory will be overwritten by the one used by protoFactory
 // because of TConfiguration propagation:
 //
-//     // bad example, DO NOT USE
-//     sccket := thrift.NewTSocketConf("host:port", &thrift.TConfiguration{
-//         ConnectTimeout: time.Second,
-//         SocketTimeout:  time.Second,
-//     })
-//     transFactory := thrift.NewTFramedTransportFactoryConf(
-//         thrift.NewTTransportFactory(),
-//         &thrift.TConfiguration{
-//             MaxFrameSize: 1024 * 1024 * 256,
-//         },
-//     )
-//     protoFactory := thrift.NewTBinaryProtocolFactoryConf(&thrift.TConfiguration{
-//         TBinaryStrictRead:  thrift.BoolPtr(true),
-//         TBinaryStrictWrite: thrift.BoolPtr(true),
-//     })
+//	// bad example, DO NOT USE
+//	socket := thrift.NewTSocketConf("host:port", &thrift.TConfiguration{
+//	    ConnectTimeout: time.Second,
+//	    SocketTimeout:  time.Second,
+//	})
+//	transFactory := thrift.NewTFramedTransportFactoryConf(
+//	    thrift.NewTTransportFactory(),
+//	    &thrift.TConfiguration{
+//	        MaxFrameSize: 1024 * 1024 * 256,
+//	    },
+//	)
+//	protoFactory := thrift.NewTBinaryProtocolFactoryConf(&thrift.TConfiguration{
+//	    TBinaryStrictRead:  thrift.BoolPtr(true),
+//	    TBinaryStrictWrite: thrift.BoolPtr(true),
+//	})
 //
 // This is the correct way to do it:
 //
-//     conf := &thrift.TConfiguration{
-//         ConnectTimeout: time.Second,
-//         SocketTimeout:  time.Second,
+//	conf := &thrift.TConfiguration{
+//	    ConnectTimeout: time.Second,
+//	    SocketTimeout:  time.Second,
 //
-//         MaxFrameSize: 1024 * 1024 * 256,
+//	    MaxFrameSize: 1024 * 1024 * 256,
 //
-//         TBinaryStrictRead:  thrift.BoolPtr(true),
-//         TBinaryStrictWrite: thrift.BoolPtr(true),
-//     }
-//     sccket := thrift.NewTSocketConf("host:port", conf)
-//     transFactory := thrift.NewTFramedTransportFactoryConf(thrift.NewTTransportFactory(), conf)
-//     protoFactory := thrift.NewTBinaryProtocolFactoryConf(conf)
+//	    TBinaryStrictRead:  thrift.BoolPtr(true),
+//	    TBinaryStrictWrite: thrift.BoolPtr(true),
+//	}
+//	socket := thrift.NewTSocketConf("host:port", conf)
+//	transFactory := thrift.NewTFramedTransportFactoryConf(thrift.NewTTransportFactory(), conf)
+//	protoFactory := thrift.NewTBinaryProtocolFactoryConf(conf)
 //
 // [1]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-tconfiguration.md
 type TConfiguration struct {
@@ -132,6 +132,8 @@
 	// THeaderProtocolIDPtr and THeaderProtocolIDPtrMust helper functions
 	// are provided to help filling this value.
 	THeaderProtocolID *THeaderProtocolID
+	// The write transforms to be applied to THeaderTransport.
+	THeaderTransforms []THeaderTransformID
 
 	// Used internally by deprecated constructors, to avoid overriding
 	// underlying TTransport/TProtocol's cfg by accidental propagations.
@@ -245,6 +247,18 @@
 	return protoID
 }
 
+// GetTHeaderTransforms returns the THeaderTransformIDs to be applied on
+// THeaderTransport writing.
+//
+// It's nil-safe. If tc is nil, empty slice will be returned (meaning no
+// transforms to be applied).
+func (tc *TConfiguration) GetTHeaderTransforms() []THeaderTransformID {
+	if tc == nil {
+		return nil
+	}
+	return tc.THeaderTransforms
+}
+
 // THeaderProtocolIDPtr validates and returns the pointer to id.
 //
 // If id is not a valid THeaderProtocolID, a pointer to THeaderProtocolDefault
diff --git a/lib/go/thrift/header_protocol.go b/lib/go/thrift/header_protocol.go
index 36777b4..bec84b8 100644
--- a/lib/go/thrift/header_protocol.go
+++ b/lib/go/thrift/header_protocol.go
@@ -119,6 +119,11 @@
 }
 
 // AddTransform add a transform for writing.
+//
+// Deprecated: This only applies to the next message written, and the next read
+// message will cause write transforms to be reset from what's configured in
+// TConfiguration. For sticky transforms, use TConfiguration.THeaderTransforms
+// instead.
 func (p *THeaderProtocol) AddTransform(transform THeaderTransformID) error {
 	return p.transport.AddTransform(transform)
 }
diff --git a/lib/go/thrift/header_protocol_test.go b/lib/go/thrift/header_protocol_test.go
index 48a69bf..dfd84f8 100644
--- a/lib/go/thrift/header_protocol_test.go
+++ b/lib/go/thrift/header_protocol_test.go
@@ -39,4 +39,24 @@
 			}))
 		},
 	)
+
+	t.Run(
+		"binary-zlib",
+		func(t *testing.T) {
+			ReadWriteProtocolTest(t, NewTHeaderProtocolFactoryConf(&TConfiguration{
+				THeaderProtocolID: THeaderProtocolIDPtrMust(THeaderProtocolBinary),
+				THeaderTransforms: []THeaderTransformID{TransformZlib},
+			}))
+		},
+	)
+
+	t.Run(
+		"compact-zlib",
+		func(t *testing.T) {
+			ReadWriteProtocolTest(t, NewTHeaderProtocolFactoryConf(&TConfiguration{
+				THeaderProtocolID: THeaderProtocolIDPtrMust(THeaderProtocolCompact),
+				THeaderTransforms: []THeaderTransformID{TransformZlib},
+			}))
+		},
+	)
 }
diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go
index 772d922..d81fb29 100644
--- a/lib/go/thrift/header_transport.go
+++ b/lib/go/thrift/header_transport.go
@@ -151,6 +151,11 @@
 }
 
 // AddTransform adds a transform.
+//
+// Deprecated: This only applies to the next message written, and the next read
+// message will cause write transforms to be reset from what's configured in
+// TConfiguration. For sticky transforms, use TConfiguration.THeaderTransforms
+// instead.
 func (tr *TransformReader) AddTransform(id THeaderTransformID) error {
 	switch id {
 	default:
@@ -300,11 +305,12 @@
 	}
 	PropagateTConfiguration(trans, conf)
 	return &THeaderTransport{
-		transport:    trans,
-		reader:       bufio.NewReader(trans),
-		writeHeaders: make(THeaderMap),
-		protocolID:   conf.GetTHeaderProtocolID(),
-		cfg:          conf,
+		transport:       trans,
+		reader:          bufio.NewReader(trans),
+		writeHeaders:    make(THeaderMap),
+		writeTransforms: conf.GetTHeaderTransforms(),
+		protocolID:      conf.GetTHeaderProtocolID(),
+		cfg:             conf,
 	}
 }
 
@@ -449,6 +455,11 @@
 	}
 	t.protocolID = THeaderProtocolID(protoID)
 
+	// Reset writeTransforms to the ones from cfg, as we are going to add
+	// compression transforms from what we read, we don't want to accumulate
+	// different transforms read from different requests
+	t.writeTransforms = t.cfg.GetTHeaderTransforms()
+
 	var transformCount int32
 	transformCount, err = hp.readVarint32()
 	if err != nil {
@@ -466,7 +477,16 @@
 			if err != nil {
 				return err
 			}
-			transformIDs[i] = THeaderTransformID(id)
+			tID := THeaderTransformID(id)
+			transformIDs[i] = tID
+
+			// For compression transforms, we should also add them
+			// to writeTransforms so that the response (assuming we
+			// are reading a request) would do the same compression.
+			switch tID {
+			case TransformZlib:
+				t.addWriteTransformsDedupe(tID)
+			}
 		}
 		// The transform IDs on the wire was added based on the order of
 		// writing, so on the reading side we need to reverse the order.
@@ -726,6 +746,9 @@
 }
 
 // AddTransform add a transform for writing.
+//
+// NOTE: This is provided as a low-level API, but in general you should use
+// TConfiguration.THeaderTransforms to set transforms for writing instead.
 func (t *THeaderTransport) AddTransform(transform THeaderTransformID) error {
 	if !supportedTransformIDs[transform] {
 		return NewTProtocolExceptionWithType(
@@ -758,6 +781,17 @@
 	}
 }
 
+// addWriteTransformsDedupe adds id to writeTransforms only if it's not already
+// there.
+func (t *THeaderTransport) addWriteTransformsDedupe(id THeaderTransformID) {
+	for _, existingID := range t.writeTransforms {
+		if existingID == id {
+			return
+		}
+	}
+	t.writeTransforms = append(t.writeTransforms, id)
+}
+
 // SetTConfiguration implements TConfigurationSetter.
 func (t *THeaderTransport) SetTConfiguration(cfg *TConfiguration) {
 	PropagateTConfiguration(t.transport, cfg)