PEP8, split into post data and post dict, prep for advanced scanning
diff --git a/aardvark.py b/aardvark.py
index 2ae9ec9..6cbb21d 100644
--- a/aardvark.py
+++ b/aardvark.py
@@ -24,9 +24,10 @@
import re
import asfpy.syslog
import typing
+import multidict
# Shadow print with our syslog wrapper
-print = asfpy.syslog.Printer(stdout=True, identity='aardvark')
+print = asfpy.syslog.Printer(stdout=True, identity="aardvark")
# Some defaults to keep this running without a yaml
@@ -34,9 +35,10 @@
DEFAULT_BACKEND = "http://localhost:8080"
DEFAULT_IPHEADER = "x-forwarded-for"
DEFAULT_DEBUG = False
+MINIMUM_SCAN_LENGTH = 16 # We don't normally scan form data elements with fewer than 16 chars
+
class Aardvark:
-
def __init__(self, config_file: str = "aardvark.yaml"):
""" Load and parse the config """
@@ -69,9 +71,9 @@
# If config file, load that into the vars
if config_file:
self.config = yaml.safe_load(open(config_file, "r"))
- self.debug = self.config.get('debug', self.debug)
+ self.debug = self.config.get("debug", self.debug)
self.proxy_url = self.config.get("proxy_url", self.proxy_url)
- self.port = int(self.config.get('port', self.port))
+ self.port = int(self.config.get("port", self.port))
self.ipheader = self.config.get("ipheader", self.ipheader)
for pm in self.config.get("postmatches", []):
r = re.compile(bytes(pm, encoding="utf-8"), flags=re.IGNORECASE)
@@ -89,7 +91,7 @@
r = re.compile(bytes(req, encoding="utf-8"), flags=re.IGNORECASE)
self.multispam_auxiliary.add(r)
- def scan(self, request_url: str, post_data: bytes):
+ def scan_simple(self, request_url: str, post_data: bytes = None):
"""Scans post data for spam"""
bad_items = []
@@ -116,6 +118,16 @@
return bad_items
+ def scan_dict(self, post_dict: multidict.MultiDictProxy):
+ """Scans form data dicts for spam"""
+ bad_items = []
+ for k, v in post_dict.items():
+ if v and isinstance(v, str) and len(v) >= MINIMUM_SCAN_LENGTH:
+ b = bytes(v, encoding="utf-8")
+ bad_items.extend(self.scan_simple(f"formdata::{k}", b))
+
+ return bad_items
+
async def proxy(self, request: aiohttp.web.Request):
"""Handles each proxy request"""
request_url = "/" + request.match_info["path"]
@@ -145,7 +157,10 @@
print("Average request scan time is %.2f ms" % (avg * 1000.0))
# Read POST data and query string
- post_data = await request.read()
+ post_dict = await request.post() # Request data as key/value pairs if applicable
+ post_data = None
+ if not post_dict:
+ post_data = await request.read() # Request data as a blob if not valid form data
get_data = request.rel_url.query
# Perform scan!
@@ -155,7 +170,11 @@
if remote_ip in self.offenders:
bad_items.append("Client is on the list of bad offenders.")
else:
- bad_items = self.scan(request_url, post_data)
+ bad_items = []
+ if post_data:
+ bad_items.extend(self.scan_simple(request_url, post_data))
+ elif post_dict:
+ bad_items.extend(self.scan_dict(post_dict))
# If this URL is actually to be ignored, forget all we just did!
if bad_items:
for iu in self.ignoreurls:
@@ -180,15 +199,23 @@
async with aiohttp.ClientSession() as session:
try:
+ req_headers = request.headers.copy()
+ if post_dict:
+ del req_headers["content-length"]
async with session.request(
- request.method, target_url, headers=request.headers, params=get_data, data=post_data
+ request.method,
+ target_url,
+ headers=req_headers,
+ params=get_data,
+ data=post_data or post_dict,
+ timeout=30,
) as resp:
result = resp
raw = await result.read()
headers = result.headers.copy()
# We do NOT want chunked T-E! Leave it to aiohttp
- if 'Transfer-Encoding' in headers:
- del headers['Transfer-Encoding']
+ if "Transfer-Encoding" in headers:
+ del headers["Transfer-Encoding"]
self.processing_times.append(time.time() - now)
return aiohttp.web.Response(body=raw, status=result.status, headers=headers)
except aiohttp.client_exceptions.ClientConnectorError as e: