| package cobra |
| |
| import ( |
| "fmt" |
| "io" |
| "os" |
| "sort" |
| "strings" |
| |
| "github.com/spf13/pflag" |
| ) |
| |
| const ( |
| BashCompFilenameExt = "cobra_annotation_bash_completion_filename_extentions" |
| BashCompOneRequiredFlag = "cobra_annotation_bash_completion_one_required_flag" |
| BashCompSubdirsInDir = "cobra_annotation_bash_completion_subdirs_in_dir" |
| ) |
| |
| func preamble(out io.Writer, name string) error { |
| _, err := fmt.Fprintf(out, "# bash completion for %-36s -*- shell-script -*-\n", name) |
| if err != nil { |
| return err |
| } |
| _, err = fmt.Fprintf(out, ` |
| __debug() |
| { |
| if [[ -n ${BASH_COMP_DEBUG_FILE} ]]; then |
| echo "$*" >> "${BASH_COMP_DEBUG_FILE}" |
| fi |
| } |
| |
| # Homebrew on Macs have version 1.3 of bash-completion which doesn't include |
| # _init_completion. This is a very minimal version of that function. |
| __my_init_completion() |
| { |
| COMPREPLY=() |
| _get_comp_words_by_ref cur prev words cword |
| } |
| |
| __index_of_word() |
| { |
| local w word=$1 |
| shift |
| index=0 |
| for w in "$@"; do |
| [[ $w = "$word" ]] && return |
| index=$((index+1)) |
| done |
| index=-1 |
| } |
| |
| __contains_word() |
| { |
| local w word=$1; shift |
| for w in "$@"; do |
| [[ $w = "$word" ]] && return |
| done |
| return 1 |
| } |
| |
| __handle_reply() |
| { |
| __debug "${FUNCNAME}" |
| case $cur in |
| -*) |
| if [[ $(type -t compopt) = "builtin" ]]; then |
| compopt -o nospace |
| fi |
| local allflags |
| if [ ${#must_have_one_flag[@]} -ne 0 ]; then |
| allflags=("${must_have_one_flag[@]}") |
| else |
| allflags=("${flags[*]} ${two_word_flags[*]}") |
| fi |
| COMPREPLY=( $(compgen -W "${allflags[*]}" -- "$cur") ) |
| if [[ $(type -t compopt) = "builtin" ]]; then |
| [[ $COMPREPLY == *= ]] || compopt +o nospace |
| fi |
| return 0; |
| ;; |
| esac |
| |
| # check if we are handling a flag with special work handling |
| local index |
| __index_of_word "${prev}" "${flags_with_completion[@]}" |
| if [[ ${index} -ge 0 ]]; then |
| ${flags_completion[${index}]} |
| return |
| fi |
| |
| # we are parsing a flag and don't have a special handler, no completion |
| if [[ ${cur} != "${words[cword]}" ]]; then |
| return |
| fi |
| |
| local completions |
| if [[ ${#must_have_one_flag[@]} -ne 0 ]]; then |
| completions=("${must_have_one_flag[@]}") |
| elif [[ ${#must_have_one_noun[@]} -ne 0 ]]; then |
| completions=("${must_have_one_noun[@]}") |
| else |
| completions=("${commands[@]}") |
| fi |
| COMPREPLY=( $(compgen -W "${completions[*]}" -- "$cur") ) |
| |
| if [[ ${#COMPREPLY[@]} -eq 0 ]]; then |
| declare -F __custom_func >/dev/null && __custom_func |
| fi |
| |
| __ltrim_colon_completions "$cur" |
| } |
| |
| # The arguments should be in the form "ext1|ext2|extn" |
| __handle_filename_extension_flag() |
| { |
| local ext="$1" |
| _filedir "@(${ext})" |
| } |
| |
| __handle_subdirs_in_dir_flag() |
| { |
| local dir="$1" |
| pushd "${dir}" >/dev/null 2>&1 && _filedir -d && popd >/dev/null 2>&1 |
| } |
| |
| __handle_flag() |
| { |
| __debug "${FUNCNAME}: c is $c words[c] is ${words[c]}" |
| |
| # if a command required a flag, and we found it, unset must_have_one_flag() |
| local flagname=${words[c]} |
| local flagvalue |
| # if the word contained an = |
| if [[ ${words[c]} == *"="* ]]; then |
| flagvalue=${flagname#*=} # take in as flagvalue after the = |
| flagname=${flagname%%=*} # strip everything after the = |
| flagname="${flagname}=" # but put the = back |
| fi |
| __debug "${FUNCNAME}: looking for ${flagname}" |
| if __contains_word "${flagname}" "${must_have_one_flag[@]}"; then |
| must_have_one_flag=() |
| fi |
| |
| # keep flag value with flagname as flaghash |
| if [ ${flagvalue} ] ; then |
| flaghash[${flagname}]=${flagvalue} |
| elif [ ${words[ $((c+1)) ]} ] ; then |
| flaghash[${flagname}]=${words[ $((c+1)) ]} |
| else |
| flaghash[${flagname}]="true" # pad "true" for bool flag |
| fi |
| |
| # skip the argument to a two word flag |
| if __contains_word "${words[c]}" "${two_word_flags[@]}"; then |
| c=$((c+1)) |
| # if we are looking for a flags value, don't show commands |
| if [[ $c -eq $cword ]]; then |
| commands=() |
| fi |
| fi |
| |
| c=$((c+1)) |
| |
| } |
| |
| __handle_noun() |
| { |
| __debug "${FUNCNAME}: c is $c words[c] is ${words[c]}" |
| |
| if __contains_word "${words[c]}" "${must_have_one_noun[@]}"; then |
| must_have_one_noun=() |
| fi |
| |
| nouns+=("${words[c]}") |
| c=$((c+1)) |
| } |
| |
| __handle_command() |
| { |
| __debug "${FUNCNAME}: c is $c words[c] is ${words[c]}" |
| |
| local next_command |
| if [[ -n ${last_command} ]]; then |
| next_command="_${last_command}_${words[c]//:/__}" |
| else |
| if [[ $c -eq 0 ]]; then |
| next_command="_$(basename ${words[c]//:/__})" |
| else |
| next_command="_${words[c]//:/__}" |
| fi |
| fi |
| c=$((c+1)) |
| __debug "${FUNCNAME}: looking for ${next_command}" |
| declare -F $next_command >/dev/null && $next_command |
| } |
| |
| __handle_word() |
| { |
| if [[ $c -ge $cword ]]; then |
| __handle_reply |
| return |
| fi |
| __debug "${FUNCNAME}: c is $c words[c] is ${words[c]}" |
| if [[ "${words[c]}" == -* ]]; then |
| __handle_flag |
| elif __contains_word "${words[c]}" "${commands[@]}"; then |
| __handle_command |
| elif [[ $c -eq 0 ]] && __contains_word "$(basename ${words[c]})" "${commands[@]}"; then |
| __handle_command |
| else |
| __handle_noun |
| fi |
| __handle_word |
| } |
| |
| `) |
| return err |
| } |
| |
| func postscript(w io.Writer, name string) error { |
| name = strings.Replace(name, ":", "__", -1) |
| _, err := fmt.Fprintf(w, "__start_%s()\n", name) |
| if err != nil { |
| return err |
| } |
| _, err = fmt.Fprintf(w, `{ |
| local cur prev words cword |
| declare -A flaghash 2>/dev/null || : |
| if declare -F _init_completion >/dev/null 2>&1; then |
| _init_completion -s || return |
| else |
| __my_init_completion || return |
| fi |
| |
| local c=0 |
| local flags=() |
| local two_word_flags=() |
| local flags_with_completion=() |
| local flags_completion=() |
| local commands=("%s") |
| local must_have_one_flag=() |
| local must_have_one_noun=() |
| local last_command |
| local nouns=() |
| |
| __handle_word |
| } |
| |
| `, name) |
| if err != nil { |
| return err |
| } |
| _, err = fmt.Fprintf(w, `if [[ $(type -t compopt) = "builtin" ]]; then |
| complete -o default -F __start_%s %s |
| else |
| complete -o default -o nospace -F __start_%s %s |
| fi |
| |
| `, name, name, name, name) |
| if err != nil { |
| return err |
| } |
| _, err = fmt.Fprintf(w, "# ex: ts=4 sw=4 et filetype=sh\n") |
| return err |
| } |
| |
| func writeCommands(cmd *Command, w io.Writer) error { |
| if _, err := fmt.Fprintf(w, " commands=()\n"); err != nil { |
| return err |
| } |
| for _, c := range cmd.Commands() { |
| if !c.IsAvailableCommand() || c == cmd.helpCommand { |
| continue |
| } |
| if _, err := fmt.Fprintf(w, " commands+=(%q)\n", c.Name()); err != nil { |
| return err |
| } |
| } |
| _, err := fmt.Fprintf(w, "\n") |
| return err |
| } |
| |
| func writeFlagHandler(name string, annotations map[string][]string, w io.Writer) error { |
| for key, value := range annotations { |
| switch key { |
| case BashCompFilenameExt: |
| _, err := fmt.Fprintf(w, " flags_with_completion+=(%q)\n", name) |
| if err != nil { |
| return err |
| } |
| |
| if len(value) > 0 { |
| ext := "__handle_filename_extension_flag " + strings.Join(value, "|") |
| _, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext) |
| } else { |
| ext := "_filedir" |
| _, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext) |
| } |
| if err != nil { |
| return err |
| } |
| case BashCompSubdirsInDir: |
| _, err := fmt.Fprintf(w, " flags_with_completion+=(%q)\n", name) |
| |
| if len(value) == 1 { |
| ext := "__handle_subdirs_in_dir_flag " + value[0] |
| _, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext) |
| } else { |
| ext := "_filedir -d" |
| _, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext) |
| } |
| if err != nil { |
| return err |
| } |
| } |
| } |
| return nil |
| } |
| |
| func writeShortFlag(flag *pflag.Flag, w io.Writer) error { |
| b := (flag.Value.Type() == "bool") |
| name := flag.Shorthand |
| format := " " |
| if !b { |
| format += "two_word_" |
| } |
| format += "flags+=(\"-%s\")\n" |
| if _, err := fmt.Fprintf(w, format, name); err != nil { |
| return err |
| } |
| return writeFlagHandler("-"+name, flag.Annotations, w) |
| } |
| |
| func writeFlag(flag *pflag.Flag, w io.Writer) error { |
| b := (flag.Value.Type() == "bool") |
| name := flag.Name |
| format := " flags+=(\"--%s" |
| if !b { |
| format += "=" |
| } |
| format += "\")\n" |
| if _, err := fmt.Fprintf(w, format, name); err != nil { |
| return err |
| } |
| return writeFlagHandler("--"+name, flag.Annotations, w) |
| } |
| |
| func writeFlags(cmd *Command, w io.Writer) error { |
| _, err := fmt.Fprintf(w, ` flags=() |
| two_word_flags=() |
| flags_with_completion=() |
| flags_completion=() |
| |
| `) |
| if err != nil { |
| return err |
| } |
| var visitErr error |
| cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) { |
| if err := writeFlag(flag, w); err != nil { |
| visitErr = err |
| return |
| } |
| if len(flag.Shorthand) > 0 { |
| if err := writeShortFlag(flag, w); err != nil { |
| visitErr = err |
| return |
| } |
| } |
| }) |
| if visitErr != nil { |
| return visitErr |
| } |
| cmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) { |
| if err := writeFlag(flag, w); err != nil { |
| visitErr = err |
| return |
| } |
| if len(flag.Shorthand) > 0 { |
| if err := writeShortFlag(flag, w); err != nil { |
| visitErr = err |
| return |
| } |
| } |
| }) |
| if visitErr != nil { |
| return visitErr |
| } |
| |
| _, err = fmt.Fprintf(w, "\n") |
| return err |
| } |
| |
| func writeRequiredFlag(cmd *Command, w io.Writer) error { |
| if _, err := fmt.Fprintf(w, " must_have_one_flag=()\n"); err != nil { |
| return err |
| } |
| flags := cmd.NonInheritedFlags() |
| var visitErr error |
| flags.VisitAll(func(flag *pflag.Flag) { |
| for key := range flag.Annotations { |
| switch key { |
| case BashCompOneRequiredFlag: |
| format := " must_have_one_flag+=(\"--%s" |
| b := (flag.Value.Type() == "bool") |
| if !b { |
| format += "=" |
| } |
| format += "\")\n" |
| if _, err := fmt.Fprintf(w, format, flag.Name); err != nil { |
| visitErr = err |
| return |
| } |
| |
| if len(flag.Shorthand) > 0 { |
| if _, err := fmt.Fprintf(w, " must_have_one_flag+=(\"-%s\")\n", flag.Shorthand); err != nil { |
| visitErr = err |
| return |
| } |
| } |
| } |
| } |
| }) |
| return visitErr |
| } |
| |
| func writeRequiredNoun(cmd *Command, w io.Writer) error { |
| if _, err := fmt.Fprintf(w, " must_have_one_noun=()\n"); err != nil { |
| return err |
| } |
| sort.Sort(sort.StringSlice(cmd.ValidArgs)) |
| for _, value := range cmd.ValidArgs { |
| if _, err := fmt.Fprintf(w, " must_have_one_noun+=(%q)\n", value); err != nil { |
| return err |
| } |
| } |
| return nil |
| } |
| |
| func gen(cmd *Command, w io.Writer) error { |
| for _, c := range cmd.Commands() { |
| if !c.IsAvailableCommand() || c == cmd.helpCommand { |
| continue |
| } |
| if err := gen(c, w); err != nil { |
| return err |
| } |
| } |
| commandName := cmd.CommandPath() |
| commandName = strings.Replace(commandName, " ", "_", -1) |
| commandName = strings.Replace(commandName, ":", "__", -1) |
| if _, err := fmt.Fprintf(w, "_%s()\n{\n", commandName); err != nil { |
| return err |
| } |
| if _, err := fmt.Fprintf(w, " last_command=%q\n", commandName); err != nil { |
| return err |
| } |
| if err := writeCommands(cmd, w); err != nil { |
| return err |
| } |
| if err := writeFlags(cmd, w); err != nil { |
| return err |
| } |
| if err := writeRequiredFlag(cmd, w); err != nil { |
| return err |
| } |
| if err := writeRequiredNoun(cmd, w); err != nil { |
| return err |
| } |
| if _, err := fmt.Fprintf(w, "}\n\n"); err != nil { |
| return err |
| } |
| return nil |
| } |
| |
| func (cmd *Command) GenBashCompletion(w io.Writer) error { |
| if err := preamble(w, cmd.Name()); err != nil { |
| return err |
| } |
| if len(cmd.BashCompletionFunction) > 0 { |
| if _, err := fmt.Fprintf(w, "%s\n", cmd.BashCompletionFunction); err != nil { |
| return err |
| } |
| } |
| if err := gen(cmd, w); err != nil { |
| return err |
| } |
| return postscript(w, cmd.Name()) |
| } |
| |
| func (cmd *Command) GenBashCompletionFile(filename string) error { |
| outFile, err := os.Create(filename) |
| if err != nil { |
| return err |
| } |
| defer outFile.Close() |
| |
| return cmd.GenBashCompletion(outFile) |
| } |
| |
| // MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag, if it exists. |
| func (cmd *Command) MarkFlagRequired(name string) error { |
| return MarkFlagRequired(cmd.Flags(), name) |
| } |
| |
| // MarkPersistentFlagRequired adds the BashCompOneRequiredFlag annotation to the named persistent flag, if it exists. |
| func (cmd *Command) MarkPersistentFlagRequired(name string) error { |
| return MarkFlagRequired(cmd.PersistentFlags(), name) |
| } |
| |
| // MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag in the flag set, if it exists. |
| func MarkFlagRequired(flags *pflag.FlagSet, name string) error { |
| return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"}) |
| } |
| |
| // MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag, if it exists. |
| // Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided. |
| func (cmd *Command) MarkFlagFilename(name string, extensions ...string) error { |
| return MarkFlagFilename(cmd.Flags(), name, extensions...) |
| } |
| |
| // MarkPersistentFlagFilename adds the BashCompFilenameExt annotation to the named persistent flag, if it exists. |
| // Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided. |
| func (cmd *Command) MarkPersistentFlagFilename(name string, extensions ...string) error { |
| return MarkFlagFilename(cmd.PersistentFlags(), name, extensions...) |
| } |
| |
| // MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag in the flag set, if it exists. |
| // Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided. |
| func MarkFlagFilename(flags *pflag.FlagSet, name string, extensions ...string) error { |
| return flags.SetAnnotation(name, BashCompFilenameExt, extensions) |
| } |