blob: aceb279da5647cc205b84b7dea7dc7f7fcf75bca [file] [log] [blame]
package main
import (
"context"
"fmt"
"os"
"regexp"
"strings"
"github.com/google/go-github/v45/github"
"github.com/sethvargo/go-githubactions"
"golang.org/x/oauth2"
"github.com/apache/pulsar-test-infra/docbot/pkg/logger"
)
const (
MessageLabelMissing = `Please provide a correct documentation label for your PR.
Instructions see [Pulsar Documentation Label Guide](https://docs.google.com/document/d/1Qw7LHQdXWBW9t2-r-A7QdFDBwmZh6ytB4guwMoXHqc0).`
MessageLabelMultiple = `Please select only one documentation label for your PR.
Instructions see [Pulsar Documentation Label Guide](https://docs.google.com/document/d/1Qw7LHQdXWBW9t2-r-A7QdFDBwmZh6ytB4guwMoXHqc0).`
)
type ActionConfig struct {
token *string
repo *string
owner *string
number *int
labelPattern *string
labelWatchSet map[string]struct{}
labelMissing *string
enableLabelMissing *bool
enableLabelMultiple *bool
// labels extracted from PR body
labels map[string]bool
}
func NewActionConfig() (*ActionConfig, error) {
ownerRepoSlug := os.Getenv("GITHUB_REPOSITORY")
ownerRepo := strings.Split(ownerRepoSlug, "/")
if len(ownerRepo) != 2 {
return nil, fmt.Errorf("GITHUB_REPOSITORY is not found")
}
owner, repo := ownerRepo[0], ownerRepo[1]
token := os.Getenv("GITHUB_TOKEN")
labelPattern := os.Getenv("LABEL_PATTERN")
if len(labelPattern) == 0 {
labelPattern = "- \\[(.*?)\\] ?`(.+?)`"
}
labelWatchListSlug := os.Getenv("LABEL_WATCH_LIST")
labelWatchList := strings.Split(strings.TrimSpace(labelWatchListSlug), ",")
labelWatchSet := make(map[string]struct{})
for _, l := range labelWatchList {
labelWatchSet[l] = struct{}{}
}
enableLabelMissingSlug := os.Getenv("ENABLE_LABEL_MISSING")
enableLabelMissing := true
if enableLabelMissingSlug == "false" {
enableLabelMissing = false
}
labelMissing := os.Getenv("LABEL_MISSING")
if len(labelMissing) == 0 {
labelMissing = "label-missing"
}
enableLabelMultipleSlug := os.Getenv("ENABLE_LABEL_MULTIPLE")
enableLabelMultiple := false
if enableLabelMultipleSlug == "true" {
enableLabelMultiple = true
}
return &ActionConfig{
token: &token,
repo: &repo,
owner: &owner,
labelPattern: &labelPattern,
labelWatchSet: labelWatchSet,
labelMissing: &labelMissing,
enableLabelMissing: &enableLabelMissing,
enableLabelMultiple: &enableLabelMultiple,
}, nil
}
func (ac *ActionConfig) GetToken() string {
if ac == nil || ac.token == nil {
return ""
}
return *ac.token
}
func (ac *ActionConfig) GetOwner() string {
if ac == nil || ac.owner == nil {
return ""
}
return *ac.owner
}
func (ac *ActionConfig) GetRepo() string {
if ac == nil || ac.repo == nil {
return ""
}
return *ac.repo
}
func (ac *ActionConfig) GetNumber() int {
if ac == nil || ac.number == nil {
return 0
}
return *ac.number
}
func (ac *ActionConfig) GetLabelPattern() string {
if ac == nil || ac.labelPattern == nil {
return ""
}
return *ac.labelPattern
}
func (ac *ActionConfig) GetLabelMissing() string {
if ac == nil || ac.labelMissing == nil {
return ""
}
return *ac.labelMissing
}
func (ac *ActionConfig) GetEnableLabelMissing() bool {
if ac == nil || ac.enableLabelMissing == nil {
return false
}
return *ac.enableLabelMissing
}
func (ac *ActionConfig) GetEnableLabelMultiple() bool {
if ac == nil || ac.enableLabelMultiple == nil {
return false
}
return *ac.enableLabelMultiple
}
type Action struct {
config *ActionConfig
globalContext context.Context
client *github.Client
// opened, edited, labeled, unlabeled
event string
}
func NewAction(ac *ActionConfig) *Action {
ctx := context.Background()
ts := oauth2.StaticTokenSource(
&oauth2.Token{AccessToken: ac.GetToken()},
)
tc := oauth2.NewClient(ctx, ts)
return &Action{
config: ac,
globalContext: ctx,
client: github.NewClient(tc),
}
}
func (a *Action) Run(actionType string) error {
a.event = actionType
switch actionType {
case "opened", "edited":
return a.onPullRequestOpenedOrEdited()
case "labeled", "unlabeled":
return a.onPullRequestLabeledOrUnlabeled()
}
return nil
}
func (a *Action) onPullRequestOpenedOrEdited() error {
pr, _, err := a.client.PullRequests.Get(a.globalContext, a.config.GetOwner(), a.config.GetRepo(), a.config.GetNumber())
if err != nil {
return fmt.Errorf("get PR: %v", err)
}
// Get repo labels
logger.Infoln("@List repo labels")
repoLabels, err := a.getRepoLabels()
if err != nil {
return fmt.Errorf("list repo labels: %v", err)
}
logger.Infof("Repo labels: %v\n", a.labelsToString(repoLabels))
repoLabelsSet := make(map[string]struct{})
for _, label := range repoLabels {
repoLabelsSet[label.GetName()] = struct{}{}
}
// Get current labels on this PR
logger.Infoln("@List issue labels")
issueLabels, err := a.getIssueLabels()
if err != nil {
return fmt.Errorf("list current issue labels: %v", err)
}
logger.Infof("Issue labels: %v\n", a.labelsToString(issueLabels))
// Get the intersection of issueLabels and labelWatchSet, including labelMissing
logger.Infoln("@List current labels")
currentLabelsSet := make(map[string]struct{})
for _, label := range issueLabels {
if _, exist := a.config.labelWatchSet[label.GetName()]; !exist && label.GetName() != a.config.GetLabelMissing() {
continue
}
currentLabelsSet[label.GetName()] = struct{}{}
}
logger.Infof("Current labels: %v\n", a.labelsSetToString(currentLabelsSet))
// Get expected labels
// Only handle labels already exist in repo
logger.Infoln("@List expected labels")
expectedLabelsMap := make(map[string]bool)
for label, checked := range a.config.labels {
if _, exist := repoLabelsSet[label]; !exist {
logger.Infof("Found label %v not exist int repo\n", label)
continue
}
expectedLabelsMap[label] = checked
}
logger.Infof("Expected labels: %v\n", expectedLabelsMap)
// Remove labels
logger.Infoln("@Remove labels")
labelsToRemove := make(map[string]struct{})
if len(expectedLabelsMap) == 0 { // Remove current labels when PR body is empty
for l := range a.config.labelWatchSet {
if _, exist := currentLabelsSet[l]; exist {
labelsToRemove[l] = struct{}{}
}
}
} else {
for label := range currentLabelsSet {
if label == a.config.GetLabelMissing() {
continue
}
if checked, exist := expectedLabelsMap[label]; exist && checked {
continue
}
labelsToRemove[label] = struct{}{}
}
}
// Remove missing label
checkedCount := 0
for _, checked := range expectedLabelsMap {
if checked {
checkedCount++
}
}
if !a.config.GetEnableLabelMultiple() && checkedCount > 1 {
logger.Infoln("Multiple labels detected")
_, _, err = a.client.Issues.CreateComment(a.globalContext,
a.config.GetOwner(), a.config.GetRepo(), a.config.GetNumber(),
&github.IssueComment{
Body: func(v string) *string { return &v }(fmt.Sprintf("@%s %s", pr.User.GetLogin(), MessageLabelMultiple))})
if err != nil {
return fmt.Errorf("create issue comment: %v", err)
}
return fmt.Errorf("%s", MessageLabelMultiple)
}
if _, exist := currentLabelsSet[a.config.GetLabelMissing()]; exist && checkedCount > 0 {
labelsToRemove[a.config.GetLabelMissing()] = struct{}{}
}
logger.Infof("Labels to remove: %v\n", a.labelsSetToString(labelsToRemove))
for label := range labelsToRemove {
_, err := a.client.Issues.RemoveLabelForIssue(a.globalContext, a.config.GetOwner(), a.config.GetRepo(), a.config.GetNumber(), label)
if err != nil {
return fmt.Errorf("remove label %v: %v", label, err)
}
}
// Add labels
logger.Infoln("@Add labels")
labelsToAdd := []string{}
for label, checked := range expectedLabelsMap {
if !checked {
continue
}
if _, exist := currentLabelsSet[label]; !exist {
labelsToAdd = append(labelsToAdd, label)
}
}
if len(labelsToAdd) == 0 {
logger.Infoln("No labels to add.")
} else {
logger.Infof("Labels to add: %v\n", labelsToAdd)
_, _, err = a.client.Issues.AddLabelsToIssue(a.globalContext, a.config.GetOwner(), a.config.GetRepo(), a.config.GetNumber(), labelsToAdd)
if err != nil {
logger.Infof("Add labels %v: %v\n", labelsToAdd, err)
}
}
// Add missing label
if a.config.GetEnableLabelMissing() && checkedCount == 0 {
logger.Infoln("@Add missing label")
_, _, err = a.client.Issues.AddLabelsToIssue(a.globalContext,
a.config.GetOwner(), a.config.GetRepo(), a.config.GetNumber(),
[]string{a.config.GetLabelMissing()})
if err != nil {
return fmt.Errorf("add missing label %v: %v", a.config.GetLabelMissing(), err)
}
_, _, err = a.client.Issues.CreateComment(a.globalContext,
a.config.GetOwner(), a.config.GetRepo(), a.config.GetNumber(),
&github.IssueComment{
Body: func(v string) *string { return &v }(fmt.Sprintf("@%s %s", pr.User.GetLogin(), MessageLabelMissing))})
if err != nil {
logger.Infof("Create issue comment: %v\n", err)
}
return fmt.Errorf("%s", MessageLabelMissing)
}
return nil
}
func (a *Action) onPullRequestLabeledOrUnlabeled() error {
pr, _, err := a.client.PullRequests.Get(a.globalContext, a.config.GetOwner(), a.config.GetRepo(), a.config.GetNumber())
if err != nil {
return fmt.Errorf("get PR: %v", err)
}
// Get repo labels
logger.Infoln("@List repo labels")
repoLabels, err := a.getRepoLabels()
if err != nil {
return fmt.Errorf("list repo labels: %v", err)
}
logger.Infof("Repo labels: %v\n", a.labelsToString(repoLabels))
repoLabelsSet := make(map[string]struct{})
for _, label := range repoLabels {
repoLabelsSet[label.GetName()] = struct{}{}
}
// Get current labels on this PR
logger.Infoln("@List issue labels")
issueLabels, err := a.getIssueLabels()
if err != nil {
return fmt.Errorf("list current issue labels: %v", err)
}
logger.Infof("Issue labels: %v\n", a.labelsToString(issueLabels))
// Get the intersection of issueLabels and labelWatchSet, including labelMissing
logger.Infoln("@List current labels")
currentLabelsSet := make(map[string]struct{})
for _, label := range issueLabels {
if _, exist := a.config.labelWatchSet[label.GetName()]; !exist && label.GetName() != a.config.GetLabelMissing() {
continue
}
currentLabelsSet[label.GetName()] = struct{}{}
}
logger.Infof("Current labels: %v\n", a.labelsSetToString(currentLabelsSet))
// Get expected labels
// Only handle labels already exist in repo
logger.Infoln("@List expected labels")
expectedLabelsMap := make(map[string]bool)
for label, checked := range a.config.labels {
if _, exist := repoLabelsSet[label]; !exist {
logger.Infof("Found label %v not exist int repo\n", label)
continue
}
expectedLabelsMap[label] = checked
}
logger.Infof("Expected labels: %v\n", expectedLabelsMap)
// Remove missing label
labelsToRemove := make(map[string]struct{})
checkedCount := 0
for label := range currentLabelsSet {
if label != a.config.GetLabelMissing() {
checkedCount++
}
}
if !a.config.GetEnableLabelMultiple() && checkedCount > 1 {
logger.Infoln("Multiple labels detected")
_, _, err = a.client.Issues.CreateComment(a.globalContext,
a.config.GetOwner(), a.config.GetRepo(), a.config.GetNumber(),
&github.IssueComment{
Body: func(v string) *string { return &v }(fmt.Sprintf("@%s %s", pr.User.GetLogin(), MessageLabelMultiple))})
if err != nil {
return fmt.Errorf("create issue comment: %v", err)
}
return fmt.Errorf("%s", MessageLabelMultiple)
}
if _, exist := currentLabelsSet[a.config.GetLabelMissing()]; exist && checkedCount > 0 {
labelsToRemove[a.config.GetLabelMissing()] = struct{}{}
}
logger.Infof("Labels to remove: %v\n", labelsToRemove)
for label := range labelsToRemove {
_, err := a.client.Issues.RemoveLabelForIssue(a.globalContext, a.config.GetOwner(), a.config.GetRepo(), a.config.GetNumber(), label)
if err != nil {
return fmt.Errorf("remove label %v: %v", label, err)
}
}
// Add missing label
if a.config.GetEnableLabelMissing() && checkedCount == 0 {
logger.Infoln("@Add missing label")
_, _, err = a.client.Issues.AddLabelsToIssue(a.globalContext,
a.config.GetOwner(), a.config.GetRepo(), a.config.GetNumber(),
[]string{a.config.GetLabelMissing()})
if err != nil {
return fmt.Errorf("add missing label %v: %v", a.config.GetLabelMissing(), err)
}
_, _, err = a.client.Issues.CreateComment(a.globalContext,
a.config.GetOwner(), a.config.GetRepo(), a.config.GetNumber(),
&github.IssueComment{
Body: func(v string) *string { return &v }(fmt.Sprintf("@%s %s", pr.User.GetLogin(), MessageLabelMissing))})
if err != nil {
logger.Infof("Create issue comment: %v\n", err)
}
return fmt.Errorf("%s", MessageLabelMissing)
}
// Update PR Body
// Compare current labels and expected labels
if a.event == "unlabeled" {
return nil
}
changeList := make(map[string]bool)
for label := range currentLabelsSet {
if checked, exist := expectedLabelsMap[label]; exist && checked {
continue
}
// If not exist, need to add
// If exist but not checked, need to update
changeList[label] = true
}
for label, checked := range expectedLabelsMap {
if _, exist := currentLabelsSet[label]; !exist && checked {
changeList[label] = false
}
}
body := pr.GetBody()
for label, checked := range changeList {
src := fmt.Sprintf("- [ ] `%s`", label)
dst := fmt.Sprintf("- [x] `%s`", label)
if !checked {
src = fmt.Sprintf("- [x] `%s`", label)
dst = fmt.Sprintf("- [ ] `%s`", label)
}
if strings.Contains(body, src) { // Update the label
body = strings.Replace(body, src, dst, 1)
} else { // Add the label
body = fmt.Sprintf("%s\r\n%s\r\n", body, dst)
}
}
if len(changeList) > 0 {
logger.Infoln("@Update PR body")
logger.Infof("ChangeList: %v\n", changeList)
_, _, err = a.client.PullRequests.Edit(a.globalContext, a.config.GetOwner(), a.config.GetRepo(), a.config.GetNumber(),
&github.PullRequest{Body: &body})
if err != nil {
return fmt.Errorf("edit PR: %v", err)
}
}
return nil
}
func (a *Action) extractLabels(prBody string) map[string]bool {
r := regexp.MustCompile(a.config.GetLabelPattern())
targets := r.FindAllStringSubmatch(prBody, -1)
labels := make(map[string]bool)
//// Init labels from watch list
//for label := range a.config.labelWatchSet {
// labels[label] = false
//}
for _, v := range targets {
checked := strings.ToLower(strings.TrimSpace(v[1])) == "x"
name := strings.TrimSpace(v[2])
// Filter uninterested labels
if _, exist := a.config.labelWatchSet[name]; !exist {
continue
}
labels[name] = checked
}
return labels
}
func (a *Action) getRepoLabels() ([]*github.Label, error) {
ctx := context.Background()
listOptions := &github.ListOptions{PerPage: 100}
repoLabels := make([]*github.Label, 0)
for {
rLabels, resp, err := a.client.Issues.ListLabels(ctx, a.config.GetOwner(), a.config.GetRepo(), listOptions)
if err != nil {
return nil, err
}
repoLabels = append(repoLabels, rLabels...)
if resp.NextPage == 0 {
break
}
listOptions.Page = resp.NextPage
}
return repoLabels, nil
}
func (a *Action) getIssueLabels() ([]*github.Label, error) {
ctx := context.Background()
listOptions := &github.ListOptions{PerPage: 100}
issueLabels := make([]*github.Label, 0)
for {
iLabels, resp, err := a.client.Issues.ListLabelsByIssue(ctx, a.config.GetOwner(), a.config.GetRepo(), a.config.GetNumber(), listOptions)
if err != nil {
return nil, err
}
issueLabels = append(issueLabels, iLabels...)
if resp.NextPage == 0 {
break
}
listOptions.Page = resp.NextPage
}
return issueLabels, nil
}
func (a *Action) labelsToString(labels []*github.Label) []string {
result := []string{}
for _, label := range labels {
result = append(result, label.GetName())
}
return result
}
func (a *Action) labelsSetToString(labels map[string]struct{}) []string {
result := []string{}
for label := range labels {
result = append(result, label)
}
return result
}
func main() {
logger.Infoln("@Start docbot")
actionConfig, err := NewActionConfig()
if err != nil {
logger.Fatalf("Get action config: %v\n", err)
}
action := NewAction(actionConfig)
githubContext, err := githubactions.Context()
if err != nil {
logger.Fatalf("Get github context: %v\n", err)
}
switch githubContext.EventName {
case "issues":
logger.Infoln("@EventName is issues")
case "pull_request", "pull_request_target":
logger.Infoln("@EventName is PR")
actionType, ok := githubContext.Event["action"].(string)
if !ok {
logger.Fatalln("Action type is not string")
}
pr := githubContext.Event["pull_request"]
pullRequest, ok := pr.(map[string]interface{})
if !ok {
logger.Fatalln("PR event is not map")
}
number := int(githubContext.Event["number"].(float64))
prBody, ok := pullRequest["body"].(string)
if !ok {
logger.Fatalln("PR body is not string")
}
// Get expected labels
labels := action.extractLabels(prBody)
actionConfig.number = &number
actionConfig.labels = labels
if err := action.Run(actionType); err != nil {
logger.Fatalln(err)
}
}
}