Add persistent storage of block list, allow for unblocking via CLI
diff --git a/aardvark.py b/aardvark.py
index 44ac18c..ff6e4eb 100644
--- a/aardvark.py
+++ b/aardvark.py
@@ -30,11 +30,11 @@
import spamfilter
import aiofile
import platform
+import datetime
# Shadow print with our syslog wrapper
print = asfpy.syslog.Printer(stdout=True, identity="aardvark")
-
# Some defaults to keep this running without a yaml
DEFAULT_PORT = 1729
DEFAULT_BACKEND = "http://localhost:8080"
@@ -45,6 +45,7 @@
DEFAULT_NAIVE = True
DEFAULT_SPAM_NAIVE_THRESHOLD = 60
MINIMUM_SCAN_LENGTH = 16 # We don't normally scan form data elements with fewer than 16 chars
+BLOCKFILE = "blocklist.txt"
class Aardvark:
@@ -63,7 +64,9 @@
# Init vars with defaults
self.config = {} # Our config, unless otherwise specified in init
+ self.myuid = str(uuid.uuid4())
self.debug = False # Debug prints, spammy!
+ self.persistence = False # Persistent block list
self.block_msg = DEFAULT_BLOCK_MSG
self.proxy_url = DEFAULT_BACKEND # Backend URL to proxy to
self.port = DEFAULT_PORT # Port we listen on
@@ -82,7 +85,7 @@
self.naive_threshold = DEFAULT_SPAM_NAIVE_THRESHOLD
self.enable_naive = DEFAULT_NAIVE
self.lock = asyncio.Lock()
-
+
if platform.system() == 'Linux':
major, minor, _ = platform.release().split('.', 2)
if major > "4" or (major >= "4" and minor >= "18"):
@@ -101,6 +104,7 @@
self.ipheader = self.config.get("ipheader", self.ipheader)
self.savepath = self.config.get("savedata", self.savepath)
self.asyncwrite = self.config.get("asyncwrite", self.asyncwrite)
+ self.persistence = self.config.get("persistence", self.persistence)
self.block_msg = self.config.get("spam_response", self.block_msg)
self.enable_naive = self.config.get("enable_naive_scan", self.enable_naive)
self.naive_threshold = self.config.get("naive_spam_threshold", self.naive_threshold)
@@ -119,10 +123,33 @@
for req in multimatch.get("auxiliary", []):
r = re.compile(bytes(req, encoding="utf-8"), flags=re.IGNORECASE)
self.multispam_auxiliary.add(r)
+ if self.persistence:
+ if os.path.exists(BLOCKFILE):
+ offenders = 0
+ with open(BLOCKFILE, "r") as bl:
+ for line in bl:
+ if line.strip() and not line.startswith("#"):
+ offenders += 1
+ self.offenders.add(line.strip())
+ print(f"Loaded {offenders} offenders from persistent storage.")
+
if self.enable_naive:
print("Loading Naïve Bayesian spam filter...")
self.spamfilter = spamfilter.BayesScanner()
-
+
+ async def save_block_list_async(self):
+ async with aiofile.async_open(BLOCKFILE, "w") as f:
+ bl = f"# Block list generated at {datetime.datetime.now().isoformat()}\n# UUID: {self.myuid}\n"
+ bl += "\n".join(self.offenders)
+ await f.write(bl)
+
+ def save_block_list_sync(self):
+ with open(BLOCKFILE, "w") as f:
+ bl = f"# Block list generated at {datetime.datetime.now().isoformat()}\n# UUID: {self.myuid}\n"
+ bl += "\n".join(self.offenders)
+ f.write(bl)
+
+
async def save_request_data(
self, request: aiohttp.web.Request, remote_ip: str, post: typing.Union[multidict.MultiDictProxy, bytes]
):
@@ -196,7 +223,8 @@
if self.enable_naive:
res = self.spamfilter.scan_text(v)
if res >= self.naive_threshold:
- bad_items.append(f"Form element {k} has spam score of {res}, crosses threshold of {self.naive_threshold}!")
+ bad_items.append(
+ f"Form element {k} has spam score of {res}, crosses threshold of {self.naive_threshold}!")
return bad_items
async def proxy(self, request: aiohttp.web.Request):
@@ -211,6 +239,16 @@
if self.debug:
print(f"Proxying request to {target_url}...") # This can get spammy, default is to not show it.
+ if request.path == '/aardvark-unblock':
+ ip = request.query_string
+ theiruid = request.headers.get('X-Aardvark-Key', '')
+ if theiruid == self.myuid:
+ if ip in self.offenders:
+ self.offenders.remove(ip)
+ print(f"Removed IP {ip} from block list.")
+ return aiohttp.web.Response(text="Block removed", status=200)
+ return aiohttp.web.Response(text="No such block", status=404)
+
# Debug output for syslog
self.last_batches.append(time.time())
if len(self.last_batches) >= 5000:
@@ -285,16 +323,17 @@
del req_headers["content-type"]
for k, v in post_dict.items():
if isinstance(v, aiohttp.web.FileField): # This sets multipart properly in the request
- form_data.add_field(name=v.name, filename=v.filename, value=v.file.raw, content_type=v.content_type)
+ form_data.add_field(name=v.name, filename=v.filename, value=v.file.raw,
+ content_type=v.content_type)
else:
form_data.add_field(name=k, value=v)
async with session.request(
- request.method,
- target_url,
- headers=req_headers,
- params=get_data,
- data=form_data or post_data,
- timeout=30,
+ request.method,
+ target_url,
+ headers=req_headers,
+ params=get_data,
+ data=form_data or post_data,
+ timeout=30,
) as resp:
result = resp
headers = result.headers.copy()
@@ -328,6 +367,15 @@
self.processing_times.append(time.time() - now)
return aiohttp.web.Response(text=self.block_msg, status=403)
+ async def sync_block_list(self):
+ while True:
+ if self.persistence:
+ if self.asyncwrite:
+ await self.save_block_list_async()
+ else:
+ self.save_block_list_sync()
+ await asyncio.sleep(900)
+
async def main():
aar = Aardvark()
@@ -340,8 +388,8 @@
print("Starting Aardvark Anti Spam Proxy")
await site.start()
print(f"Started on port {aar.port}")
- while True:
- await asyncio.sleep(60)
+ print(f"Unblock UUID: {aar.myuid}")
+ await aar.sync_block_list()
if __name__ == "__main__":