package main import ( "bytes" _ "embed" "errors" "go/ast" "go/format" "go/parser" "go/token" "os" "reflect" "strings" "text/template" ) //go:embed templates.go.tmpl var fileTemplateString string func main() { if err := run(); err != nil { panic(err) } } func slice(fileContents string, fset *token.FileSet, start token.Pos, end token.Pos) string { return fileContents[fset.Position(start).Offset:fset.Position(end).Offset] } type Package struct { PackageName string Functions []Function } type Function struct { Name string Pattern string RequestTypeDef string RequestTypeFields []RequestTypeField DoParseForm bool } type RequestTypeField struct { Name string Extractor string Optional bool NameInReq string TypeDef string } func run() error { fileTemplate, err := template.New("").Funcs(template.FuncMap{ "quote": func(s string) string { return `"` + strings.NewReplacer(`\`, `\\`, `"`, `\"`, "\n", `\n`).Replace(s) + `"` }, "error": func(msg string) struct{} { panic("error in template: " + msg) }, }).Parse(fileTemplateString) if err != nil { return err } dirEntries, err := os.ReadDir(".") if err != nil { return err } parsedPackage := Package{} for _, ent := range dirEntries { if !strings.HasSuffix(ent.Name(), ".go") { continue } fileBytes, err := os.ReadFile(ent.Name()) if err != nil { return err } fileContents := string(fileBytes) var fset token.FileSet f, err := parser.ParseFile(&fset, ent.Name(), fileBytes, parser.ParseComments|parser.SkipObjectResolution) if err != nil { return err } if strings.HasSuffix(f.Name.String(), "_test") { continue } if parsedPackage.PackageName != "" && parsedPackage.PackageName != f.Name.String() { return errors.New("Found two different package names in directory: " + parsedPackage.PackageName + " and " + f.Name.String()) } parsedPackage.PackageName = f.Name.String() for _, decl := range f.Decls { f, ok := decl.(*ast.FuncDecl) if !ok { continue } if f.Doc == nil { continue } hhRoute := f.Doc.List[len(f.Doc.List)-1].Text var pattern string if pattern, ok = strings.CutPrefix(hhRoute, "//hh:route "); !ok { continue } parsedRequestType, ok := f.Type.Params.List[1].Type.(*ast.StructType) if !ok { return errors.New("Parsed request type must be a struct") } parsedFunction := Function{ Name: f.Name.Name, Pattern: pattern, RequestTypeDef: slice(fileContents, &fset, parsedRequestType.Pos(), parsedRequestType.End()), } for _, field := range parsedRequestType.Fields.List { for _, nameIdent := range field.Names { typ := field.Type parsedField := RequestTypeField{ Name: nameIdent.Name, Extractor: "", Optional: false, TypeDef: slice(fileContents, &fset, typ.Pos(), typ.End()), } if parsedField.TypeDef == "*http.Request" { parsedFunction.RequestTypeFields = append(parsedFunction.RequestTypeFields, parsedField) continue } var tag string if field.Tag != nil { tag = reflect.StructTag(field.Tag.Value[1 : len(field.Tag.Value)-1]).Get("hh") } if tag == "" { return errors.New("Don't know what to do with '" + parsedField.Name + "'. You must add a `hh:\"...\"` tag to specify") } tags := strings.Split(tag, ",") if tags[0] == "optional" { parsedField.Optional = true tags = tags[1:] } if len(tags) == 0 { return errors.New("Must specify extractor for '" + parsedField.Name + "' in `" + tag + "`") } parsedField.Extractor = tags[0] tags = tags[1:] switch parsedField.Extractor { case "form": parsedFunction.DoParseForm = true case "cookie", "path": default: return errors.New("Unknown extractor '" + tags[0] + "' on field " + nameIdent.Name) } if len(tags) >= 1 { parsedField.NameInReq = tags[0] tags = tags[1:] } else { parsedField.NameInReq = parsedField.Name } if len(tags) > 0 { return errors.New("Unexpected rest of tag '" + tags[0] + "' in tag `" + tag + "` on field " + nameIdent.Name) } parsedFunction.RequestTypeFields = append(parsedFunction.RequestTypeFields, parsedField) } } parsedPackage.Functions = append(parsedPackage.Functions, parsedFunction) } } var unformatted bytes.Buffer if err := fileTemplate.Execute(&unformatted, parsedPackage); err != nil { return err } formatted, err := format.Source(unformatted.Bytes()) if err != nil { return err } if err := os.WriteFile("hh.gen.go", formatted, 0o660); err != nil { return err } return nil }