blob: 2f40fb114759f50a4d87ad9a9079c702843580d3 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import codecs
import hmac
import json
import logging
import os
from typing import cast
import boto3
from chalice import BadRequestError, Chalice, ForbiddenError
from chalice.app import Request
app = Chalice(app_name='scale_out_runner')
app.log.setLevel(logging.INFO)
ASG_GROUP_NAME = os.getenv('ASG_NAME', 'AshbRunnerASG')
TABLE_NAME = os.getenv('COUNTER_TABLE', 'GithubRunnerQueue')
_commiters = set()
GH_WEBHOOK_TOKEN = None
REPOS = os.getenv('REPOS')
if REPOS:
REPO_CONFIGURATION = json.loads(REPOS)
else:
REPO_CONFIGURATION = {
# <repo>: [list-of-branches-to-use-self-hosted-on]
'apache/airflow': {'main', 'master'},
}
del REPOS
@app.route('/', methods=['POST'])
def index():
validate_gh_sig(app.current_request)
if app.current_request.headers.get('X-GitHub-Event', None) != "check_run":
# Ignore things about installs/permissions etc
return {'ignored': 'not about check_runs'}
body = app.current_request.json_body
repo = body['repository']['full_name']
sender = body['sender']['login']
# Other repos configured with this app, but we don't do anything with them
# yet.
if repo not in REPO_CONFIGURATION:
app.log.info("Ignoring event for %r", repo)
return {'ignored': 'Other repo'}
interested_branches = REPO_CONFIGURATION[repo]
branch = body['check_run']['check_suite']['head_branch']
use_self_hosted = sender in commiters() or branch in interested_branches
payload = {'sender': sender, 'use_self_hosted': use_self_hosted}
if body['action'] == 'completed' and body['check_run']['conclusion'] == 'cancelled':
if use_self_hosted:
# The only time we get a "cancelled" job is when it wasn't yet running.
queue_length = increment_dynamodb_counter(-1)
# Don't scale in the ASG -- let the CloudWatch alarm do that.
payload['new_queue'] = queue_length
else:
payload = {'ignored': 'unknown sender'}
elif body['action'] != 'created':
payload = {'ignored': "action is not 'created'"}
elif body['check_run']['status'] != 'queued':
# Skipped runs are "created", but are instantly completed. Ignore anything that is not queued
payload = {'ignored': "check_run.status is not 'queued'"}
else:
if use_self_hosted:
# Increment counter in DynamoDB
queue_length = increment_dynamodb_counter()
payload.update(**scale_asg_if_needed(queue_length))
app.log.info(
"delivery=%s branch=%s: %r",
app.current_request.headers.get('X-GitHub-Delivery', None),
branch,
payload,
)
return payload
def commiters(ssm_repo_name: str = os.getenv('SSM_REPO_NAME', 'apache/airflow')):
global _commiters
if not _commiters:
client = boto3.client('ssm')
param_path = os.path.join('/runners/', ssm_repo_name, 'configOverlay')
app.log.info("Loading config overlay from %s", param_path)
try:
resp = client.get_parameter(Name=param_path, WithDecryption=True)
except client.exceptions.ParameterNotFound:
app.log.debug("Failed to load config overlay", exc_info=True)
return set()
try:
overlay = json.loads(resp['Parameter']['Value'])
except ValueError:
app.log.debug("Failed to parse config overlay", exc_info=True)
return set()
_commiters = set(overlay['pullRequestSecurity']['allowedAuthors'])
return _commiters
def validate_gh_sig(request: Request):
sig = request.headers.get('X-Hub-Signature-256', None)
if not sig or not sig.startswith('sha256='):
raise BadRequestError('X-Hub-Signature-256 not of expected format')
sig = sig[len('sha256=') :]
calculated_sig = sign_request_body(request)
app.log.debug('Checksum verification - expected %s got %s', calculated_sig, sig)
if not hmac.compare_digest(sig, calculated_sig):
raise ForbiddenError('Spoofed request')
def sign_request_body(request: Request) -> str:
global GH_WEBHOOK_TOKEN
if GH_WEBHOOK_TOKEN is None:
if 'GH_WEBHOOK_TOKEN' in os.environ:
# Local dev support:
GH_WEBHOOK_TOKEN = os.environ['GH_WEBHOOK_TOKEN'].encode('utf-8')
else:
encrypted = os.environb[b'GH_WEBHOOK_TOKEN_ENCRYPTED']
kms = boto3.client('kms')
response = kms.decrypt(CiphertextBlob=codecs.decode(encrypted, 'base64'))
GH_WEBHOOK_TOKEN = response['Plaintext']
body = cast(bytes, request.raw_body)
return hmac.new(GH_WEBHOOK_TOKEN, body, digestmod='SHA256').hexdigest() # type: ignore
def increment_dynamodb_counter(delta: int = 1) -> int:
dynamodb = boto3.client('dynamodb')
args = dict(
TableName=TABLE_NAME,
Key={'id': {'S': 'queued_jobs'}},
ExpressionAttributeValues={':delta': {'N': str(delta)}},
UpdateExpression='ADD queued :delta',
ReturnValues='UPDATED_NEW',
)
if delta < 0:
# Make sure it never goes below zero!
args['ExpressionAttributeValues'][':limit'] = {'N': str(-delta)}
args['ConditionExpression'] = 'queued >= :limit'
resp = dynamodb.update_item(**args)
return int(resp['Attributes']['queued']['N'])
def scale_asg_if_needed(num_queued_jobs: int) -> dict:
asg = boto3.client('autoscaling')
resp = asg.describe_auto_scaling_groups(
AutoScalingGroupNames=[ASG_GROUP_NAME],
)
asg_info = resp['AutoScalingGroups'][0]
current = asg_info['DesiredCapacity']
max_size = asg_info['MaxSize']
busy = 0
for instance in asg_info['Instances']:
if instance['LifecycleState'] == 'InService' and instance['ProtectedFromScaleIn']:
busy += 1
app.log.info("Busy instances: %d, num_queued_jobs: %d, current_size: %d", busy, num_queued_jobs, current)
new_size = num_queued_jobs + busy
if new_size > current:
if new_size <= max_size or current < max_size:
try:
new_size = min(new_size, max_size)
asg.set_desired_capacity(AutoScalingGroupName=ASG_GROUP_NAME, DesiredCapacity=new_size)
return {'new_capcity': new_size}
except asg.exceptions.ScalingActivityInProgressFault as e:
return {'error': str(e)}
else:
return {'capacity_at_max': True}
else:
return {'idle_instances': True}