| #!/usr/bin/env python3 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| """merge_pr.py - Merge Apache Zeppelin pull requests via the GitHub API. |
| |
| Optionally cherry-picks into release branches and resolves JIRA issues. |
| No external dependencies — uses only Python 3 built-in libraries. |
| |
| Usage: |
| python3 dev/merge_pr.py --pr 5167 --dry-run |
| python3 dev/merge_pr.py --pr 5167 --resolve-jira --fix-versions 0.13.0 |
| python3 dev/merge_pr.py --pr 5167 --resolve-jira --release-branches branch-0.12 |
| """ |
| |
| import argparse |
| import json |
| import os |
| import re |
| import subprocess |
| import sys |
| import urllib.error |
| import urllib.request |
| |
| GITHUB_API_BASE = "https://api.github.com/repos/apache/zeppelin" |
| JIRA_API_BASE = "https://issues.apache.org/jira/rest/api/2" |
| |
| DEFAULT_BRANCH = "master" |
| DEFAULT_REMOTE = "apache" |
| JIRA_RESOLVE_TRANSITION = "Resolve Issue" |
| JIRA_CLOSED_STATUSES = frozenset(("Resolved", "Closed")) |
| |
| JIRA_ID_RE = re.compile(r"ZEPPELIN-\d{3,6}") |
| TITLE_FORMATTED_RE = re.compile(r"^\[ZEPPELIN-\d{3,6}](\[[A-Z0-9_\s,]+] )+\S+") |
| TITLE_REF_RE = re.compile(r"(?i)(ZEPPELIN[-\s]*\d{3,6})") |
| COMPONENT_RE = re.compile(r"(?i)(\[[\w\s,.\-]+])") |
| WHITESPACE_RE = re.compile(r"\s+") |
| LEADING_NON_WORD_RE = re.compile(r"^\W+") |
| SEMANTIC_VER_RE = re.compile(r"^\d+\.\d+\.\d+$") |
| |
| |
| class MergePR: |
| def __init__(self, args): |
| self.pr = args.pr |
| self.target = args.target or "" |
| self.fix_versions = _parse_csv(args.fix_versions) if args.fix_versions else [] |
| self.release_branches = _parse_csv(args.release_branches) if args.release_branches else [] |
| self.resolve_jira = args.resolve_jira |
| self.dry_run = args.dry_run |
| self.push_remote = args.push_remote or os.environ.get("PUSH_REMOTE_NAME", DEFAULT_REMOTE) |
| self.github_token = args.github_token or os.environ.get("GITHUB_OAUTH_KEY", "") |
| self.jira_token = args.jira_token or os.environ.get("JIRA_ACCESS_TOKEN", "") |
| |
| # ── Git ────────────────────────────────────────────────────────────── |
| |
| def _git(self, *args): |
| result = subprocess.run( |
| ["git", *args], |
| capture_output=True, text=True, |
| ) |
| if result.returncode != 0: |
| output = (result.stdout + result.stderr).strip() |
| raise RuntimeError(f"git {' '.join(args)} failed:\n{output}") |
| return result.stdout.strip() |
| |
| def _git_current_ref(self): |
| ref = self._git("rev-parse", "--abbrev-ref", "HEAD") |
| return self._git("rev-parse", "HEAD") if ref == "HEAD" else ref |
| |
| # ── HTTP ───────────────────────────────────────────────────────────── |
| |
| def _http(self, method, url, payload=None, auth=""): |
| data = json.dumps(payload).encode() if payload is not None else None |
| req = urllib.request.Request(url, data=data, method=method) |
| req.add_header("Content-Type", "application/json") |
| req.add_header("Accept", "application/json") |
| if auth: |
| req.add_header("Authorization", auth) |
| try: |
| with urllib.request.urlopen(req) as resp: |
| body = resp.read().decode() |
| return resp.status, json.loads(body) if body else {} |
| except urllib.error.HTTPError as e: |
| err_body = e.read().decode() if e.fp else "" |
| try: |
| return e.code, json.loads(err_body) if err_body else {} |
| except json.JSONDecodeError: |
| return e.code, {"error": err_body} |
| |
| # ── GitHub ─────────────────────────────────────────────────────────── |
| |
| def _gh_auth(self): |
| return f"token {self.github_token}" if self.github_token else "" |
| |
| def _gh_get_pr(self, num): |
| code, data = self._http("GET", f"{GITHUB_API_BASE}/pulls/{num}", auth=self._gh_auth()) |
| if code != 200: |
| raise RuntimeError(f"GET PR #{num}: HTTP {code}") |
| return data |
| |
| def _gh_merge_pr(self, num, title, msg): |
| payload = {"commit_title": title, "commit_message": msg, "merge_method": "squash"} |
| code, data = self._http("PUT", f"{GITHUB_API_BASE}/pulls/{num}/merge", payload, self._gh_auth()) |
| if code == 405: |
| raise RuntimeError(f"Merge PR #{num} is not allowed") |
| if code != 200: |
| raise RuntimeError(f"Merge PR #{num}: HTTP {code}") |
| return data |
| |
| def _gh_comment_pr(self, num, comment): |
| code, _ = self._http("POST", f"{GITHUB_API_BASE}/issues/{num}/comments", |
| {"body": comment}, self._gh_auth()) |
| if code != 201: |
| print(f"Warning: comment PR #{num}: HTTP {code}", file=sys.stderr) |
| |
| # ── JIRA ───────────────────────────────────────────────────────────── |
| |
| def _jira_auth(self): |
| return f"Bearer {self.jira_token}" if self.jira_token else "" |
| |
| def _jira_get_issue(self, key): |
| code, data = self._http("GET", f"{JIRA_API_BASE}/issue/{key}", auth=self._jira_auth()) |
| if code != 200: |
| raise RuntimeError(f"GET {key}: HTTP {code}") |
| return data |
| |
| def _jira_unreleased_versions(self): |
| code, data = self._http("GET", f"{JIRA_API_BASE}/project/ZEPPELIN/versions", auth=self._jira_auth()) |
| if code != 200: |
| raise RuntimeError(f"GET versions: HTTP {code}") |
| versions = [] |
| for v in data: |
| name = v.get("name", "") |
| if not v.get("released") and not v.get("archived") and SEMANTIC_VER_RE.match(name): |
| versions.append({"id": str(v["id"]), "name": name}) |
| versions.sort(key=lambda v: _ver_tuple(v["name"]), reverse=True) |
| return versions |
| |
| def _jira_transitions(self, key): |
| code, data = self._http("GET", f"{JIRA_API_BASE}/issue/{key}/transitions", auth=self._jira_auth()) |
| if code != 200: |
| raise RuntimeError(f"GET transitions {key}: HTTP {code}") |
| return [{"id": t["id"], "name": t["name"]} for t in data.get("transitions", [])] |
| |
| def _jira_resolve(self, key, transition_id, fix_ver, comment): |
| payload = { |
| "transition": {"id": transition_id}, |
| "update": { |
| "comment": [{"add": {"body": comment}}], |
| "fixVersions": [{"add": {"id": fv["id"], "name": fv["name"]}} for fv in fix_ver], |
| }, |
| } |
| code, _ = self._http("POST", f"{JIRA_API_BASE}/issue/{key}/transitions", payload, self._jira_auth()) |
| if code != 204: |
| raise RuntimeError(f"Resolve {key}: HTTP {code}") |
| |
| # ── Fix version resolution ─────────────────────────────────────────── |
| |
| def _resolve_fix_versions(self, branches, versions): |
| """Resolve fix version objects from explicit --fix-versions and branch inference. |
| |
| Returns a list of version dicts ({"id": ..., "name": ...}). |
| Raises RuntimeError if an explicit fix version is not found. |
| """ |
| vm = {v["name"]: v for v in versions} |
| fix_ver, seen = [], set() |
| |
| for fv in self.fix_versions: |
| if fv not in vm: |
| raise RuntimeError(f'fix version "{fv}" not found') |
| fix_ver.append(vm[fv]) |
| seen.add(fv) |
| |
| infer_master = not self.fix_versions |
| latest = versions[0]["name"] |
| names = [] |
| for branch in branches: |
| if branch == DEFAULT_BRANCH: |
| if infer_master and latest not in seen: |
| names.append(latest) |
| seen.add(latest) |
| else: |
| prefix = branch[len("branch-"):] if branch.startswith("branch-") else branch |
| found = [v["name"] for v in versions if v["name"].startswith(prefix + ".") or v["name"] == prefix] |
| if found: |
| pick = found[-1] # smallest matching (list is desc-sorted) |
| if pick not in seen: |
| names.append(pick) |
| seen.add(pick) |
| else: |
| print(f"Warning: no version found for {branch}, skipping", file=sys.stderr) |
| |
| # Remove redundant X.Y.0 when X.(Y-1).0 is also present |
| filtered = [] |
| for v in names: |
| parts = v.split(".") |
| if len(parts) == 3 and parts[2] == "0": |
| minor = int(parts[1]) |
| if minor > 0 and f"{parts[0]}.{minor - 1}.0" in seen: |
| continue |
| filtered.append(v) |
| |
| inferred = [vm[n] for n in filtered if n in vm] |
| if inferred: |
| print(f"Auto-inferred fix version(s): {', '.join(filtered)}") |
| fix_ver.extend(inferred) |
| return fix_ver |
| |
| # ── Effective command ──────────────────────────────────────────────── |
| |
| def _print_effective_command(self, target_branch, fix_ver): |
| parts = ["python3 dev/merge_pr.py", f"--pr {self.pr}"] |
| if target_branch and target_branch != DEFAULT_BRANCH: |
| parts.append(f"--target {target_branch}") |
| if self.release_branches: |
| parts.append(f"--release-branches {','.join(self.release_branches)}") |
| if self.resolve_jira: |
| parts.append("--resolve-jira") |
| if fix_ver: |
| parts.append(f"--fix-versions {','.join(fv['name'] for fv in fix_ver)}") |
| if self.push_remote != DEFAULT_REMOTE: |
| parts.append(f"--push-remote {self.push_remote}") |
| print(f"[dry-run] Effective command:\n {' '.join(parts)}") |
| |
| # ── Main flow ──────────────────────────────────────────────────────── |
| |
| def run(self): |
| original_head = self._git_current_ref() |
| |
| pr_data = self._gh_get_pr(self.pr) |
| if not pr_data.get("mergeable"): |
| raise RuntimeError(f"PR #{self.pr} is not mergeable") |
| pr_title = pr_data["title"] |
| if "[WIP]" in pr_title: |
| print(f"WARNING: PR title contains [WIP]: {pr_title}", file=sys.stderr) |
| |
| target_branch = self.target or pr_data["base"]["ref"] |
| title = _standardize_title(pr_title) |
| src = f"{pr_data['user']['login']}/{pr_data['head']['ref']}" |
| pr_body = pr_data.get("body", "") or "" |
| |
| print(f"=== Pull Request #{self.pr} ===") |
| print(f"title: {title}") |
| print(f"source: {src}") |
| print(f"target: {target_branch}") |
| print(f"url: {pr_data['url']}") |
| if self.release_branches: |
| print(f"release-branches: {', '.join(self.release_branches)}") |
| |
| # Resolve fix versions once (used for both dry-run display and actual JIRA resolution) |
| fix_ver = [] |
| if self.resolve_jira and self.jira_token and JIRA_ID_RE.search(title): |
| try: |
| versions = self._jira_unreleased_versions() |
| if versions: |
| branches = [target_branch] + self.release_branches |
| fix_ver = self._resolve_fix_versions(branches, versions) |
| except RuntimeError as e: |
| print(f"Warning: failed to resolve fix versions: {e}", file=sys.stderr) |
| |
| if self.dry_run: |
| print() |
| self._print_effective_command(target_branch, fix_ver) |
| return |
| |
| # Merge |
| body = pr_body.replace("@", "<at>") |
| try: |
| name = self._git("config", "--get", "user.name") |
| except RuntimeError: |
| name = "" |
| try: |
| email = self._git("config", "--get", "user.email") |
| except RuntimeError: |
| email = "" |
| msg = f"{body}\n\nCloses #{self.pr} from {src}.\n\nSigned-off-by: {name} <{email}>" |
| |
| merge_data = self._gh_merge_pr(self.pr, title, msg) |
| sha = merge_data["sha"] |
| print(f"\nPR #{self.pr} merged! (hash: {_short_sha(sha)})") |
| |
| try: |
| self._git("fetch", self.push_remote, target_branch) |
| except RuntimeError: |
| pass |
| |
| # Cherry-pick into release branches |
| merged = [target_branch] |
| for branch in self.release_branches: |
| pick = _pick_branch_name(self.pr, branch) |
| try: |
| self._git("fetch", self.push_remote, f"{branch}:{pick}") |
| except RuntimeError as e: |
| print(f"Warning: fetch {branch} failed: {e}", file=sys.stderr) |
| continue |
| self._git("checkout", pick) |
| try: |
| self._git("cherry-pick", "-sx", sha) |
| self._git("push", self.push_remote, f"{pick}:{branch}") |
| h = self._git("rev-parse", pick) |
| print(f"Picked into {branch} (hash: {_short_sha(h)})") |
| merged.append(branch) |
| except RuntimeError as e: |
| print(f"Warning: cherry-pick/push into {branch} failed: {e}", file=sys.stderr) |
| try: |
| self._git("cherry-pick", "--abort") |
| except RuntimeError: |
| pass |
| finally: |
| self._git("checkout", original_head) |
| self._git("branch", "-D", pick) |
| |
| self._comment_merge_summary(merged, sha) |
| |
| if self.resolve_jira: |
| try: |
| self._do_resolve_jira(title, fix_ver) |
| except RuntimeError as e: |
| print(f"Warning: JIRA resolution failed: {e}", file=sys.stderr) |
| |
| def _comment_merge_summary(self, merged, sha): |
| lines = [f"Merged into {merged[0]} ({_short_sha(sha)})."] |
| for branch in merged[1:]: |
| lines.append(f"Cherry-picked into {branch}.") |
| try: |
| self._gh_comment_pr(self.pr, "\n".join(lines)) |
| print("Commented on PR with merge summary.") |
| except RuntimeError as e: |
| print(f"Warning: failed to comment on PR: {e}", file=sys.stderr) |
| |
| def _do_resolve_jira(self, title, fix_ver): |
| if not self.jira_token: |
| raise RuntimeError("JIRA_ACCESS_TOKEN is not set") |
| |
| ids = JIRA_ID_RE.findall(title) |
| if not ids: |
| print("No JIRA ID found in PR title, skipping.") |
| return |
| |
| for jira_id in ids: |
| try: |
| issue = self._jira_get_issue(jira_id) |
| except RuntimeError as e: |
| print(f"Warning: get {jira_id}: {e}", file=sys.stderr) |
| continue |
| status = issue.get("fields", {}).get("status", {}).get("name", "") |
| if status in JIRA_CLOSED_STATUSES: |
| print(f'JIRA {jira_id} already "{status}", skipping.') |
| continue |
| |
| print(f"=== JIRA {jira_id} ===") |
| print(f"Summary: {issue.get('fields', {}).get('summary', '')}") |
| print(f"Status: {status}") |
| |
| transitions = self._jira_transitions(jira_id) |
| resolve_id = next((t["id"] for t in transitions if t["name"] == JIRA_RESOLVE_TRANSITION), None) |
| if not resolve_id: |
| print(f"Warning: no '{JIRA_RESOLVE_TRANSITION}' transition for {jira_id}", file=sys.stderr) |
| continue |
| |
| jira_comment = ( |
| f"Issue resolved by pull request {self.pr}" |
| f"\n[https://github.com/apache/zeppelin/pull/{self.pr}]" |
| ) |
| try: |
| self._jira_resolve(jira_id, resolve_id, fix_ver, jira_comment) |
| print(f"Resolved {jira_id}!") |
| except RuntimeError as e: |
| print(f"Warning: resolve {jira_id}: {e}", file=sys.stderr) |
| |
| |
| # ── Module-level utilities ─────────────────────────────────────────────── |
| |
| def _parse_csv(value): |
| return [s.strip() for s in value.split(",") if s.strip()] if value else [] |
| |
| |
| def _ver_tuple(v): |
| return tuple(int(x) for x in v.split(".")) |
| |
| |
| def _short_sha(sha): |
| return sha[:8] if len(sha) > 8 else sha |
| |
| |
| def _pick_branch_name(pr_num, branch): |
| return f"PR_TOOL_PICK_PR_{pr_num}_{branch.upper()}" |
| |
| |
| def _standardize_title(text): |
| text = text.rstrip(".") |
| if text.startswith('Revert "') and text.endswith('"'): |
| return text |
| if TITLE_FORMATTED_RE.match(text): |
| return text |
| |
| jira_refs = [] |
| for m in TITLE_REF_RE.finditer(text): |
| ref = m.group(1) |
| jira_refs.append("[" + WHITESPACE_RE.sub("-", ref.upper()) + "]") |
| text = text.replace(ref, "") |
| |
| components = [] |
| for m in COMPONENT_RE.finditer(text): |
| comp = m.group(1) |
| components.append(comp.upper()) |
| text = text.replace(comp, "") |
| |
| text = LEADING_NON_WORD_RE.sub("", text) |
| result = "".join(jira_refs) + "".join(components) + " " + text |
| return WHITESPACE_RE.sub(" ", result.strip()) |
| |
| |
| # ── Entry point ────────────────────────────────────────────────────────── |
| |
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Merge Apache Zeppelin pull requests", |
| usage="python3 dev/merge_pr.py [flags]", |
| ) |
| parser.add_argument("--pr", type=int, required=True, help="Pull request number") |
| parser.add_argument("--target", default="", help="Target branch (default: PR base branch)") |
| parser.add_argument("--fix-versions", default="", help="JIRA fix version(s), comma-separated") |
| parser.add_argument("--release-branches", default="", help="Release branch(es) to cherry-pick into, comma-separated") |
| parser.add_argument("--resolve-jira", action="store_true", help="Resolve associated JIRA issue(s)") |
| parser.add_argument("--dry-run", action="store_true", help="Show what would be done without making changes") |
| parser.add_argument("--push-remote", default="", help="Git remote for pushing (default: apache)") |
| parser.add_argument("--github-token", default="", help="GitHub OAuth token (env: GITHUB_OAUTH_KEY)") |
| parser.add_argument("--jira-token", default="", help="JIRA access token (env: JIRA_ACCESS_TOKEN)") |
| |
| args = parser.parse_args() |
| MergePR(args).run() |
| |
| |
| if __name__ == "__main__": |
| main() |