blob: 124c73931f4e52d17f0a37db4df27a9277359b39 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
from datetime import datetime, timedelta
from functools import partial
from typing import Any, Dict, SupportsInt
from superset.jinja_context import PrestoTemplateProcessor
def DATE(
ts: datetime, day_offset: SupportsInt = 0, hour_offset: SupportsInt = 0
) -> str:
"""Current day as a string"""
day_offset, hour_offset = int(day_offset), int(hour_offset)
offset_day = (ts + timedelta(days=day_offset, hours=hour_offset)).date()
return str(offset_day)
class CustomPrestoTemplateProcessor(PrestoTemplateProcessor):
"""A custom presto template processor for test."""
engine = "db_for_macros_testing"
def process_template(self, sql: str, **kwargs) -> str:
"""Processes a sql template with $ style macro using regex."""
# Add custom macros functions.
macros = {"DATE": partial(DATE, datetime.utcnow())} # type: Dict[str, Any]
# Update with macros defined in context and kwargs.
macros.update(self._context)
macros.update(kwargs)
def replacer(match):
"""Expands $ style macros with corresponding function calls."""
macro_name, args_str = match.groups()
args = [a.strip() for a in args_str.split(",")]
if args == [""]:
args = []
f = macros[macro_name[1:]]
return f(*args)
macro_names = ["$" + name for name in macros.keys()]
pattern = r"(%s)\s*\(([^()]*)\)" % "|".join(map(re.escape, macro_names))
return re.sub(pattern, replacer, sql)