blob: f54e00f488982ac79fedcc7abdcba6675a90e343 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from functools import wraps
from typing import TYPE_CHECKING, Callable, TypeVar, cast
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException
from airflow.metrics.protocols import DeltaType, Timer, TimerProtocol
from airflow.metrics.validators import (
AllowListValidator,
BlockListValidator,
ListValidator,
validate_stat,
)
if TYPE_CHECKING:
from statsd import StatsClient
T = TypeVar("T", bound=Callable)
log = logging.getLogger(__name__)
def prepare_stat_with_tags(fn: T) -> T:
"""Add tags to stat with influxdb standard format if influxdb_tags_enabled is True."""
@wraps(fn)
def wrapper(
self, stat: str | None = None, *args, tags: dict[str, str] | None = None, **kwargs
) -> Callable[[str], str]:
if self.influxdb_tags_enabled:
if stat is not None and tags is not None:
for k, v in tags.items():
if self.metric_tags_validator.test(k):
if all(c not in [",", "="] for c in v + k):
stat += f",{k}={v}"
else:
log.error("Dropping invalid tag: %s=%s.", k, v)
return fn(self, stat, *args, tags=tags, **kwargs)
return cast(T, wrapper)
class SafeStatsdLogger:
"""StatsD Logger."""
def __init__(
self,
statsd_client: StatsClient,
metrics_validator: ListValidator = AllowListValidator(),
influxdb_tags_enabled: bool = False,
metric_tags_validator: ListValidator = AllowListValidator(),
) -> None:
self.statsd = statsd_client
self.metrics_validator = metrics_validator
self.influxdb_tags_enabled = influxdb_tags_enabled
self.metric_tags_validator = metric_tags_validator
@prepare_stat_with_tags
@validate_stat
def incr(
self,
stat: str,
count: int = 1,
rate: float = 1,
*,
tags: dict[str, str] | None = None,
) -> None:
"""Increment stat."""
if self.metrics_validator.test(stat):
return self.statsd.incr(stat, count, rate)
return None
@prepare_stat_with_tags
@validate_stat
def decr(
self,
stat: str,
count: int = 1,
rate: float = 1,
*,
tags: dict[str, str] | None = None,
) -> None:
"""Decrement stat."""
if self.metrics_validator.test(stat):
return self.statsd.decr(stat, count, rate)
return None
@prepare_stat_with_tags
@validate_stat
def gauge(
self,
stat: str,
value: int | float,
rate: float = 1,
delta: bool = False,
*,
tags: dict[str, str] | None = None,
) -> None:
"""Gauge stat."""
if self.metrics_validator.test(stat):
return self.statsd.gauge(stat, value, rate, delta)
return None
@prepare_stat_with_tags
@validate_stat
def timing(
self,
stat: str,
dt: DeltaType,
*,
tags: dict[str, str] | None = None,
) -> None:
"""Stats timing."""
if self.metrics_validator.test(stat):
return self.statsd.timing(stat, dt)
return None
@prepare_stat_with_tags
@validate_stat
def timer(
self,
stat: str | None = None,
*args,
tags: dict[str, str] | None = None,
**kwargs,
) -> TimerProtocol:
"""Timer metric that can be cancelled."""
if stat and self.metrics_validator.test(stat):
return Timer(self.statsd.timer(stat, *args, **kwargs))
return Timer()
def get_statsd_logger(cls) -> SafeStatsdLogger:
"""Returns logger for StatsD."""
# no need to check for the scheduler/statsd_on -> this method is only called when it is set
# and previously it would crash with None is callable if it was called without it.
from statsd import StatsClient
stats_class = conf.getimport("metrics", "statsd_custom_client_path", fallback=None)
metrics_validator: ListValidator
if stats_class:
if not issubclass(stats_class, StatsClient):
raise AirflowConfigException(
"Your custom StatsD client must extend the statsd.StatsClient in order to ensure "
"backwards compatibility."
)
else:
log.info("Successfully loaded custom StatsD client")
else:
stats_class = StatsClient
statsd = stats_class(
host=conf.get("metrics", "statsd_host"),
port=conf.getint("metrics", "statsd_port"),
prefix=conf.get("metrics", "statsd_prefix"),
)
if conf.get("metrics", "metrics_allow_list", fallback=None):
metrics_validator = AllowListValidator(conf.get("metrics", "metrics_allow_list"))
if conf.get("metrics", "metrics_block_list", fallback=None):
log.warning(
"Ignoring metrics_block_list as both metrics_allow_list "
"and metrics_block_list have been set"
)
elif conf.get("metrics", "metrics_block_list", fallback=None):
metrics_validator = BlockListValidator(conf.get("metrics", "metrics_block_list"))
else:
metrics_validator = AllowListValidator()
influxdb_tags_enabled = conf.getboolean("metrics", "statsd_influxdb_enabled", fallback=False)
metric_tags_validator = BlockListValidator(conf.get("metrics", "statsd_disabled_tags", fallback=None))
return SafeStatsdLogger(statsd, metrics_validator, influxdb_tags_enabled, metric_tags_validator)