summaryrefslogtreecommitdiff
path: root/cmd/generate
diff options
context:
space:
mode:
authorMathias Magnusson <mathias@magnusson.space>2024-11-17 20:03:48 +0100
committerMathias Magnusson <mathias@magnusson.space>2024-11-17 20:03:48 +0100
commit2778c52e4da52fd33f2df7fc9024252c2470b172 (patch)
tree1134be8e985f828e31c609ac061b1275ed35f972 /cmd/generate
parent19fa57e67bcc4af13a252c17c0e18adab162d2d1 (diff)
downloadhh-2778c52e4da52fd33f2df7fc9024252c2470b172.tar.gz
use a template instead of a bunch of buffer.WriteString
Diffstat (limited to 'cmd/generate')
-rw-r--r--cmd/generate/main.go162
-rw-r--r--cmd/generate/templates.go.tmpl62
2 files changed, 137 insertions, 87 deletions
diff --git a/cmd/generate/main.go b/cmd/generate/main.go
index 4494690..bdb46f2 100644
--- a/cmd/generate/main.go
+++ b/cmd/generate/main.go
@@ -1,18 +1,20 @@
package main
import (
- "bytes"
+ _ "embed"
"errors"
- "fmt"
"go/ast"
"go/parser"
"go/token"
- "io"
"os"
"reflect"
"strings"
+ "text/template"
)
+//go:embed templates.go.tmpl
+var fileTemplateString string
+
func main() {
if err := run(); err != nil {
panic(err)
@@ -23,16 +25,40 @@ func slice(fileContents string, fset *token.FileSet, start token.Pos, end token.
return fileContents[fset.Position(start).Offset:fset.Position(end).Offset]
}
-func indent(s string) string {
- var output bytes.Buffer
- for _, line := range strings.SplitAfter(s, "\n") {
- output.WriteByte('\t')
- output.WriteString(line)
- }
- return output.String()
+type File struct {
+ PackageName string
+ Functions []Function
+}
+
+type Function struct {
+ Name string
+ Pattern string
+ RequestTypeDef string
+ RequestTypeFields []RequestTypeField
+ DoParseForm bool
+}
+
+type RequestTypeField struct {
+ Name string
+ Extractor string
+ Optional bool
+ NameInReq string
+ TypeDef string
}
func run() error {
+ fileTemplate, err := template.New("").Funcs(template.FuncMap{
+ "quote": func(s string) string {
+ return `"` + strings.NewReplacer(`\`, `\\`, `"`, `\"`, "\n", `\n`).Replace(s) + `"`
+ },
+ "error": func() struct{} {
+ panic("error")
+ },
+ }).Parse(fileTemplateString)
+ if err != nil {
+ return err
+ }
+
filename := "examples/basic.go"
fileBytes, err := os.ReadFile(filename)
if err != nil {
@@ -44,11 +70,7 @@ func run() error {
if err != nil {
return err
}
- var output bytes.Buffer
- output.WriteString("package ")
- output.WriteString(f.Name.Name)
- output.WriteByte('\n')
- handlers := map[string]string{}
+ parsedFile := File{PackageName: f.Name.Name}
for _, decl := range f.Decls {
f, ok := decl.(*ast.FuncDecl)
if !ok {
@@ -58,104 +80,70 @@ func run() error {
continue
}
hhRoute := f.Doc.List[len(f.Doc.List)-1].Text
- var routeSpec string
- if routeSpec, ok = strings.CutPrefix(hhRoute, "//hh:route "); !ok {
+ var pattern string
+ if pattern, ok = strings.CutPrefix(hhRoute, "//hh:route "); !ok {
continue
}
- handlers[routeSpec] = "hh_" + f.Name.String()
- output.WriteString("\nfunc hh_")
- output.WriteString(f.Name.String())
- output.WriteString("[S any](s S, w http.ResponseWriter, r *http.Request) {")
parsedRequestType, ok := f.Type.Params.List[1].Type.(*ast.StructType)
if !ok {
return errors.New("Parsed request type must be a struct")
}
+ parsedFunction := Function{
+ Name: f.Name.Name,
+ Pattern: pattern,
+ RequestTypeDef: slice(fileContents, &fset, parsedRequestType.Pos(), parsedRequestType.End()),
+ }
for _, field := range parsedRequestType.Fields.List {
for _, nameIdent := range field.Names {
typ := field.Type
- name := nameIdent.Name
+ parsedField := RequestTypeField{
+ Name: nameIdent.Name,
+ Extractor: "",
+ Optional: false,
+ TypeDef: slice(fileContents, &fset, typ.Pos(), typ.End()),
+ }
+ if parsedField.TypeDef == "*http.Request" {
+ parsedFunction.RequestTypeFields = append(parsedFunction.RequestTypeFields, parsedField)
+ continue
+ }
var tag string
if field.Tag != nil {
tag = reflect.StructTag(field.Tag.Value[1 : len(field.Tag.Value)-1]).Get("hh")
}
- fmt.Println(typ, name, tag)
- if t1, ok := typ.(*ast.StarExpr); ok {
- if t2, ok := t1.X.(*ast.SelectorExpr); ok && t2.Sel.Name == "Request" {
- if id, ok := t2.X.(*ast.Ident); ok && id.Name == "http" {
- continue
- }
- }
- }
- if t1, ok := typ.(*ast.SelectorExpr); ok && t1.Sel.Name == "ResponseWriter" {
- if id, ok := t1.X.(*ast.Ident); ok && id.Name == "http" {
- continue
- }
- }
if tag == "" {
- return errors.New("Don't know what to do with '" + name + "'. You must add a tag to specify")
+ return errors.New("Don't know what to do with '" + parsedField.Name + "'. You must add a `hh:\"...\"` tag to specify")
}
tags := strings.Split(tag, ",")
- // TODO: handle raw request. Or maybe that should be a separate parameter
- optional := false
if tags[0] == "optional" {
- optional = true
+ parsedField.Optional = true
tags = tags[1:]
}
- switch tags[0] {
+ if len(tags) == 0 {
+ return errors.New("Must specify extractor for '" + parsedField.Name + "' in `" + tag + "`")
+ }
+ parsedField.Extractor = tags[0]
+ tags = tags[1:]
+ switch parsedField.Extractor {
case "form":
- output.WriteString("\n\t")
- output.WriteString(name)
- output.WriteString(" := r.FormValue(\"")
- output.WriteString(name)
- output.WriteString("\")")
+ parsedFunction.DoParseForm = true
case "cookie":
- output.WriteString("\n\tvar ")
- output.WriteString(name)
- output.WriteString(" string\n\t")
- output.WriteString(name)
- output.WriteString("0, _ := r.Cookie(\"")
- output.WriteString(name) // TODO: optionally get cookie name from tags[1]
- output.WriteString("\")\n\tif ")
- output.WriteString(name)
- output.WriteString("0 != nil {\n\t\t")
- output.WriteString(name)
- output.WriteString(" = ")
- output.WriteString(name)
- output.WriteString("0.Value\n\t}")
- if !optional {
- output.WriteString(" else {\n\t\tw.WriteHeader(http.StatusBadRequest)\n\t\tw.Write([]byte(`Bad request. Missing cookie '")
- output.WriteString(name)
- output.WriteString("'`))\n\t}")
- }
default:
- return errors.New("Unknown extractor " + tags[0])
+ return errors.New("Unknown extractor '" + tags[0] + "' on field " + nameIdent.Name)
}
- output.WriteString("\n")
- }
- }
- output.WriteString("\n\t")
- output.WriteString(f.Name.Name)
- output.WriteString("(w, ")
- structDef := slice(fileContents, &fset, parsedRequestType.Pos(), parsedRequestType.End())
- output.WriteString(indent(structDef)[1:])
- output.WriteString("{")
- for i, field := range parsedRequestType.Fields.List {
- for j, nameIdent := range field.Names {
- if i+j > 0 {
- output.WriteString(", ")
+ if len(tags) >= 1 {
+ parsedField.NameInReq = tags[0]
+ tags = tags[1:]
+ } else {
+ parsedField.NameInReq = parsedField.Name
}
- typ := field.Type
- name := nameIdent.Name
- _, _ = typ, name
- output.WriteString(name)
- output.WriteString(": ")
- output.WriteString(name)
+ if len(tags) > 0 {
+ return errors.New("Unexpected rest of tag '" + tags[0] + "' in tag `" + tag + "` on field " + nameIdent.Name)
+ }
+ parsedFunction.RequestTypeFields = append(parsedFunction.RequestTypeFields, parsedField)
}
}
- output.WriteString("})\n")
- output.WriteString("}\n")
+ parsedFile.Functions = append(parsedFile.Functions, parsedFunction)
}
- io.Copy(os.Stdout, &output)
- fmt.Println(handlers)
+ fileTemplate.Execute(os.Stdout, parsedFile)
return nil
}
diff --git a/cmd/generate/templates.go.tmpl b/cmd/generate/templates.go.tmpl
new file mode 100644
index 0000000..9201ebf
--- /dev/null
+++ b/cmd/generate/templates.go.tmpl
@@ -0,0 +1,62 @@
+package {{ .PackageName }}
+
+{{ range $_, $fn := .Functions }}
+func hh_{{ $fn.Name }}[S any](s S, w http.ResponseWriter, r *http.Request) {
+ {{- if $fn.DoParseForm }}
+ if err := r.ParseForm(); err != nil {
+ panic("todo: Bad request")
+ }
+ {{ end }}
+ {{ range $_, $f := $fn.RequestTypeFields }}
+ {{ if eq $f.TypeDef "*http.Request" }}
+ {{ continue }}
+ {{ end }}
+ var {{ $f.Name }}0 string
+ {{- if $f.Optional }}
+ {{ $f.Name }}Skipped := false
+ {{- end }}
+ {{- if eq $f.Extractor "form" }}
+ {{ $f.Name }}1 := r.Form[{{ $f.NameInReq | quote }}]
+ if len({{ $f.Name }}1) != 0 {
+ {{ $f.Name }} = {{ $f.Name }}1[0]
+ } else {
+ {{- if not $f.Optional }}
+ panic("todo: Bad request: form value " + {{ $f.NameInReq | quote }} + " missing")
+ {{- else }}
+ {{ $f.Name }}Skipped = true
+ {{- end }}
+ }
+ {{ else if eq $f.Extractor "cookie" }}
+ {{ $f.Name }}1, _ := r.Cookie({{ $f.NameInReq | quote }})
+ if {{ $f.Name }}1 != nil {
+ {{ $f.Name }} = {{ $f.Name }}1.Value
+ } else {
+ {{- if not $f.Optional }}
+ panic("todo: Bad request: cookie " + {{ $f.NameInReq | quote }} + " missing")
+ {{- else }}
+ {{ $f.Name }}Skipped = true
+ {{- end }}
+ }
+ {{ else }}
+ {{ error }}
+ {{ end -}}
+ {{ if eq $f.TypeDef "string" -}}
+ {{ $f.Name }} := {{ $f.Name }}0
+ {{ else if eq $f.TypeDef "int" -}}
+ var {{ $f.Name }} int
+ {{ if $f.Optional }} if !{{ $f.Name }}Skipped { {{ end -}}
+ var err error
+ {{ $f.Name }}, err = strconv.Atoi({{ $f.Name }}0)
+ if err != nil {
+ panic("todo: Bad request: " + {{ $f.NameInReq | quote }} + " must be a valid int")
+ }
+ {{ if $f.Optional }} } {{ end }}
+ {{ end }}
+ {{ end }}
+ {{ $fn.Name }}(w, {{ $fn.RequestTypeDef }}{
+ {{ range $_, $f := $fn.RequestTypeFields -}}
+ {{ $f.Name }}: {{ $f.Name }},
+ {{ end }}
+ })
+}
+{{ end }}