|  | // Licensed to the Apache Software Foundation (ASF) under one or more | 
|  | // contributor license agreements.  See the NOTICE file distributed with | 
|  | // this work for additional information regarding copyright ownership. | 
|  | // The ASF licenses this file to You under the Apache License, Version 2.0 | 
|  | // (the "License"); you may not use this file except in compliance with | 
|  | // the License.  You may obtain a copy of the License at | 
|  | // | 
|  | //    http://www.apache.org/licenses/LICENSE-2.0 | 
|  | // | 
|  | // Unless required by applicable law or agreed to in writing, software | 
|  | // distributed under the License is distributed on an "AS IS" BASIS, | 
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | // See the License for the specific language governing permissions and | 
|  | // limitations under the License. | 
|  |  | 
|  | // specialize is a low-level tool to generate type-specialized code. It is a | 
|  | // convenience wrapper over text/template suitable for go generate. Unlike | 
|  | // many other template tools, it does not parse Go code and allows use of | 
|  | // text/template control within the template itself. | 
|  | package main | 
|  |  | 
|  | import ( | 
|  | "bytes" | 
|  | "flag" | 
|  | "fmt" | 
|  | "log" | 
|  | "math" | 
|  | "os" | 
|  | "path/filepath" | 
|  | "strings" | 
|  | "text/template" | 
|  |  | 
|  | "golang.org/x/text/cases" | 
|  | "golang.org/x/text/language" | 
|  | ) | 
|  |  | 
|  | var ( | 
|  | noheader = flag.Bool("noheader", false, "Omit auto-generated header") | 
|  | pack     = flag.String("package", "", "Package name (optional)") | 
|  | imports  = flag.String("imports", "", "Comma-separated list of extra imports (optional)") | 
|  |  | 
|  | x = flag.String("x", "", "Comma-separated list of X types (optional)") | 
|  | y = flag.String("y", "", "Comma-separated list of Y types (optional)") | 
|  | z = flag.String("z", "", "Comma-separated list of Z types (optional)") | 
|  |  | 
|  | input  = flag.String("input", "", "Template file.") | 
|  | output = flag.String("output", "", "Filename for generated code. If not provided, a file next to the input is generated.") | 
|  | ) | 
|  |  | 
|  | // Top is the top-level struct to be passed to the template. | 
|  | type Top struct { | 
|  | // Name is the base form of the filename: "foo/bar.tmpl" -> "bar". | 
|  | Name string | 
|  | // Package is the package name. | 
|  | Package string | 
|  | // Imports is a list of custom imports, if provided. | 
|  | Imports []string | 
|  | // X is the list of X type values. | 
|  | X []*X | 
|  | } | 
|  |  | 
|  | // X is the concrete type to be iterated over in the user template. | 
|  | type X struct { | 
|  | // Name is the name of X for use as identifier: "int" -> "Int", "[]byte" -> "ByteSlice". | 
|  | Name string | 
|  | // Type is the textual type of X: "int", "float32", "foo.Baz". | 
|  | Type string | 
|  | // Y is the list of Y type values for this X. | 
|  | Y []*Y | 
|  | } | 
|  |  | 
|  | // Y is the concrete type to be iterated over in the user template for each X. | 
|  | // Each combination of X and Y will be present. | 
|  | type Y struct { | 
|  | // Name is the name of Y for use as identifier: "int" -> "Int", "[]byte" -> "ByteSlice". | 
|  | Name string | 
|  | // Type is the textual type of Y: "int", "float32", "foo.Baz". | 
|  | Type string | 
|  | // Z is the list of Z type values for this Y. | 
|  | Z []*Z | 
|  | } | 
|  |  | 
|  | // Z is the concrete type to be iterated over in the user template for each Y. | 
|  | // Each combination of X, Y and Z will be present. | 
|  | type Z struct { | 
|  | // Name is the name of Z for use as identifier: "int" -> "Int", "[]byte" -> "ByteSlice". | 
|  | Name string | 
|  | // Type is the textual type of Z: "int", "float32", "foo.Baz". | 
|  | Type string | 
|  | } | 
|  |  | 
|  | var ( | 
|  | integers   = []string{"int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64"} | 
|  | floats     = []string{"float32", "float64"} | 
|  | primitives = append(append([]string{"bool", "string"}, integers...), floats...) | 
|  |  | 
|  | macros = map[string][]string{ | 
|  | "integers":   integers, | 
|  | "floats":     floats, | 
|  | "primitives": primitives, | 
|  | "data":       append([]string{"[]byte"}, primitives...), | 
|  | "universals": {"typex.T", "typex.U", "typex.V", "typex.W", "typex.X", "typex.Y", "typex.Z"}, | 
|  | } | 
|  |  | 
|  | packageMacros = map[string][]string{ | 
|  | "typex": {"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"}, | 
|  | } | 
|  | ) | 
|  |  | 
|  | func usage() { | 
|  | fmt.Fprintf(os.Stderr, "Usage: %v [options] --input=<filename.tmpl --x=<types>\n", filepath.Base(os.Args[0])) | 
|  | flag.PrintDefaults() | 
|  | } | 
|  |  | 
|  | func main() { | 
|  | flag.Usage = usage | 
|  | flag.Parse() | 
|  |  | 
|  | log.SetFlags(0) | 
|  | log.SetPrefix("specialize: ") | 
|  |  | 
|  | if *input == "" { | 
|  | flag.Usage() | 
|  | log.Fatalf("no template file") | 
|  | } | 
|  |  | 
|  | name := filepath.Base(*input) | 
|  | if index := strings.Index(name, "."); index > 0 { | 
|  | name = name[:index] | 
|  | } | 
|  | if *output == "" { | 
|  | *output = filepath.Join(filepath.Dir(*input), name+".go") | 
|  | } | 
|  |  | 
|  | top := Top{Name: name, Package: *pack, Imports: expand(packageMacros, *imports)} | 
|  | var ys []*Y | 
|  | if *y != "" { | 
|  | var zs []*Z | 
|  | if *z != "" { | 
|  | for _, zt := range expand(macros, *z) { | 
|  | zs = append(zs, &Z{Name: makeName(zt), Type: zt}) | 
|  | } | 
|  | } | 
|  | for _, yt := range expand(macros, *y) { | 
|  | ys = append(ys, &Y{Name: makeName(yt), Type: yt, Z: zs}) | 
|  | } | 
|  | } | 
|  | for _, xt := range expand(macros, *x) { | 
|  | top.X = append(top.X, &X{Name: makeName(xt), Type: xt, Y: ys}) | 
|  | } | 
|  |  | 
|  | tmpl, err := template.New(*input).Funcs(funcMap).ParseFiles(*input) | 
|  | if err != nil { | 
|  | log.Fatalf("template parse failed: %v", err) | 
|  | } | 
|  | var buf bytes.Buffer | 
|  | if !*noheader { | 
|  | buf.WriteString("// File generated by specialize. Do not edit.\n\n") | 
|  | } | 
|  | if err := tmpl.Funcs(funcMap).Execute(&buf, top); err != nil { | 
|  | log.Fatalf("specialization failed: %v", err) | 
|  | } | 
|  | if err := os.WriteFile(*output, buf.Bytes(), 0644); err != nil { | 
|  | log.Fatalf("write failed: %v", err) | 
|  | } | 
|  | } | 
|  |  | 
|  | // expand parses, cleans up and expands macros for a comma-separated list. | 
|  | func expand(subst map[string][]string, list string) []string { | 
|  | var ret []string | 
|  | for _, xt := range strings.Split(list, ",") { | 
|  | xt = strings.TrimSpace(xt) | 
|  | if xt == "" { | 
|  | continue | 
|  | } | 
|  | if exp, ok := subst[strings.ToLower(xt)]; ok { | 
|  | for _, t := range exp { | 
|  | ret = append(ret, t) | 
|  | } | 
|  | continue | 
|  | } | 
|  | ret = append(ret, xt) | 
|  | } | 
|  | return ret | 
|  | } | 
|  |  | 
|  | // makeName creates a capitalized identifier from a type. | 
|  | func makeName(t string) string { | 
|  | if strings.HasPrefix(t, "[]") { | 
|  | return makeName(t[2:] + "Slice") | 
|  | } | 
|  |  | 
|  | t = strings.Replace(t, ".", "_", -1) | 
|  | t = strings.Replace(t, "[", "_", -1) | 
|  | t = strings.Replace(t, "]", "_", -1) | 
|  | return cases.Title(language.Und, cases.NoLower).String(t) | 
|  | } | 
|  |  | 
|  | // Useful template functions | 
|  |  | 
|  | var funcMap template.FuncMap = map[string]any{ | 
|  | "join":                                   strings.Join, | 
|  | "upto":                                   upto, | 
|  | "mkargs":                                 mkargs, | 
|  | "mktuple":                                mktuple, | 
|  | "mktuplef":                               mktuplef, | 
|  | "add":                                    add, | 
|  | "mult":                                   mult, | 
|  | "dict":                                   dict, | 
|  | "list":                                   list, | 
|  | "genericTypingRepresentation":            genericTypingRepresentation, | 
|  | "possibleBundleLifecycleParameterCombos": possibleBundleLifecycleParameterCombos, | 
|  | } | 
|  |  | 
|  | // mkargs(n, type) returns "<fmt.Sprintf(format, 0)>, .., <fmt.Sprintf(format, n-1)> type". | 
|  | // If n is 0, it returns the empty string. | 
|  | func mkargs(n int, format, typ string) string { | 
|  | if n == 0 { | 
|  | return "" | 
|  | } | 
|  | return fmt.Sprintf("%v %v", mktuplef(n, format), typ) | 
|  | } | 
|  |  | 
|  | // mktuple(n, v) returns "v, v, ..., v". | 
|  | func mktuple(n int, v string) string { | 
|  | var ret []string | 
|  | for i := 0; i < n; i++ { | 
|  | ret = append(ret, v) | 
|  | } | 
|  | return strings.Join(ret, ", ") | 
|  | } | 
|  |  | 
|  | // mktuplef(n, format) returns "<fmt.Sprintf(format, 0)>, .., <fmt.Sprintf(format, n-1)>" | 
|  | func mktuplef(n int, format string) string { | 
|  | var ret []string | 
|  | for i := 0; i < n; i++ { | 
|  | ret = append(ret, fmt.Sprintf(format, i)) | 
|  | } | 
|  | return strings.Join(ret, ", ") | 
|  | } | 
|  |  | 
|  | // upto(n) returns []int{0, 1, .., n-1}. | 
|  | func upto(i int) []int { | 
|  | var ret []int | 
|  | for k := 0; k < i; k++ { | 
|  | ret = append(ret, k) | 
|  | } | 
|  | return ret | 
|  | } | 
|  |  | 
|  | func add(i int, j int) int { | 
|  | return i + j | 
|  | } | 
|  |  | 
|  | func mult(i int, j int) int { | 
|  | return i * j | 
|  | } | 
|  |  | 
|  | func dict(values ...any) map[string]any { | 
|  | dict := make(map[string]any, len(values)/2) | 
|  | if len(values)%2 != 0 { | 
|  | panic("Invalid dictionary call") | 
|  | } | 
|  | for i := 0; i < len(values); i += 2 { | 
|  | dict[values[i].(string)] = values[i+1] | 
|  | } | 
|  |  | 
|  | return dict | 
|  | } | 
|  |  | 
|  | func list(values ...string) []string { | 
|  | return values | 
|  | } | 
|  |  | 
|  | func genericTypingRepresentation(in int, out int, includeType bool) string { | 
|  | seenElements := false | 
|  | typing := "" | 
|  | if in > 0 { | 
|  | typing += fmt.Sprintf("[I%v", 0) | 
|  | for i := 1; i < in; i++ { | 
|  | typing += fmt.Sprintf(", I%v", i) | 
|  | } | 
|  | seenElements = true | 
|  | } | 
|  | if out > 0 { | 
|  | i := 0 | 
|  | if !seenElements { | 
|  | typing += fmt.Sprintf("[R%v", 0) | 
|  | i++ | 
|  | } | 
|  | for i < out { | 
|  | typing += fmt.Sprintf(", R%v", i) | 
|  | i++ | 
|  | } | 
|  | seenElements = true | 
|  | } | 
|  |  | 
|  | if seenElements { | 
|  | if includeType { | 
|  | typing += " any" | 
|  | } | 
|  | typing += "]" | 
|  | } | 
|  |  | 
|  | return typing | 
|  | } | 
|  |  | 
|  | func possibleBundleLifecycleParameterCombos(numInInterface any, processElementInInterface any) [][]string { | 
|  | numIn := numInInterface.(int) | 
|  | processElementIn := processElementInInterface.(int) | 
|  | orderedKnownParameterOptions := []string{"context.Context", "typex.PaneInfo", "[]typex.Window", "typex.EventTime", "typex.BundleFinalization"} | 
|  | // Because of how Bundle lifecycle functions are invoked, all known parameters must precede unknown options and be in order. | 
|  | // Once we hit an unknown options, all remaining unknown options must be included since all iters/emitters must be included | 
|  | // Therefore, we can generate a powerset of the known options and fill out any remaining parameters with an ordered set of remaining unknown options | 
|  | pSetSize := int(math.Pow(2, float64(len(orderedKnownParameterOptions)))) | 
|  | combos := make([][]string, 0, pSetSize) | 
|  |  | 
|  | for index := 0; index < pSetSize; index++ { | 
|  | var subSet []string | 
|  |  | 
|  | for j, elem := range orderedKnownParameterOptions { | 
|  | // And with the bit representation to get this iteration of the powerset. | 
|  | if index&(1<<uint(j)) > 0 { | 
|  | subSet = append(subSet, elem) | 
|  | } | 
|  | } | 
|  | // Fill out any remaining parameter slots with consecutive parameters from ProcessElement if there are enough options | 
|  | if len(subSet) <= numIn && numIn-len(subSet) <= processElementIn { | 
|  | for len(subSet) < numIn { | 
|  | nextElement := processElementIn - (numIn - len(subSet)) | 
|  | subSet = append(subSet, fmt.Sprintf("I%v", nextElement)) | 
|  | } | 
|  | combos = append(combos, subSet) | 
|  | } | 
|  | } | 
|  |  | 
|  | return combos | 
|  | } |