Proxy Protocol out fixes (#9341)
This makes the following adjustments for proxy.config.http.proxy_protocol_out:
* Logging updates.
* Fix Proxy Protocol src and dest order. This fixes #9335.
* Fix Proxy Protocol for blind tunnels
* Add a proxy_protocol_out test
diff --git a/iocore/net/ProxyProtocol.cc b/iocore/net/ProxyProtocol.cc
index 0b99355..c18f563 100644
--- a/iocore/net/ProxyProtocol.cc
+++ b/iocore/net/ProxyProtocol.cc
@@ -356,6 +356,7 @@
bw.fill(len);
}
+ Debug("proxyprotocol_v1", "Proxy Protocol v1: %.*s", static_cast<int>(bw.size()), bw.data());
bw.write("\r\n");
return bw.size();
@@ -441,6 +442,7 @@
// Set len field (number of following bytes part of the header) in the hdr
uint16_t len = htons(bw.size() - PPv2_CONNECTION_HEADER_LEN);
memcpy(buf + len_field_offset, &len, sizeof(uint16_t));
+ Debug("proxyprotocol_v2", "Proxy Protocol v2 of %zu bytes", bw.size());
return bw.size();
}
diff --git a/proxy/http/HttpSM.cc b/proxy/http/HttpSM.cc
index 3eded70..97519be 100644
--- a/proxy/http/HttpSM.cc
+++ b/proxy/http/HttpSM.cc
@@ -127,9 +127,10 @@
// nothing to forward
return 0;
} else {
+ Debug("proxyprotocol", "vc_in had no Proxy Protocol. Manufacturing from the vc_in socket.");
// set info from incoming NetVConnection
IpEndpoint local = vc_in->get_local_endpoint();
- info = ProxyProtocol{pp_version, local.family(), local, vc_in->get_remote_endpoint()};
+ info = ProxyProtocol{pp_version, local.family(), vc_in->get_remote_endpoint(), local};
}
}
@@ -6982,7 +6983,13 @@
client_response_hdr_bytes = 0;
}
- client_request_body_bytes = 0;
+ int64_t nbytes = 0;
+ if (t_state.txn_conf->proxy_protocol_out >= 0) {
+ nbytes = do_outbound_proxy_protocol(from_ua_buf, static_cast<NetVConnection *>(server_entry->vc), ua_txn->get_netvc(),
+ t_state.txn_conf->proxy_protocol_out);
+ }
+
+ client_request_body_bytes = nbytes;
if (ua_raw_buffer_reader != nullptr) {
client_request_body_bytes += from_ua_buf->write(ua_raw_buffer_reader, client_request_hdr_bytes);
ua_raw_buffer_reader->dealloc();
diff --git a/tests/gold_tests/autest-site/trafficserver.test.ext b/tests/gold_tests/autest-site/trafficserver.test.ext
index 43f768e..1fe9eed 100755
--- a/tests/gold_tests/autest-site/trafficserver.test.ext
+++ b/tests/gold_tests/autest-site/trafficserver.test.ext
@@ -41,7 +41,7 @@
def MakeATSProcess(obj, name, command='traffic_server', select_ports=True,
enable_tls=False, enable_cache=True, enable_quic=False,
block_for_debug=False, log_data=default_log_data,
- use_traffic_out=True):
+ use_traffic_out=True, enable_proxy_protocol=False):
#####################################
# common locations
@@ -328,6 +328,14 @@
if enable_tls:
get_port(p, "ssl_port")
get_port(p, "ssl_portv6")
+
+ if enable_proxy_protocol:
+ get_port(p, "proxy_protocol_port")
+ get_port(p, "proxy_protocol_portv6")
+
+ if enable_tls:
+ get_port(p, "proxy_protocol_ssl_port")
+ get_port(p, "proxy_protocol_ssl_portv6")
else:
p.Variables.port = 8080
p.Variables.portv6 = 8080
@@ -382,6 +390,10 @@
if enable_quic:
port_str += " {ssl_port}:quic {ssl_portv6}:quic:ipv6".format(
ssl_port=p.Variables.ssl_port, ssl_portv6=p.Variables.ssl_portv6)
+ if enable_proxy_protocol:
+ port_str += f" {p.Variables.proxy_protocol_port}:pp {p.Variables.proxy_protocol_portv6}:pp:ipv6"
+ if enable_tls:
+ port_str += f" {p.Variables.proxy_protocol_ssl_port}:pp:ssl {p.Variables.proxy_protocol_ssl_portv6}:pp:ssl:ipv6"
#p.Env['PROXY_CONFIG_HTTP_SERVER_PORTS'] = port_str
p.Disk.records_config.update({
'proxy.config.http.server_ports': port_str,
diff --git a/tests/gold_tests/proxy_protocol/proxy_protocol.test.py b/tests/gold_tests/proxy_protocol/proxy_protocol.test.py
index 4125ba0..ed58315 100644
--- a/tests/gold_tests/proxy_protocol/proxy_protocol.test.py
+++ b/tests/gold_tests/proxy_protocol/proxy_protocol.test.py
@@ -17,6 +17,7 @@
# limitations under the License.
import os
+from ports import get_port
import sys
Test.Summary = 'Test PROXY Protocol'
@@ -27,6 +28,8 @@
class ProxyProtocolTest:
+ """Test that ATS can receive Proxy Protocol."""
+
def __init__(self):
self.setupOriginServer()
self.setupTS()
@@ -39,7 +42,7 @@
'''
def setupTS(self):
- self.ts = Test.MakeATSProcess("ts", enable_tls=True, enable_cache=False)
+ self.ts = Test.MakeATSProcess("ts_in", enable_tls=True, enable_cache=False, enable_proxy_protocol=True)
self.ts.addDefaultSSLFiles()
self.ts.Disk.ssl_multicert_config.AddLine("dest_ip=* ssl_cert_name=server.pem ssl_key_name=server.key")
@@ -48,7 +51,6 @@
f"map / http://127.0.0.1:{self.httpbin.Variables.Port}/")
self.ts.Disk.records_config.update({
- "proxy.config.http.server_ports": f"{self.ts.Variables.port}:pp {self.ts.Variables.ssl_port}:ssl:pp",
"proxy.config.http.proxy_protocol_allowlist": "127.0.0.1",
"proxy.config.http.insert_forwarded": "for|by=ip|proto",
"proxy.config.ssl.server.cert.path": f"{self.ts.Variables.SSLDir}",
@@ -76,7 +78,7 @@
tr = Test.AddTestRun()
tr.Processes.Default.StartBefore(self.httpbin)
tr.Processes.Default.StartBefore(self.ts)
- tr.Processes.Default.Command = f"curl -vs --haproxy-protocol http://localhost:{self.ts.Variables.port}/get | {self.json_printer}"
+ tr.Processes.Default.Command = f"curl -vs --haproxy-protocol http://localhost:{self.ts.Variables.proxy_protocol_port}/get | {self.json_printer}"
tr.Processes.Default.ReturnCode = 0
tr.Processes.Default.Streams.stdout = "gold/test_case_0_stdout.gold"
tr.Processes.Default.Streams.stderr = "gold/test_case_0_stderr.gold"
@@ -88,7 +90,7 @@
Incoming PROXY Protocol v1 on SSL port
"""
tr = Test.AddTestRun()
- tr.Processes.Default.Command = f"curl -vsk --haproxy-protocol --http1.1 https://localhost:{self.ts.Variables.ssl_port}/get | {self.json_printer}"
+ tr.Processes.Default.Command = f"curl -vsk --haproxy-protocol --http1.1 https://localhost:{self.ts.Variables.proxy_protocol_ssl_port}/get | {self.json_printer}"
tr.Processes.Default.ReturnCode = 0
tr.Processes.Default.Streams.stdout = "gold/test_case_1_stdout.gold"
tr.Processes.Default.Streams.stderr = "gold/test_case_1_stderr.gold"
@@ -100,7 +102,7 @@
Test with netcat
"""
tr = Test.AddTestRun()
- tr.Processes.Default.Command = f"echo 'PROXY TCP4 198.51.100.1 198.51.100.2 51137 80\r\nGET /get HTTP/1.1\r\nHost: 127.0.0.1:80\r\n' | nc localhost {self.ts.Variables.port}"
+ tr.Processes.Default.Command = f"echo 'PROXY TCP4 198.51.100.1 198.51.100.2 51137 80\r\nGET /get HTTP/1.1\r\nHost: 127.0.0.1:80\r\n' | nc localhost {self.ts.Variables.proxy_protocol_port}"
tr.Processes.Default.ReturnCode = 0
tr.Processes.Default.Streams.stdout = "gold/test_case_2_stdout.gold"
tr.StillRunningAfter = self.httpbin
@@ -127,4 +129,160 @@
self.addTestCase99()
+class ProxyProtocolOutTest:
+ """Test that ATS can send Proxy Protocol."""
+
+ _pp_server = 'proxy_protocol_server.py'
+
+ _dns_counter = 0
+ _server_counter = 0
+ _ts_counter = 0
+
+ def __init__(self, pp_version: int, is_tunnel: bool) -> None:
+ """Initialize a ProxyProtocolOutTest.
+
+ :param pp_version: The Proxy Protocol version to use (1 or 2).
+ :param is_tunnel: Whether ATS should tunnel to the origin.
+ """
+
+ if pp_version not in (-1, 1, 2):
+ raise ValueError(
+ f'Invalid Proxy Protocol version (not 1 or 2): {pp_version}')
+ self._pp_version = pp_version
+ self._is_tunnel = is_tunnel
+
+ def setupOriginServer(self, tr: 'TestRun') -> None:
+ """Configure the origin server.
+
+ :param tr: The TestRun to associate the origin's Process with.
+ """
+ tr.Setup.CopyAs(self._pp_server, tr.RunDirectory)
+ cert_file = os.path.join(Test.Variables.AtsTestToolsDir, "ssl", "server.pem")
+ key_file = os.path.join(Test.Variables.AtsTestToolsDir, "ssl", "server.key")
+ tr.Setup.Copy(cert_file)
+ tr.Setup.Copy(key_file)
+ server = tr.Processes.Process(
+ f'server-{ProxyProtocolOutTest._server_counter}')
+ ProxyProtocolOutTest._server_counter += 1
+ server_port = get_port(server, "external_port")
+ internal_port = get_port(server, "internal_port")
+ command = (
+ f'{sys.executable} {self._pp_server} '
+ f'server.pem server.key 127.0.0.1 {server_port} {internal_port}')
+ if not self._is_tunnel:
+ command += ' --plaintext'
+ server.Command = command
+ server.Ready = When.PortOpenv4(server_port)
+
+ self._server = server
+
+ def setupDNS(self, tr: 'TestRun') -> None:
+ """Configure the DNS server.
+
+ :param tr: The TestRun to associate the DNS's Process with.
+ """
+ self._dns = tr.MakeDNServer(
+ f'dns-{ProxyProtocolOutTest._dns_counter}',
+ default='127.0.0.1')
+ ProxyProtocolOutTest._dns_counter += 1
+
+ def setupTS(self, tr: 'TestRun') -> None:
+ """Configure Traffic Server."""
+ process_name = f'ts-out-{ProxyProtocolOutTest._ts_counter}'
+ ProxyProtocolOutTest._ts_counter += 1
+ self._ts = tr.MakeATSProcess(process_name, enable_tls=True,
+ enable_cache=False)
+
+ self._ts.addDefaultSSLFiles()
+ self._ts.Disk.ssl_multicert_config.AddLine(
+ "dest_ip=* ssl_cert_name=server.pem ssl_key_name=server.key"
+ )
+
+ self._ts.Disk.remap_config.AddLine(
+ f"map / http://backend.pp.origin.com:{self._server.Variables.external_port}/")
+
+ self._ts.Disk.records_config.update({
+ "proxy.config.ssl.server.cert.path": f"{self._ts.Variables.SSLDir}",
+ "proxy.config.ssl.server.private_key.path": f"{self._ts.Variables.SSLDir}",
+ "proxy.config.diags.debug.enabled": 1,
+ "proxy.config.diags.debug.tags": "http|proxyprotocol",
+ "proxy.config.http.proxy_protocol_out": self._pp_version,
+ "proxy.config.dns.nameservers": f"127.0.0.1:{self._dns.Variables.Port}",
+ "proxy.config.dns.resolv_conf": 'NULL'
+ })
+
+ if self._is_tunnel:
+ self._ts.Disk.records_config.update({
+ "proxy.config.http.connect_ports": f'{self._server.Variables.external_port}',
+ })
+
+ self._ts.Disk.sni_yaml.AddLines([
+ 'sni:',
+ '- fqdn: pp.origin.com',
+ f' tunnel_route: backend.pp.origin.com:{self._server.Variables.external_port}',
+ ])
+
+ def setLogExpectations(self, tr: 'TestRun') -> None:
+
+ tr.Processes.Default.Streams.All += Testers.ContainsExpression(
+ "HTTP/1.1 200 OK",
+ "Verify that curl got a 200 response")
+
+ if self._pp_version in (1, 2):
+ expected_pp = (
+ 'PROXY TCP4 127.0.0.1 127.0.0.1 '
+ rf'\d+ {self._ts.Variables.ssl_port}'
+ )
+ self._server.Streams.All += Testers.ContainsExpression(
+ expected_pp,
+ "Verify the server got the expected Proxy Protocol string.")
+
+ self._server.Streams.All += Testers.ContainsExpression(
+ f'Received Proxy Protocol v{self._pp_version}',
+ "Verify the server got the expected Proxy Protocol version.")
+
+ if self._pp_version == -1:
+ self._server.Streams.All += Testers.ContainsExpression(
+ 'No Proxy Protocol string found',
+ 'There should be no Proxy Protocol string.')
+
+ def run(self) -> None:
+ """Run the test."""
+ description = f'Proxy Protocol v{self._pp_version} '
+ if self._is_tunnel:
+ description += "with blind tunneling"
+ else:
+ description += "without blind tunneling"
+ tr = Test.AddTestRun(description)
+
+ self.setupDNS(tr)
+ self.setupOriginServer(tr)
+ self.setupTS(tr)
+
+ self._ts.StartBefore(self._server)
+ self._ts.StartBefore(self._dns)
+ tr.Processes.Default.StartBefore(self._ts)
+
+ origin = f'pp.origin.com:{self._ts.Variables.ssl_port}'
+ command = (
+ 'sleep1; curl -vsk --http1.1 '
+ f'--resolve "{origin}:127.0.0.1" '
+ f'https://{origin}/get'
+ )
+
+ tr.Processes.Default.Command = command
+ tr.Processes.Default.ReturnCode = 0
+ # Its only one transaction, so this should complete quickly. The test
+ # server often hangs if there are issues parsing the Proxy Protocol
+ # string.
+ tr.TimeOut = 5
+ self.setLogExpectations(tr)
+
+
ProxyProtocolTest().run()
+
+ProxyProtocolOutTest(pp_version=-1, is_tunnel=False).run()
+ProxyProtocolOutTest(pp_version=1, is_tunnel=False).run()
+ProxyProtocolOutTest(pp_version=2, is_tunnel=False).run()
+ProxyProtocolOutTest(pp_version=1, is_tunnel=True).run()
+ProxyProtocolOutTest(pp_version=2, is_tunnel=True).run()
diff --git a/tests/gold_tests/proxy_protocol/proxy_protocol_server.py b/tests/gold_tests/proxy_protocol/proxy_protocol_server.py
new file mode 100644
index 0000000..af1a528
--- /dev/null
+++ b/tests/gold_tests/proxy_protocol/proxy_protocol_server.py
@@ -0,0 +1,381 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A simple server that expects and prints out the Proxy Protocol string."""
+
+import argparse
+import logging
+import socket
+import ssl
+import struct
+import sys
+import threading
+
+
+# Set a 10ms timeout for socket operations.
+TIMEOUT = .010
+
+PP_V2_PREFIX = b'\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a'
+
+
+# Create a condition variable for thread initialization.
+internal_thread_is_ready = threading.Condition()
+
+
+def parse_args() -> argparse.Namespace:
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser(description=__doc__)
+
+ parser.add_argument(
+ "certfile",
+ help="The path to the certificate file to use for TLS.")
+ parser.add_argument(
+ "keyfile",
+ help="The path to the key file to use for TLS.")
+ parser.add_argument(
+ "address",
+ help="The IP address to listen on.")
+ parser.add_argument(
+ "port",
+ type=int,
+ help="The port to listen on.")
+ parser.add_argument(
+ "internal_port",
+ type=int,
+ help="The internal port used to parse the TLS content.")
+ parser.add_argument(
+ "--plaintext",
+ action="store_true",
+ help="Listen for plaintext connections instead of TLS.")
+
+ return parser.parse_args()
+
+
+def receive_and_send_http(sock: socket.socket) -> None:
+ """Receive and send an HTTP request and response.
+
+ :param sock: The socket to receive the request on.
+ """
+ sock.settimeout(TIMEOUT)
+
+ # Read the request until the final CRLF is received.
+ received_request = b''
+ while True:
+ data = None
+ try:
+ data = sock.recv(1024)
+ logging.debug(f'Internal: received {len(data)} bytes')
+ except socket.timeout:
+ continue
+ if not data:
+ break
+ received_request += data
+
+ if b'\r\n\r\n' in received_request:
+ break
+ logging.info("Received request:")
+ logging.info(received_request.decode("utf-8"))
+
+ # Send a response.
+ response = (
+ "HTTP/1.1 200 OK\r\n"
+ "Content-Length: 0\r\n"
+ "Connection: close\r\n"
+ "\r\n"
+ )
+ logging.info(f'Sending:\n{response}')
+ try:
+ sock.sendall(response.encode("utf-8"))
+ except socket.timeout:
+ logging.error("Timeout sending a response.")
+
+
+def run_internal_server(cert_file: str, key_file: str,
+ address: str, port: int,
+ plaintext: bool) -> None:
+ """Run the internal server.
+
+ This server is receives the HTTP content with the Proxy Protocol prefix
+ stripped off by the client.
+
+ :param cert_file: The path to the certificate file to use for TLS.
+ :param key_file: The path to the key file to use for TLS.
+ :param address: The IP address to listen on.
+ :param port: The port to listen on.
+ :param plaintext: Whether to listen for HTTP rather than HTTPS traffic.
+ """
+
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.bind((address, port))
+ sock.listen()
+ logging.info(f"Internal HTTPS server listening on {address}:{port}")
+
+ if plaintext:
+ # Notify the waiting thread that the internal server is ready.
+ with internal_thread_is_ready:
+ internal_thread_is_ready.notify()
+ conn, addr_in = sock.accept()
+ logging.info(f"Internal server accepted plaintext connection from {addr_in}")
+ with conn:
+ receive_and_send_http(conn)
+ else:
+ # Wrap the server socket to handle TLS.
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ context.load_cert_chain(certfile=cert_file, keyfile=key_file)
+
+ with context.wrap_socket(sock, server_side=True) as ssock:
+ with internal_thread_is_ready:
+ internal_thread_is_ready.notify()
+ conn, addr_in = ssock.accept()
+ logging.info(f"Internal server accepted TLS connection from {addr_in}")
+ with conn:
+ receive_and_send_http(conn)
+
+
+def parse_pp_v1(pp_bytes: bytes) -> int:
+ """Parse and print the Proxy Protocol v1 string.
+
+ :param pp_bytes: The bytes containing the Proxy Protocol string. There may
+ be more bytes than the Proxy Protocol string.
+
+ :returns: The number of bytes occupied by the proxy v1 protcol.
+ """
+ # Proxy Protocol v1 string ends with CRLF.
+ end = pp_bytes.find(b'\r\n')
+ if end == -1:
+ raise ValueError("Proxy Protocol v1 string ending not found")
+ logging.info(pp_bytes[:end].decode("utf-8"))
+ return end + 2
+
+
+def parse_pp_v2(pp_bytes: bytes) -> int:
+ """Parse and print the Proxy Protocol v2 string.
+
+ :param pp_bytes: The bytes containing the Proxy Protocol string. There may
+ be more bytes than the Proxy Protocol string.
+
+ :returns: The number of bytes occupied by the proxy v2 protocol string.
+ """
+
+ # Skip the 12 byte header.
+ pp_bytes = pp_bytes[12:]
+ version_command = pp_bytes[0]
+ pp_bytes = pp_bytes[1:]
+ family_protocol = pp_bytes[0]
+ pp_bytes = pp_bytes[1:]
+ tuple_length = int.from_bytes(pp_bytes[:2], byteorder='big')
+ pp_bytes = pp_bytes[2:]
+
+ # Of version_command, the highest 4 bits is the version and the lowest is
+ # the command.
+ version = version_command >> 4
+ command = version_command & 0x0F
+
+ if version != 2:
+ raise ValueError(
+ f'Invalid version: {version} (by spec, should always be 0x02)')
+
+ if command == 0x0:
+ command_description = 'LOCAL'
+ elif command == 0x1:
+ command_description = 'PROXY'
+ else:
+ raise ValueError(
+ f'Invalid command: {command} (by spec, should be 0x00 or 0x01)')
+
+ # Of address_family, the highest 4 bits is the address family and the
+ # lowest is the transport protocol.
+ if family_protocol == 0x0:
+ transport_protocol_description = 'UNSPEC'
+ elif family_protocol == 0x11:
+ transport_protocol_description = 'TCP4'
+ elif family_protocol == 0x12:
+ transport_protocol_description = 'UDP4'
+ elif family_protocol == 0x21:
+ transport_protocol_description = 'TCP6'
+ elif family_protocol == 0x22:
+ transport_protocol_description = 'UDP6'
+ elif family_protocol == 0x31:
+ transport_protocol_description = 'UNIX_STREAM'
+ elif family_protocol == 0x32:
+ transport_protocol_description = 'UNIX_DGRAM'
+ else:
+ raise ValueError(
+ f'Invalid address family: {family_protocol} (by spec, should be '
+ '0x00, 0x11, 0x12, 0x21, 0x22, 0x31, or 0x32)')
+
+ if family_protocol in (0x11, 0x12):
+ if tuple_length != 12:
+ raise ValueError(
+ "Unexpected tuple length for TCP4/UDP4: "
+ f"{tuple_length} (by spec, should be 12)"
+ )
+ src_addr = socket.inet_ntop(socket.AF_INET, pp_bytes[:4])
+ pp_bytes = pp_bytes[4:]
+ dst_addr = socket.inet_ntop(socket.AF_INET, pp_bytes[:4])
+ pp_bytes = pp_bytes[4:]
+ src_port = int.from_bytes(pp_bytes[:2], byteorder='big')
+ pp_bytes = pp_bytes[2:]
+ dst_port = int.from_bytes(pp_bytes[:2], byteorder='big')
+ pp_bytes = pp_bytes[2:]
+
+ tuple_description = f'{src_addr} {dst_addr} {src_port} {dst_port}'
+ logging.info(
+ f'{command_description} {transport_protocol_description} '
+ f'{tuple_description}')
+
+ return 16 + tuple_length
+
+
+def accept_pp_connection(sock: socket.socket, address: str, internal_port: int) -> bool:
+ """Accept a connection and parse the proxy protocol header.
+
+ :param sock: The socket to accept the connection on.
+ :param address: The address of the internal server to connect to.
+ :param internal_port: The port of the internal server to connect to.
+
+ :returns: True if the connection had a payload, False otherwise.
+ """
+ client_conn, address_in = sock.accept()
+ logging.info(f'Accepted connection from {address_in}')
+ with client_conn:
+ has_pp = False
+ pp_length = 0
+ # Read the Proxy Protocol prefix, which ends with the first CRLF.
+ received_data = b''
+ while True:
+ data = client_conn.recv(1024)
+ if data:
+ logging.debug(f"Received: {len(data)} bytes")
+ else:
+ logging.info("No data received while waiting for "
+ "Proxy Protocol prefix")
+ return False
+ received_data += data
+
+ if (received_data.startswith(b'PROXY') and
+ b'\r\n' in received_data):
+ logging.info("Received Proxy Protocol v1")
+ pp_length = parse_pp_v1(received_data)
+ has_pp = True
+ break
+
+ if received_data.startswith(PP_V2_PREFIX):
+ logging.info("Received Proxy Protocol v2")
+ pp_length = parse_pp_v2(received_data)
+ has_pp = True
+ break
+
+ if len(received_data) > 108:
+ # The spec gaurantees that the prefix will be no more than
+ # 108 bytes.
+ logging.info("No Proxy Protocol string found.")
+ break
+ if has_pp:
+ # Now, strip the received_data of the prefix and blind tunnel
+ # the rest of the content.
+ for_internal = received_data[pp_length:]
+ logging.debug(
+ f"Stripped the prefix, now thare are {len(for_internal)} "
+ "bytes for the internal server.")
+ else:
+ for_internal = received_data
+ client_conn.settimeout(TIMEOUT)
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as internal_sock:
+ logging.debug(f"Connecting to internal server on {address}:{internal_port}")
+ internal_sock.connect((address, internal_port))
+ internal_sock.settimeout(TIMEOUT)
+ if for_internal:
+ logging.debug('Sending remaining data to internal server: '
+ f'{len(for_internal)} bytes')
+ internal_sock.sendall(for_internal)
+ while True:
+
+ logging.debug("entering loop")
+
+ try:
+ from_internal = internal_sock.recv(1024)
+ logging.debug(f'Received {len(from_internal)} bytes from internal server')
+ if not from_internal:
+ logging.debug('No more data from internal server, closing connection')
+ break
+ client_conn.sendall(from_internal)
+ logging.debug(f'Sent {len(from_internal)} bytes to client')
+ except socket.timeout:
+ pass
+
+ try:
+ for_internal = client_conn.recv(1024)
+ logging.debug(f'Received {len(for_internal)} bytes from client')
+ if not for_internal:
+ logging.debug('No more data from client, closing connection')
+ break
+ internal_sock.sendall(for_internal)
+ logging.debug(f'Sent {len(for_internal)} bytes to internal server')
+ except socket.timeout:
+ pass
+
+
+def receive_pp_request(address: str, port: int, internal_port: int) -> None:
+ """Start a server to receive a connection which may have a proxy protocol
+ header.
+
+ :param address: The address to listen on.
+ :param port: The port to listen on.
+ :param internal_port: The port of the internal server to connect to.
+ """
+
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.bind((address, port))
+ sock.listen()
+
+ # The PortOpen logic will create an empty request to the server. Ignore
+ # those until we have a connection with a real request which comes in.
+ request_received = False
+ while not request_received:
+ request_received = accept_pp_connection(sock, address,
+ internal_port)
+
+
+def main() -> int:
+ """Run the server listening for Proxy Protocol."""
+ args = parse_args()
+
+ with internal_thread_is_ready:
+ """Start the threads to receive requests."""
+ internal_server = threading.Thread(
+ target=run_internal_server,
+ args=(args.certfile, args.keyfile, args.address,
+ args.internal_port, args.plaintext))
+ internal_server.start()
+
+ # Wait for the internal server to start before proceeding.
+ internal_thread_is_ready.wait()
+
+ receive_pp_request(args.address, args.port, args.internal_port)
+ internal_server.join()
+
+ return 0
+
+
+if __name__ == "__main__":
+ logging.basicConfig(
+ level=logging.DEBUG,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ sys.exit(main())