blob: 4c953555fde94afc9743b54aa372ea6222c039c2 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import argparse
import json
import os
import re
import subprocess
from typing import List
from urllib.request import Request, urlopen
# Define patterns for each group of files you're interested in
PATTERNS = {
"python": [
r"^\.github/workflows/.*python",
r"^tests/",
r"^superset/",
r"^scripts/",
r"^setup\.py",
r"^requirements/.+\.txt",
r"^.pylintrc",
],
"frontend": [
r"^\.github/workflows/.*(bashlib|frontend|e2e)",
r"^superset-frontend/",
],
"docker": [
r"^Dockerfile$",
r"^docker/",
],
"docs": [
r"^docs/",
],
}
GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN")
def fetch_files_github_api(url: str): # type: ignore
"""Fetches data using GitHub API."""
req = Request(url)
req.add_header("Authorization", f"token {GITHUB_TOKEN}")
req.add_header("Accept", "application/vnd.github.v3+json")
print(f"Fetching from {url}")
with urlopen(req) as response:
body = response.read()
return json.loads(body)
def fetch_changed_files_pr(repo: str, pr_number: str) -> List[str]:
"""Fetches files changed in a PR using the GitHub API."""
url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/files"
files = fetch_files_github_api(url)
return [file_info["filename"] for file_info in files]
def fetch_changed_files_push(repo: str, sha: str) -> List[str]:
"""Fetches files changed in the last commit for push events using GitHub API."""
# Fetch commit details to get the parent SHA
commit_url = f"https://api.github.com/repos/{repo}/commits/{sha}"
commit_data = fetch_files_github_api(commit_url)
if "parents" not in commit_data or len(commit_data["parents"]) < 1:
raise RuntimeError("No parent commit found for comparison.")
parent_sha = commit_data["parents"][0]["sha"]
# Compare the current commit against its parent
compare_url = f"https://api.github.com/repos/{repo}/compare/{parent_sha}...{sha}"
comparison_data = fetch_files_github_api(compare_url)
return [file["filename"] for file in comparison_data["files"]]
def detect_changes(files: List[str], check_patterns: List) -> bool: # type: ignore
"""Detects if any of the specified files match the provided patterns."""
for file in files:
for pattern in check_patterns:
if re.match(pattern, file):
return True
return False
def print_files(files: List[str]) -> None:
print("\n".join([f"- {s}" for s in files]))
def main(event_type: str, sha: str, repo: str) -> None:
"""Main function to check for file changes based on event context."""
print("SHA:", sha)
if event_type == "pull_request":
pr_number = os.getenv("GITHUB_REF", "").split("/")[-2]
files = fetch_changed_files_pr(repo, pr_number)
print("PR files:")
print_files(files)
elif event_type == "push":
files = fetch_changed_files_push(repo, sha)
print("Files touched since previous commit:")
print_files(files)
else:
raise ValueError("Unsupported event type")
changes_detected = {}
for group, regex_patterns in PATTERNS.items():
patterns_compiled = [re.compile(p) for p in regex_patterns]
changes_detected[group] = detect_changes(files, patterns_compiled)
# Output results
output_path = os.getenv("GITHUB_OUTPUT") or "/tmp/GITHUB_OUTPUT.txt"
with open(output_path, "a") as f:
for check, changed in changes_detected.items():
if changed:
print(f"{check}={str(changed).lower()}", file=f)
print(f"Triggering group: {check}")
def get_git_sha() -> str:
return os.getenv("GITHUB_SHA") or subprocess.check_output(
["git", "rev-parse", "HEAD"]
).strip().decode("utf-8")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Detect file changes based on event context"
)
parser.add_argument(
"--event-type",
default=os.getenv("GITHUB_EVENT_NAME") or "push",
help="The type of event that triggered the workflow",
)
parser.add_argument(
"--sha",
default=get_git_sha(),
help="The commit SHA for push events or PR head SHA",
)
parser.add_argument(
"--repo",
default=os.getenv("GITHUB_REPOSITORY") or "apache/superset",
help="GitHub repository in the format owner/repo",
)
args = parser.parse_args()
main(args.event_type, args.sha, args.repo)