diff options
author | Mathias Magnusson <mathias@magnusson.space> | 2024-11-17 20:03:48 +0100 |
---|---|---|
committer | Mathias Magnusson <mathias@magnusson.space> | 2024-11-17 20:03:48 +0100 |
commit | 2778c52e4da52fd33f2df7fc9024252c2470b172 (patch) | |
tree | 1134be8e985f828e31c609ac061b1275ed35f972 /cmd/generate | |
parent | 19fa57e67bcc4af13a252c17c0e18adab162d2d1 (diff) | |
download | hh-2778c52e4da52fd33f2df7fc9024252c2470b172.tar.gz |
use a template instead of a bunch of buffer.WriteString
Diffstat (limited to 'cmd/generate')
-rw-r--r-- | cmd/generate/main.go | 162 | ||||
-rw-r--r-- | cmd/generate/templates.go.tmpl | 62 |
2 files changed, 137 insertions, 87 deletions
diff --git a/cmd/generate/main.go b/cmd/generate/main.go index 4494690..bdb46f2 100644 --- a/cmd/generate/main.go +++ b/cmd/generate/main.go @@ -1,18 +1,20 @@ package main import ( - "bytes" + _ "embed" "errors" - "fmt" "go/ast" "go/parser" "go/token" - "io" "os" "reflect" "strings" + "text/template" ) +//go:embed templates.go.tmpl +var fileTemplateString string + func main() { if err := run(); err != nil { panic(err) @@ -23,16 +25,40 @@ func slice(fileContents string, fset *token.FileSet, start token.Pos, end token. return fileContents[fset.Position(start).Offset:fset.Position(end).Offset] } -func indent(s string) string { - var output bytes.Buffer - for _, line := range strings.SplitAfter(s, "\n") { - output.WriteByte('\t') - output.WriteString(line) - } - return output.String() +type File struct { + PackageName string + Functions []Function +} + +type Function struct { + Name string + Pattern string + RequestTypeDef string + RequestTypeFields []RequestTypeField + DoParseForm bool +} + +type RequestTypeField struct { + Name string + Extractor string + Optional bool + NameInReq string + TypeDef string } func run() error { + fileTemplate, err := template.New("").Funcs(template.FuncMap{ + "quote": func(s string) string { + return `"` + strings.NewReplacer(`\`, `\\`, `"`, `\"`, "\n", `\n`).Replace(s) + `"` + }, + "error": func() struct{} { + panic("error") + }, + }).Parse(fileTemplateString) + if err != nil { + return err + } + filename := "examples/basic.go" fileBytes, err := os.ReadFile(filename) if err != nil { @@ -44,11 +70,7 @@ func run() error { if err != nil { return err } - var output bytes.Buffer - output.WriteString("package ") - output.WriteString(f.Name.Name) - output.WriteByte('\n') - handlers := map[string]string{} + parsedFile := File{PackageName: f.Name.Name} for _, decl := range f.Decls { f, ok := decl.(*ast.FuncDecl) if !ok { @@ -58,104 +80,70 @@ func run() error { continue } hhRoute := f.Doc.List[len(f.Doc.List)-1].Text - var routeSpec string - if routeSpec, ok = strings.CutPrefix(hhRoute, "//hh:route "); !ok { + var pattern string + if pattern, ok = strings.CutPrefix(hhRoute, "//hh:route "); !ok { continue } - handlers[routeSpec] = "hh_" + f.Name.String() - output.WriteString("\nfunc hh_") - output.WriteString(f.Name.String()) - output.WriteString("[S any](s S, w http.ResponseWriter, r *http.Request) {") parsedRequestType, ok := f.Type.Params.List[1].Type.(*ast.StructType) if !ok { return errors.New("Parsed request type must be a struct") } + parsedFunction := Function{ + Name: f.Name.Name, + Pattern: pattern, + RequestTypeDef: slice(fileContents, &fset, parsedRequestType.Pos(), parsedRequestType.End()), + } for _, field := range parsedRequestType.Fields.List { for _, nameIdent := range field.Names { typ := field.Type - name := nameIdent.Name + parsedField := RequestTypeField{ + Name: nameIdent.Name, + Extractor: "", + Optional: false, + TypeDef: slice(fileContents, &fset, typ.Pos(), typ.End()), + } + if parsedField.TypeDef == "*http.Request" { + parsedFunction.RequestTypeFields = append(parsedFunction.RequestTypeFields, parsedField) + continue + } var tag string if field.Tag != nil { tag = reflect.StructTag(field.Tag.Value[1 : len(field.Tag.Value)-1]).Get("hh") } - fmt.Println(typ, name, tag) - if t1, ok := typ.(*ast.StarExpr); ok { - if t2, ok := t1.X.(*ast.SelectorExpr); ok && t2.Sel.Name == "Request" { - if id, ok := t2.X.(*ast.Ident); ok && id.Name == "http" { - continue - } - } - } - if t1, ok := typ.(*ast.SelectorExpr); ok && t1.Sel.Name == "ResponseWriter" { - if id, ok := t1.X.(*ast.Ident); ok && id.Name == "http" { - continue - } - } if tag == "" { - return errors.New("Don't know what to do with '" + name + "'. You must add a tag to specify") + return errors.New("Don't know what to do with '" + parsedField.Name + "'. You must add a `hh:\"...\"` tag to specify") } tags := strings.Split(tag, ",") - // TODO: handle raw request. Or maybe that should be a separate parameter - optional := false if tags[0] == "optional" { - optional = true + parsedField.Optional = true tags = tags[1:] } - switch tags[0] { + if len(tags) == 0 { + return errors.New("Must specify extractor for '" + parsedField.Name + "' in `" + tag + "`") + } + parsedField.Extractor = tags[0] + tags = tags[1:] + switch parsedField.Extractor { case "form": - output.WriteString("\n\t") - output.WriteString(name) - output.WriteString(" := r.FormValue(\"") - output.WriteString(name) - output.WriteString("\")") + parsedFunction.DoParseForm = true case "cookie": - output.WriteString("\n\tvar ") - output.WriteString(name) - output.WriteString(" string\n\t") - output.WriteString(name) - output.WriteString("0, _ := r.Cookie(\"") - output.WriteString(name) // TODO: optionally get cookie name from tags[1] - output.WriteString("\")\n\tif ") - output.WriteString(name) - output.WriteString("0 != nil {\n\t\t") - output.WriteString(name) - output.WriteString(" = ") - output.WriteString(name) - output.WriteString("0.Value\n\t}") - if !optional { - output.WriteString(" else {\n\t\tw.WriteHeader(http.StatusBadRequest)\n\t\tw.Write([]byte(`Bad request. Missing cookie '") - output.WriteString(name) - output.WriteString("'`))\n\t}") - } default: - return errors.New("Unknown extractor " + tags[0]) + return errors.New("Unknown extractor '" + tags[0] + "' on field " + nameIdent.Name) } - output.WriteString("\n") - } - } - output.WriteString("\n\t") - output.WriteString(f.Name.Name) - output.WriteString("(w, ") - structDef := slice(fileContents, &fset, parsedRequestType.Pos(), parsedRequestType.End()) - output.WriteString(indent(structDef)[1:]) - output.WriteString("{") - for i, field := range parsedRequestType.Fields.List { - for j, nameIdent := range field.Names { - if i+j > 0 { - output.WriteString(", ") + if len(tags) >= 1 { + parsedField.NameInReq = tags[0] + tags = tags[1:] + } else { + parsedField.NameInReq = parsedField.Name } - typ := field.Type - name := nameIdent.Name - _, _ = typ, name - output.WriteString(name) - output.WriteString(": ") - output.WriteString(name) + if len(tags) > 0 { + return errors.New("Unexpected rest of tag '" + tags[0] + "' in tag `" + tag + "` on field " + nameIdent.Name) + } + parsedFunction.RequestTypeFields = append(parsedFunction.RequestTypeFields, parsedField) } } - output.WriteString("})\n") - output.WriteString("}\n") + parsedFile.Functions = append(parsedFile.Functions, parsedFunction) } - io.Copy(os.Stdout, &output) - fmt.Println(handlers) + fileTemplate.Execute(os.Stdout, parsedFile) return nil } diff --git a/cmd/generate/templates.go.tmpl b/cmd/generate/templates.go.tmpl new file mode 100644 index 0000000..9201ebf --- /dev/null +++ b/cmd/generate/templates.go.tmpl @@ -0,0 +1,62 @@ +package {{ .PackageName }} + +{{ range $_, $fn := .Functions }} +func hh_{{ $fn.Name }}[S any](s S, w http.ResponseWriter, r *http.Request) { + {{- if $fn.DoParseForm }} + if err := r.ParseForm(); err != nil { + panic("todo: Bad request") + } + {{ end }} + {{ range $_, $f := $fn.RequestTypeFields }} + {{ if eq $f.TypeDef "*http.Request" }} + {{ continue }} + {{ end }} + var {{ $f.Name }}0 string + {{- if $f.Optional }} + {{ $f.Name }}Skipped := false + {{- end }} + {{- if eq $f.Extractor "form" }} + {{ $f.Name }}1 := r.Form[{{ $f.NameInReq | quote }}] + if len({{ $f.Name }}1) != 0 { + {{ $f.Name }} = {{ $f.Name }}1[0] + } else { + {{- if not $f.Optional }} + panic("todo: Bad request: form value " + {{ $f.NameInReq | quote }} + " missing") + {{- else }} + {{ $f.Name }}Skipped = true + {{- end }} + } + {{ else if eq $f.Extractor "cookie" }} + {{ $f.Name }}1, _ := r.Cookie({{ $f.NameInReq | quote }}) + if {{ $f.Name }}1 != nil { + {{ $f.Name }} = {{ $f.Name }}1.Value + } else { + {{- if not $f.Optional }} + panic("todo: Bad request: cookie " + {{ $f.NameInReq | quote }} + " missing") + {{- else }} + {{ $f.Name }}Skipped = true + {{- end }} + } + {{ else }} + {{ error }} + {{ end -}} + {{ if eq $f.TypeDef "string" -}} + {{ $f.Name }} := {{ $f.Name }}0 + {{ else if eq $f.TypeDef "int" -}} + var {{ $f.Name }} int + {{ if $f.Optional }} if !{{ $f.Name }}Skipped { {{ end -}} + var err error + {{ $f.Name }}, err = strconv.Atoi({{ $f.Name }}0) + if err != nil { + panic("todo: Bad request: " + {{ $f.NameInReq | quote }} + " must be a valid int") + } + {{ if $f.Optional }} } {{ end }} + {{ end }} + {{ end }} + {{ $fn.Name }}(w, {{ $fn.RequestTypeDef }}{ + {{ range $_, $f := $fn.RequestTypeFields -}} + {{ $f.Name }}: {{ $f.Name }}, + {{ end }} + }) +} +{{ end }} |