blob: 66027dc5556e24bd7906602facf09c2c2ceca773 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// specialize is a low-level tool to generate type-specialized code. It is a
// convenience wrapper over text/template suitable for go generate. Unlike
// many other template tools, it does not parse Go code and allows use of
// text/template control within the template itself.
package main
import (
"bytes"
"flag"
"fmt"
"log"
"math"
"os"
"path/filepath"
"strings"
"text/template"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
var (
noheader = flag.Bool("noheader", false, "Omit auto-generated header")
pack = flag.String("package", "", "Package name (optional)")
imports = flag.String("imports", "", "Comma-separated list of extra imports (optional)")
x = flag.String("x", "", "Comma-separated list of X types (optional)")
y = flag.String("y", "", "Comma-separated list of Y types (optional)")
z = flag.String("z", "", "Comma-separated list of Z types (optional)")
input = flag.String("input", "", "Template file.")
output = flag.String("output", "", "Filename for generated code. If not provided, a file next to the input is generated.")
)
// Top is the top-level struct to be passed to the template.
type Top struct {
// Name is the base form of the filename: "foo/bar.tmpl" -> "bar".
Name string
// Package is the package name.
Package string
// Imports is a list of custom imports, if provided.
Imports []string
// X is the list of X type values.
X []*X
}
// X is the concrete type to be iterated over in the user template.
type X struct {
// Name is the name of X for use as identifier: "int" -> "Int", "[]byte" -> "ByteSlice".
Name string
// Type is the textual type of X: "int", "float32", "foo.Baz".
Type string
// Y is the list of Y type values for this X.
Y []*Y
}
// Y is the concrete type to be iterated over in the user template for each X.
// Each combination of X and Y will be present.
type Y struct {
// Name is the name of Y for use as identifier: "int" -> "Int", "[]byte" -> "ByteSlice".
Name string
// Type is the textual type of Y: "int", "float32", "foo.Baz".
Type string
// Z is the list of Z type values for this Y.
Z []*Z
}
// Z is the concrete type to be iterated over in the user template for each Y.
// Each combination of X, Y and Z will be present.
type Z struct {
// Name is the name of Z for use as identifier: "int" -> "Int", "[]byte" -> "ByteSlice".
Name string
// Type is the textual type of Z: "int", "float32", "foo.Baz".
Type string
}
var (
integers = []string{"int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64"}
floats = []string{"float32", "float64"}
primitives = append(append([]string{"bool", "string"}, integers...), floats...)
macros = map[string][]string{
"integers": integers,
"floats": floats,
"primitives": primitives,
"data": append([]string{"[]byte"}, primitives...),
"universals": {"typex.T", "typex.U", "typex.V", "typex.W", "typex.X", "typex.Y", "typex.Z"},
}
packageMacros = map[string][]string{
"typex": {"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"},
}
)
func usage() {
fmt.Fprintf(os.Stderr, "Usage: %v [options] --input=<filename.tmpl --x=<types>\n", filepath.Base(os.Args[0]))
flag.PrintDefaults()
}
func main() {
flag.Usage = usage
flag.Parse()
log.SetFlags(0)
log.SetPrefix("specialize: ")
if *input == "" {
flag.Usage()
log.Fatalf("no template file")
}
name := filepath.Base(*input)
if index := strings.Index(name, "."); index > 0 {
name = name[:index]
}
if *output == "" {
*output = filepath.Join(filepath.Dir(*input), name+".go")
}
top := Top{Name: name, Package: *pack, Imports: expand(packageMacros, *imports)}
var ys []*Y
if *y != "" {
var zs []*Z
if *z != "" {
for _, zt := range expand(macros, *z) {
zs = append(zs, &Z{Name: makeName(zt), Type: zt})
}
}
for _, yt := range expand(macros, *y) {
ys = append(ys, &Y{Name: makeName(yt), Type: yt, Z: zs})
}
}
for _, xt := range expand(macros, *x) {
top.X = append(top.X, &X{Name: makeName(xt), Type: xt, Y: ys})
}
tmpl, err := template.New(*input).Funcs(funcMap).ParseFiles(*input)
if err != nil {
log.Fatalf("template parse failed: %v", err)
}
var buf bytes.Buffer
if !*noheader {
buf.WriteString("// File generated by specialize. Do not edit.\n\n")
}
if err := tmpl.Funcs(funcMap).Execute(&buf, top); err != nil {
log.Fatalf("specialization failed: %v", err)
}
if err := os.WriteFile(*output, buf.Bytes(), 0644); err != nil {
log.Fatalf("write failed: %v", err)
}
}
// expand parses, cleans up and expands macros for a comma-separated list.
func expand(subst map[string][]string, list string) []string {
var ret []string
for _, xt := range strings.Split(list, ",") {
xt = strings.TrimSpace(xt)
if xt == "" {
continue
}
if exp, ok := subst[strings.ToLower(xt)]; ok {
for _, t := range exp {
ret = append(ret, t)
}
continue
}
ret = append(ret, xt)
}
return ret
}
// makeName creates a capitalized identifier from a type.
func makeName(t string) string {
if strings.HasPrefix(t, "[]") {
return makeName(t[2:] + "Slice")
}
t = strings.Replace(t, ".", "_", -1)
t = strings.Replace(t, "[", "_", -1)
t = strings.Replace(t, "]", "_", -1)
return cases.Title(language.Und, cases.NoLower).String(t)
}
// Useful template functions
var funcMap template.FuncMap = map[string]any{
"join": strings.Join,
"upto": upto,
"mkargs": mkargs,
"mktuple": mktuple,
"mktuplef": mktuplef,
"add": add,
"mult": mult,
"dict": dict,
"list": list,
"genericTypingRepresentation": genericTypingRepresentation,
"possibleBundleLifecycleParameterCombos": possibleBundleLifecycleParameterCombos,
}
// mkargs(n, type) returns "<fmt.Sprintf(format, 0)>, .., <fmt.Sprintf(format, n-1)> type".
// If n is 0, it returns the empty string.
func mkargs(n int, format, typ string) string {
if n == 0 {
return ""
}
return fmt.Sprintf("%v %v", mktuplef(n, format), typ)
}
// mktuple(n, v) returns "v, v, ..., v".
func mktuple(n int, v string) string {
var ret []string
for i := 0; i < n; i++ {
ret = append(ret, v)
}
return strings.Join(ret, ", ")
}
// mktuplef(n, format) returns "<fmt.Sprintf(format, 0)>, .., <fmt.Sprintf(format, n-1)>"
func mktuplef(n int, format string) string {
var ret []string
for i := 0; i < n; i++ {
ret = append(ret, fmt.Sprintf(format, i))
}
return strings.Join(ret, ", ")
}
// upto(n) returns []int{0, 1, .., n-1}.
func upto(i int) []int {
var ret []int
for k := 0; k < i; k++ {
ret = append(ret, k)
}
return ret
}
func add(i int, j int) int {
return i + j
}
func mult(i int, j int) int {
return i * j
}
func dict(values ...any) map[string]any {
dict := make(map[string]any, len(values)/2)
if len(values)%2 != 0 {
panic("Invalid dictionary call")
}
for i := 0; i < len(values); i += 2 {
dict[values[i].(string)] = values[i+1]
}
return dict
}
func list(values ...string) []string {
return values
}
func genericTypingRepresentation(in int, out int, includeType bool) string {
seenElements := false
typing := ""
if in > 0 {
typing += fmt.Sprintf("[I%v", 0)
for i := 1; i < in; i++ {
typing += fmt.Sprintf(", I%v", i)
}
seenElements = true
}
if out > 0 {
i := 0
if !seenElements {
typing += fmt.Sprintf("[R%v", 0)
i++
}
for i < out {
typing += fmt.Sprintf(", R%v", i)
i++
}
seenElements = true
}
if seenElements {
if includeType {
typing += " any"
}
typing += "]"
}
return typing
}
func possibleBundleLifecycleParameterCombos(numInInterface any, processElementInInterface any) [][]string {
numIn := numInInterface.(int)
processElementIn := processElementInInterface.(int)
orderedKnownParameterOptions := []string{"context.Context", "typex.PaneInfo", "[]typex.Window", "typex.EventTime", "typex.BundleFinalization"}
// Because of how Bundle lifecycle functions are invoked, all known parameters must precede unknown options and be in order.
// Once we hit an unknown options, all remaining unknown options must be included since all iters/emitters must be included
// Therefore, we can generate a powerset of the known options and fill out any remaining parameters with an ordered set of remaining unknown options
pSetSize := int(math.Pow(2, float64(len(orderedKnownParameterOptions))))
combos := make([][]string, 0, pSetSize)
for index := 0; index < pSetSize; index++ {
var subSet []string
for j, elem := range orderedKnownParameterOptions {
// And with the bit representation to get this iteration of the powerset.
if index&(1<<uint(j)) > 0 {
subSet = append(subSet, elem)
}
}
// Fill out any remaining parameter slots with consecutive parameters from ProcessElement if there are enough options
if len(subSet) <= numIn && numIn-len(subSet) <= processElementIn {
for len(subSet) < numIn {
nextElement := processElementIn - (numIn - len(subSet))
subSet = append(subSet, fmt.Sprintf("I%v", nextElement))
}
combos = append(combos, subSet)
}
}
return combos
}