Add persistent storage of block list, allow for unblocking via CLI
diff --git a/aardvark.py b/aardvark.py
index 44ac18c..ff6e4eb 100644
--- a/aardvark.py
+++ b/aardvark.py
@@ -30,11 +30,11 @@
 import spamfilter
 import aiofile
 import platform
+import datetime
 
 # Shadow print with our syslog wrapper
 print = asfpy.syslog.Printer(stdout=True, identity="aardvark")
 
-
 # Some defaults to keep this running without a yaml
 DEFAULT_PORT = 1729
 DEFAULT_BACKEND = "http://localhost:8080"
@@ -45,6 +45,7 @@
 DEFAULT_NAIVE = True
 DEFAULT_SPAM_NAIVE_THRESHOLD = 60
 MINIMUM_SCAN_LENGTH = 16  # We don't normally scan form data elements with fewer than 16 chars
+BLOCKFILE = "blocklist.txt"
 
 
 class Aardvark:
@@ -63,7 +64,9 @@
 
         # Init vars with defaults
         self.config = {}  # Our config, unless otherwise specified in init
+        self.myuid = str(uuid.uuid4())
         self.debug = False  # Debug prints, spammy!
+        self.persistence = False  # Persistent block list
         self.block_msg = DEFAULT_BLOCK_MSG
         self.proxy_url = DEFAULT_BACKEND  # Backend URL to proxy to
         self.port = DEFAULT_PORT  # Port we listen on
@@ -82,7 +85,7 @@
         self.naive_threshold = DEFAULT_SPAM_NAIVE_THRESHOLD
         self.enable_naive = DEFAULT_NAIVE
         self.lock = asyncio.Lock()
-        
+
         if platform.system() == 'Linux':
             major, minor, _ = platform.release().split('.', 2)
             if major > "4" or (major >= "4" and minor >= "18"):
@@ -101,6 +104,7 @@
             self.ipheader = self.config.get("ipheader", self.ipheader)
             self.savepath = self.config.get("savedata", self.savepath)
             self.asyncwrite = self.config.get("asyncwrite", self.asyncwrite)
+            self.persistence = self.config.get("persistence", self.persistence)
             self.block_msg = self.config.get("spam_response", self.block_msg)
             self.enable_naive = self.config.get("enable_naive_scan", self.enable_naive)
             self.naive_threshold = self.config.get("naive_spam_threshold", self.naive_threshold)
@@ -119,10 +123,33 @@
                 for req in multimatch.get("auxiliary", []):
                     r = re.compile(bytes(req, encoding="utf-8"), flags=re.IGNORECASE)
                     self.multispam_auxiliary.add(r)
+        if self.persistence:
+            if os.path.exists(BLOCKFILE):
+                offenders = 0
+                with open(BLOCKFILE, "r") as bl:
+                    for line in bl:
+                        if line.strip() and not line.startswith("#"):
+                            offenders += 1
+                            self.offenders.add(line.strip())
+                    print(f"Loaded {offenders} offenders from persistent storage.")
+
         if self.enable_naive:
             print("Loading Naïve Bayesian spam filter...")
             self.spamfilter = spamfilter.BayesScanner()
-    
+
+    async def save_block_list_async(self):
+        async with aiofile.async_open(BLOCKFILE, "w") as f:
+            bl = f"# Block list generated at {datetime.datetime.now().isoformat()}\n# UUID: {self.myuid}\n"
+            bl += "\n".join(self.offenders)
+            await f.write(bl)
+
+    def save_block_list_sync(self):
+        with open(BLOCKFILE, "w") as f:
+            bl = f"# Block list generated at {datetime.datetime.now().isoformat()}\n# UUID: {self.myuid}\n"
+            bl += "\n".join(self.offenders)
+            f.write(bl)
+
+
     async def save_request_data(
             self, request: aiohttp.web.Request, remote_ip: str, post: typing.Union[multidict.MultiDictProxy, bytes]
     ):
@@ -196,7 +223,8 @@
                 if self.enable_naive:
                     res = self.spamfilter.scan_text(v)
                     if res >= self.naive_threshold:
-                        bad_items.append(f"Form element {k} has spam score of {res}, crosses threshold of {self.naive_threshold}!")
+                        bad_items.append(
+                            f"Form element {k} has spam score of {res}, crosses threshold of {self.naive_threshold}!")
         return bad_items
 
     async def proxy(self, request: aiohttp.web.Request):
@@ -211,6 +239,16 @@
         if self.debug:
             print(f"Proxying request to {target_url}...")  # This can get spammy, default is to not show it.
 
+        if request.path == '/aardvark-unblock':
+            ip = request.query_string
+            theiruid = request.headers.get('X-Aardvark-Key', '')
+            if theiruid == self.myuid:
+                if ip in self.offenders:
+                    self.offenders.remove(ip)
+                    print(f"Removed IP {ip} from block list.")
+                    return aiohttp.web.Response(text="Block removed", status=200)
+            return aiohttp.web.Response(text="No such block", status=404)
+
         # Debug output for syslog
         self.last_batches.append(time.time())
         if len(self.last_batches) >= 5000:
@@ -285,16 +323,17 @@
                         del req_headers["content-type"]
                     for k, v in post_dict.items():
                         if isinstance(v, aiohttp.web.FileField):  # This sets multipart properly in the request
-                            form_data.add_field(name=v.name, filename=v.filename, value=v.file.raw, content_type=v.content_type)
+                            form_data.add_field(name=v.name, filename=v.filename, value=v.file.raw,
+                                                content_type=v.content_type)
                         else:
                             form_data.add_field(name=k, value=v)
                 async with session.request(
-                    request.method,
-                    target_url,
-                    headers=req_headers,
-                    params=get_data,
-                    data=form_data or post_data,
-                    timeout=30,
+                        request.method,
+                        target_url,
+                        headers=req_headers,
+                        params=get_data,
+                        data=form_data or post_data,
+                        timeout=30,
                 ) as resp:
                     result = resp
                     headers = result.headers.copy()
@@ -328,6 +367,15 @@
         self.processing_times.append(time.time() - now)
         return aiohttp.web.Response(text=self.block_msg, status=403)
 
+    async def sync_block_list(self):
+        while True:
+            if self.persistence:
+                if self.asyncwrite:
+                    await self.save_block_list_async()
+                else:
+                    self.save_block_list_sync()
+            await asyncio.sleep(900)
+
 
 async def main():
     aar = Aardvark()
@@ -340,8 +388,8 @@
     print("Starting Aardvark Anti Spam Proxy")
     await site.start()
     print(f"Started on port {aar.port}")
-    while True:
-        await asyncio.sleep(60)
+    print(f"Unblock UUID: {aar.myuid}")
+    await aar.sync_block_list()
 
 
 if __name__ == "__main__":