blob: 42d2c130bddb74437a98ba685153de466893e636 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
import uuid
from typing import Any, Dict, List, Optional, Tuple
import jwt
import redis
from flask import Flask, Request, Response, session
logger = logging.getLogger(__name__)
class AsyncQueryTokenException(Exception):
pass
class AsyncQueryJobException(Exception):
pass
def build_job_metadata(channel_id: str, job_id: str, **kwargs: Any) -> Dict[str, Any]:
return {
"channel_id": channel_id,
"job_id": job_id,
"user_id": session.get("user_id"),
"status": kwargs.get("status"),
"errors": kwargs.get("errors", []),
"result_url": kwargs.get("result_url"),
}
def parse_event(event_data: Tuple[str, Dict[str, Any]]) -> Dict[str, Any]:
event_id = event_data[0]
event_payload = event_data[1]["data"]
return {"id": event_id, **json.loads(event_payload)}
def increment_id(redis_id: str) -> str:
# redis stream IDs are in this format: '1607477697866-0'
try:
prefix, last = redis_id[:-1], int(redis_id[-1])
return prefix + str(last + 1)
except Exception: # pylint: disable=broad-except
return redis_id
class AsyncQueryManager:
MAX_EVENT_COUNT = 100
STATUS_PENDING = "pending"
STATUS_RUNNING = "running"
STATUS_ERROR = "error"
STATUS_DONE = "done"
def __init__(self) -> None:
super().__init__()
self._redis: redis.Redis
self._stream_prefix: str = ""
self._stream_limit: Optional[int]
self._stream_limit_firehose: Optional[int]
self._jwt_cookie_name: str
self._jwt_cookie_secure: bool = False
self._jwt_secret: str
def init_app(self, app: Flask) -> None:
config = app.config
if (
config["CACHE_CONFIG"]["CACHE_TYPE"] == "null"
or config["DATA_CACHE_CONFIG"]["CACHE_TYPE"] == "null"
):
raise Exception(
"""
Cache backends (CACHE_CONFIG, DATA_CACHE_CONFIG) must be configured
and non-null in order to enable async queries
"""
)
if len(config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]) < 32:
raise AsyncQueryTokenException(
"Please provide a JWT secret at least 32 bytes long"
)
self._redis = redis.Redis( # type: ignore
**config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True
)
self._stream_prefix = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"]
self._stream_limit = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT"]
self._stream_limit_firehose = config[
"GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE"
]
self._jwt_cookie_name = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"]
self._jwt_cookie_secure = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE"]
self._jwt_secret = config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]
@app.after_request
def validate_session( # pylint: disable=unused-variable
response: Response,
) -> Response:
reset_token = False
user_id = session["user_id"] if "user_id" in session else None
if "async_channel_id" not in session or "async_user_id" not in session:
reset_token = True
elif user_id != session["async_user_id"]:
reset_token = True
if reset_token:
async_channel_id = str(uuid.uuid4())
session["async_channel_id"] = async_channel_id
session["async_user_id"] = user_id
sub = str(user_id) if user_id else None
token = self.generate_jwt({"channel": async_channel_id, "sub": sub})
response.set_cookie(
self._jwt_cookie_name,
value=token,
httponly=True,
secure=self._jwt_cookie_secure,
# max_age=max_age or config.cookie_max_age,
# domain=config.cookie_domain,
# path=config.access_cookie_path,
# samesite=config.cookie_samesite
)
return response
def generate_jwt(self, data: Dict[str, Any]) -> str:
encoded_jwt = jwt.encode(data, self._jwt_secret, algorithm="HS256")
return encoded_jwt.decode("utf-8")
def parse_jwt(self, token: str) -> Dict[str, Any]:
data = jwt.decode(token, self._jwt_secret, algorithms=["HS256"])
return data
def parse_jwt_from_request(self, request: Request) -> Dict[str, Any]:
token = request.cookies.get(self._jwt_cookie_name)
if not token:
raise AsyncQueryTokenException("Token not preset")
try:
return self.parse_jwt(token)
except Exception as exc:
logger.warning(exc)
raise AsyncQueryTokenException("Failed to parse token")
def init_job(self, channel_id: str) -> Dict[str, Any]:
job_id = str(uuid.uuid4())
return build_job_metadata(channel_id, job_id, status=self.STATUS_PENDING)
def read_events(
self, channel: str, last_id: Optional[str]
) -> List[Optional[Dict[str, Any]]]:
stream_name = f"{self._stream_prefix}{channel}"
start_id = increment_id(last_id) if last_id else "-"
results = self._redis.xrange( # type: ignore
stream_name, start_id, "+", self.MAX_EVENT_COUNT
)
return [] if not results else list(map(parse_event, results))
def update_job(
self, job_metadata: Dict[str, Any], status: str, **kwargs: Any
) -> None:
if "channel_id" not in job_metadata:
raise AsyncQueryJobException("No channel ID specified")
if "job_id" not in job_metadata:
raise AsyncQueryJobException("No job ID specified")
updates = {"status": status, **kwargs}
event_data = {"data": json.dumps({**job_metadata, **updates})}
full_stream_name = f"{self._stream_prefix}full"
scoped_stream_name = f"{self._stream_prefix}{job_metadata['channel_id']}"
logger.debug("********** logging event data to stream %s", scoped_stream_name)
logger.debug(event_data)
self._redis.xadd( # type: ignore
scoped_stream_name, event_data, "*", self._stream_limit
)
self._redis.xadd( # type: ignore
full_stream_name, event_data, "*", self._stream_limit_firehose
)