blob: dff7798004c00a2fe66710ab6edd14da10496a58 [file] [log] [blame]
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import atexit
import logging
import os
import pendulum
import sys
from typing import Any
from sqlalchemy import create_engine, exc
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.pool import NullPool
from airflow.configuration import conf, AIRFLOW_HOME, WEBSERVER_CONFIG # NOQA F401
from airflow.contrib.kubernetes.pod import Pod
from airflow.logging_config import configure_logging
from airflow.utils.sqlalchemy import setup_event_handlers
log = logging.getLogger(__name__)
RBAC = conf.getboolean('webserver', 'rbac')
TIMEZONE = pendulum.timezone('UTC')
try:
tz = conf.get("core", "default_timezone")
if tz == "system":
TIMEZONE = pendulum.local_timezone()
else:
TIMEZONE = pendulum.timezone(tz)
except Exception:
pass
log.info("Configured default timezone %s" % TIMEZONE)
class DummyStatsLogger(object):
@classmethod
def incr(cls, stat, count=1, rate=1):
pass
@classmethod
def decr(cls, stat, count=1, rate=1):
pass
@classmethod
def gauge(cls, stat, value, rate=1, delta=False):
pass
@classmethod
def timing(cls, stat, dt):
pass
class AllowListValidator:
def __init__(self, allow_list=None):
if allow_list:
self.allow_list = tuple([item.strip().lower() for item in allow_list.split(',')])
else:
self.allow_list = None
def test(self, stat):
if self.allow_list is not None:
return stat.strip().lower().startswith(self.allow_list)
else:
return True # default is all metrics allowed
class SafeStatsdLogger:
def __init__(self, statsd_client, allow_list_validator=AllowListValidator()):
self.statsd = statsd_client
self.allow_list_validator = allow_list_validator
def incr(self, stat, count=1, rate=1):
if self.allow_list_validator.test(stat):
return self.statsd.incr(stat, count, rate)
def decr(self, stat, count=1, rate=1):
if self.allow_list_validator.test(stat):
return self.statsd.decr(stat, count, rate)
def gauge(self, stat, value, rate=1, delta=False):
if self.allow_list_validator.test(stat):
return self.statsd.gauge(stat, value, rate, delta)
def timing(self, stat, dt):
if self.allow_list_validator.test(stat):
return self.statsd.timing(stat, dt)
Stats = DummyStatsLogger # type: Any
if conf.getboolean('scheduler', 'statsd_on'):
from statsd import StatsClient
statsd = StatsClient(
host=conf.get('scheduler', 'statsd_host'),
port=conf.getint('scheduler', 'statsd_port'),
prefix=conf.get('scheduler', 'statsd_prefix'))
allow_list_validator = AllowListValidator(conf.get('scheduler', 'statsd_allow_list', fallback=None))
Stats = SafeStatsdLogger(statsd, allow_list_validator)
else:
Stats = DummyStatsLogger
HEADER = '\n'.join([
r' ____________ _____________',
r' ____ |__( )_________ __/__ /________ __',
r'____ /| |_ /__ ___/_ /_ __ /_ __ \_ | /| / /',
r'___ ___ | / _ / _ __/ _ / / /_/ /_ |/ |/ /',
r' _/_/ |_/_/ /_/ /_/ /_/ \____/____/|__/',
])
LOGGING_LEVEL = logging.INFO
# the prefix to append to gunicorn worker processes after init
GUNICORN_WORKER_READY_PREFIX = "[ready] "
LOG_FORMAT = conf.get('core', 'log_format')
SIMPLE_LOG_FORMAT = conf.get('core', 'simple_log_format')
SQL_ALCHEMY_CONN = None
DAGS_FOLDER = None
PLUGINS_FOLDER = None
LOGGING_CLASS_PATH = None
engine = None
Session = None
def policy(task_instance):
"""
This policy setting allows altering task instances right before they
are executed. It allows administrator to rewire some task parameters.
Note that the ``TaskInstance`` object has an attribute ``task`` pointing
to its related task object, that in turns has a reference to the DAG
object. So you can use the attributes of all of these to define your
policy.
To define policy, add a ``airflow_local_settings`` module
to your PYTHONPATH that defines this ``policy`` function. It receives
a ``TaskInstance`` object and can alter it where needed.
Here are a few examples of how this can be useful:
* You could enforce a specific queue (say the ``spark`` queue)
for tasks using the ``SparkOperator`` to make sure that these
task instances get wired to the right workers
* You could force all task instances running on an
``execution_date`` older than a week old to run in a ``backfill``
pool.
* ...
"""
pass
def pod_mutation_hook(pod): # type: (Pod) -> None
"""
This setting allows altering ``Pod`` objects before they are passed to
the Kubernetes client by the ``PodLauncher`` for scheduling.
To define a pod mutation hook, add a ``airflow_local_settings`` module
to your PYTHONPATH that defines this ``pod_mutation_hook`` function.
It receives a ``Pod`` object and can alter it where needed.
This could be used, for instance, to add sidecar or init containers
to every worker pod launched by KubernetesExecutor or KubernetesPodOperator.
"""
pass
def configure_vars():
global SQL_ALCHEMY_CONN
global DAGS_FOLDER
global PLUGINS_FOLDER
SQL_ALCHEMY_CONN = conf.get('core', 'SQL_ALCHEMY_CONN')
DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
PLUGINS_FOLDER = conf.get(
'core',
'plugins_folder',
fallback=os.path.join(AIRFLOW_HOME, 'plugins')
)
def configure_orm(disable_connection_pool=False):
log.debug("Setting up DB connection pool (PID %s)" % os.getpid())
global engine
global Session
engine_args = {}
pool_connections = conf.getboolean('core', 'SQL_ALCHEMY_POOL_ENABLED')
if disable_connection_pool or not pool_connections:
engine_args['poolclass'] = NullPool
log.debug("settings.configure_orm(): Using NullPool")
elif 'sqlite' not in SQL_ALCHEMY_CONN:
# Pool size engine args not supported by sqlite.
# If no config value is defined for the pool size, select a reasonable value.
# 0 means no limit, which could lead to exceeding the Database connection limit.
pool_size = conf.getint('core', 'SQL_ALCHEMY_POOL_SIZE', fallback=5)
# The maximum overflow size of the pool.
# When the number of checked-out connections reaches the size set in pool_size,
# additional connections will be returned up to this limit.
# When those additional connections are returned to the pool, they are disconnected and discarded.
# It follows then that the total number of simultaneous connections
# the pool will allow is pool_size + max_overflow,
# and the total number of “sleeping” connections the pool will allow is pool_size.
# max_overflow can be set to -1 to indicate no overflow limit;
# no limit will be placed on the total number
# of concurrent connections. Defaults to 10.
max_overflow = conf.getint('core', 'SQL_ALCHEMY_MAX_OVERFLOW', fallback=10)
# The DB server already has a value for wait_timeout (number of seconds after
# which an idle sleeping connection should be killed). Since other DBs may
# co-exist on the same server, SQLAlchemy should set its
# pool_recycle to an equal or smaller value.
pool_recycle = conf.getint('core', 'SQL_ALCHEMY_POOL_RECYCLE', fallback=1800)
# Check connection at the start of each connection pool checkout.
# Typically, this is a simple statement like “SELECT 1”, but may also make use
# of some DBAPI-specific method to test the connection for liveness.
# More information here:
# https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic
pool_pre_ping = conf.getboolean('core', 'SQL_ALCHEMY_POOL_PRE_PING', fallback=True)
log.info("settings.configure_orm(): Using pool settings. pool_size={}, max_overflow={}, "
"pool_recycle={}, pid={}".format(pool_size, max_overflow, pool_recycle, os.getpid()))
engine_args['pool_size'] = pool_size
engine_args['pool_recycle'] = pool_recycle
engine_args['pool_pre_ping'] = pool_pre_ping
engine_args['max_overflow'] = max_overflow
# Allow the user to specify an encoding for their DB otherwise default
# to utf-8 so jobs & users with non-latin1 characters can still use
# us.
engine_args['encoding'] = conf.get('core', 'SQL_ENGINE_ENCODING', fallback='utf-8')
# For Python2 we get back a newstr and need a str
engine_args['encoding'] = engine_args['encoding'].__str__()
engine = create_engine(SQL_ALCHEMY_CONN, **engine_args)
setup_event_handlers(engine)
Session = scoped_session(
sessionmaker(autocommit=False,
autoflush=False,
bind=engine,
expire_on_commit=False))
def dispose_orm():
""" Properly close pooled database connections """
log.debug("Disposing DB connection pool (PID %s)", os.getpid())
global engine
global Session
if Session:
Session.remove()
Session = None
if engine:
engine.dispose()
engine = None
def configure_adapters():
from pendulum import Pendulum
try:
from sqlite3 import register_adapter
register_adapter(Pendulum, lambda val: val.isoformat(' '))
except ImportError:
pass
try:
import MySQLdb.converters
MySQLdb.converters.conversions[Pendulum] = MySQLdb.converters.DateTime2literal
except ImportError:
pass
try:
import pymysql.converters
pymysql.converters.conversions[Pendulum] = pymysql.converters.escape_datetime
except ImportError:
pass
def validate_session():
worker_precheck = conf.getboolean('core', 'worker_precheck', fallback=False)
if not worker_precheck:
return True
else:
check_session = sessionmaker(bind=engine)
session = check_session()
try:
session.execute("select 1")
conn_status = True
except exc.DBAPIError as err:
log.error(err)
conn_status = False
session.close()
return conn_status
def configure_action_logging():
"""
Any additional configuration (register callback) for airflow.utils.action_loggers
module
:rtype: None
"""
pass
def prepare_syspath():
"""
Ensures that certain subfolders of AIRFLOW_HOME are on the classpath
"""
if DAGS_FOLDER not in sys.path:
sys.path.append(DAGS_FOLDER)
# Add ./config/ for loading custom log parsers etc, or
# airflow_local_settings etc.
config_path = os.path.join(AIRFLOW_HOME, 'config')
if config_path not in sys.path:
sys.path.append(config_path)
if PLUGINS_FOLDER not in sys.path:
sys.path.append(PLUGINS_FOLDER)
def import_local_settings():
try:
import airflow_local_settings
if hasattr(airflow_local_settings, "__all__"):
for i in airflow_local_settings.__all__:
globals()[i] = getattr(airflow_local_settings, i)
else:
for k, v in airflow_local_settings.__dict__.items():
if not k.startswith("__"):
globals()[k] = v
log.info("Loaded airflow_local_settings from " + airflow_local_settings.__file__ + ".")
except ImportError:
log.debug("Failed to import airflow_local_settings.", exc_info=True)
def initialize():
configure_vars()
prepare_syspath()
import_local_settings()
global LOGGING_CLASS_PATH
LOGGING_CLASS_PATH = configure_logging()
configure_adapters()
# The webservers import this file from models.py with the default settings.
configure_orm()
configure_action_logging()
# Ensure we close DB connections at scheduler and gunicon worker terminations
atexit.register(dispose_orm)
# Const stuff
KILOBYTE = 1024
MEGABYTE = KILOBYTE * KILOBYTE
WEB_COLORS = {'LIGHTBLUE': '#4d9de0',
'LIGHTORANGE': '#FF9933'}
# Used by DAG context_managers
CONTEXT_MANAGER_DAG = None