blob: 7f1af1f1003ae77033058f85971cbe3a92d2d8be [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Database sub-commands"""
import os
import subprocess
import textwrap
from tempfile import NamedTemporaryFile
from airflow import settings
from airflow.exceptions import AirflowException
from airflow.utils import cli as cli_utils, db
def initdb(args):
"""Initializes the metadata database"""
print("DB: " + repr(settings.engine.url))
db.initdb()
print("Done.")
def resetdb(args):
"""Resets the metadata database"""
print("DB: " + repr(settings.engine.url))
if args.yes or input("This will drop existing tables "
"if they exist. Proceed? "
"(y/n)").upper() == "Y":
db.resetdb()
else:
print("Bail.")
@cli_utils.action_logging
def upgradedb(args):
"""Upgrades the metadata database"""
print("DB: " + repr(settings.engine.url))
db.upgradedb()
@cli_utils.action_logging
def shell(args):
"""Run a shell that allows to access metadata database"""
url = settings.engine.url
print("DB: " + repr(url))
if url.get_backend_name() == 'mysql':
with NamedTemporaryFile(suffix="my.cnf") as f:
content = textwrap.dedent(f"""
[client]
host = {url.host}
user = {url.username}
password = {url.password or ""}
port = {url.port or ""}
database = {url.database}
""").strip()
f.write(content.encode())
f.flush()
subprocess.Popen(["mysql", f"--defaults-extra-file={f.name}"]).wait()
elif url.get_backend_name() == 'sqlite':
subprocess.Popen(["sqlite3", url.database]).wait()
elif url.get_backend_name() == 'postgresql':
env = os.environ.copy()
env['PGHOST'] = url.host or ""
env['PGPORT'] = str(url.port or "")
env['PGUSER'] = url.username or ""
# PostgreSQL does not allow the use of PGPASSFILE if the current user is root.
env["PGPASSWORD"] = url.password or ""
env['PGDATABASE'] = url.database
subprocess.Popen(["psql"], env=env).wait()
else:
raise AirflowException(f"Unknown driver: {url.drivername}")
@cli_utils.action_logging
def check(_):
"""Runs a check command that checks if db is available."""
db.check()