summaryrefslogtreecommitdiff
path: root/cmd/generate
diff options
context:
space:
mode:
authorMathias Magnusson <mathias@magnusson.space>2025-04-10 19:06:03 +0200
committerMathias Magnusson <mathias@magnusson.space>2025-04-10 19:06:03 +0200
commitb9bf8a23c75db82e1aff8295a97dcfdf789735f3 (patch)
treeb6c57cb6747e9a1d56f243dd07f64cd15f1a9541 /cmd/generate
parent2778c52e4da52fd33f2df7fc9024252c2470b172 (diff)
downloadhh-b9bf8a23c75db82e1aff8295a97dcfdf789735f3.tar.gz
target is a module, not file; support path params; add function to mount all routes
Diffstat (limited to 'cmd/generate')
-rw-r--r--cmd/generate/main.go182
-rw-r--r--cmd/generate/templates.go.tmpl47
2 files changed, 147 insertions, 82 deletions
diff --git a/cmd/generate/main.go b/cmd/generate/main.go
index bdb46f2..06cf8fd 100644
--- a/cmd/generate/main.go
+++ b/cmd/generate/main.go
@@ -1,9 +1,11 @@
package main
import (
+ "bytes"
_ "embed"
"errors"
"go/ast"
+ "go/format"
"go/parser"
"go/token"
"os"
@@ -25,7 +27,7 @@ func slice(fileContents string, fset *token.FileSet, start token.Pos, end token.
return fileContents[fset.Position(start).Offset:fset.Position(end).Offset]
}
-type File struct {
+type Package struct {
PackageName string
Functions []Function
}
@@ -51,99 +53,129 @@ func run() error {
"quote": func(s string) string {
return `"` + strings.NewReplacer(`\`, `\\`, `"`, `\"`, "\n", `\n`).Replace(s) + `"`
},
- "error": func() struct{} {
- panic("error")
+ "error": func(msg string) struct{} {
+ panic("error in template: " + msg)
},
}).Parse(fileTemplateString)
if err != nil {
return err
}
- filename := "examples/basic.go"
- fileBytes, err := os.ReadFile(filename)
+ dirEntries, err := os.ReadDir(".")
if err != nil {
return err
}
- fileContents := string(fileBytes)
- var fset token.FileSet
- f, err := parser.ParseFile(&fset, "examples/basic.go", fileBytes, parser.ParseComments|parser.SkipObjectResolution)
- if err != nil {
- return err
- }
- parsedFile := File{PackageName: f.Name.Name}
- for _, decl := range f.Decls {
- f, ok := decl.(*ast.FuncDecl)
- if !ok {
+
+ parsedPackage := Package{}
+ for _, ent := range dirEntries {
+ if !strings.HasSuffix(ent.Name(), ".go") {
continue
}
- if f.Doc == nil {
- continue
+
+ fileBytes, err := os.ReadFile(ent.Name())
+ if err != nil {
+ return err
}
- hhRoute := f.Doc.List[len(f.Doc.List)-1].Text
- var pattern string
- if pattern, ok = strings.CutPrefix(hhRoute, "//hh:route "); !ok {
- continue
+ fileContents := string(fileBytes)
+ var fset token.FileSet
+ f, err := parser.ParseFile(&fset, ent.Name(), fileBytes, parser.ParseComments|parser.SkipObjectResolution)
+ if err != nil {
+ return err
}
- parsedRequestType, ok := f.Type.Params.List[1].Type.(*ast.StructType)
- if !ok {
- return errors.New("Parsed request type must be a struct")
+ if strings.HasSuffix(f.Name.String(), "_test") {
+ continue
}
- parsedFunction := Function{
- Name: f.Name.Name,
- Pattern: pattern,
- RequestTypeDef: slice(fileContents, &fset, parsedRequestType.Pos(), parsedRequestType.End()),
+ if parsedPackage.PackageName != "" && parsedPackage.PackageName != f.Name.String() {
+ return errors.New("Found two different package names in directory: " + parsedPackage.PackageName + " and " + f.Name.String())
}
- for _, field := range parsedRequestType.Fields.List {
- for _, nameIdent := range field.Names {
- typ := field.Type
- parsedField := RequestTypeField{
- Name: nameIdent.Name,
- Extractor: "",
- Optional: false,
- TypeDef: slice(fileContents, &fset, typ.Pos(), typ.End()),
- }
- if parsedField.TypeDef == "*http.Request" {
- parsedFunction.RequestTypeFields = append(parsedFunction.RequestTypeFields, parsedField)
- continue
- }
- var tag string
- if field.Tag != nil {
- tag = reflect.StructTag(field.Tag.Value[1 : len(field.Tag.Value)-1]).Get("hh")
- }
- if tag == "" {
- return errors.New("Don't know what to do with '" + parsedField.Name + "'. You must add a `hh:\"...\"` tag to specify")
- }
- tags := strings.Split(tag, ",")
- if tags[0] == "optional" {
- parsedField.Optional = true
- tags = tags[1:]
- }
- if len(tags) == 0 {
- return errors.New("Must specify extractor for '" + parsedField.Name + "' in `" + tag + "`")
- }
- parsedField.Extractor = tags[0]
- tags = tags[1:]
- switch parsedField.Extractor {
- case "form":
- parsedFunction.DoParseForm = true
- case "cookie":
- default:
- return errors.New("Unknown extractor '" + tags[0] + "' on field " + nameIdent.Name)
- }
- if len(tags) >= 1 {
- parsedField.NameInReq = tags[0]
+ parsedPackage.PackageName = f.Name.String()
+
+ for _, decl := range f.Decls {
+ f, ok := decl.(*ast.FuncDecl)
+ if !ok {
+ continue
+ }
+ if f.Doc == nil {
+ continue
+ }
+ hhRoute := f.Doc.List[len(f.Doc.List)-1].Text
+ var pattern string
+ if pattern, ok = strings.CutPrefix(hhRoute, "//hh:route "); !ok {
+ continue
+ }
+ parsedRequestType, ok := f.Type.Params.List[1].Type.(*ast.StructType)
+ if !ok {
+ return errors.New("Parsed request type must be a struct")
+ }
+ parsedFunction := Function{
+ Name: f.Name.Name,
+ Pattern: pattern,
+ RequestTypeDef: slice(fileContents, &fset, parsedRequestType.Pos(), parsedRequestType.End()),
+ }
+ for _, field := range parsedRequestType.Fields.List {
+ for _, nameIdent := range field.Names {
+ typ := field.Type
+ parsedField := RequestTypeField{
+ Name: nameIdent.Name,
+ Extractor: "",
+ Optional: false,
+ TypeDef: slice(fileContents, &fset, typ.Pos(), typ.End()),
+ }
+ if parsedField.TypeDef == "*http.Request" {
+ parsedFunction.RequestTypeFields = append(parsedFunction.RequestTypeFields, parsedField)
+ continue
+ }
+ var tag string
+ if field.Tag != nil {
+ tag = reflect.StructTag(field.Tag.Value[1 : len(field.Tag.Value)-1]).Get("hh")
+ }
+ if tag == "" {
+ return errors.New("Don't know what to do with '" + parsedField.Name + "'. You must add a `hh:\"...\"` tag to specify")
+ }
+ tags := strings.Split(tag, ",")
+ if tags[0] == "optional" {
+ parsedField.Optional = true
+ tags = tags[1:]
+ }
+ if len(tags) == 0 {
+ return errors.New("Must specify extractor for '" + parsedField.Name + "' in `" + tag + "`")
+ }
+ parsedField.Extractor = tags[0]
tags = tags[1:]
- } else {
- parsedField.NameInReq = parsedField.Name
- }
- if len(tags) > 0 {
- return errors.New("Unexpected rest of tag '" + tags[0] + "' in tag `" + tag + "` on field " + nameIdent.Name)
+ switch parsedField.Extractor {
+ case "form":
+ parsedFunction.DoParseForm = true
+ case "cookie", "path":
+ default:
+ return errors.New("Unknown extractor '" + tags[0] + "' on field " + nameIdent.Name)
+ }
+ if len(tags) >= 1 {
+ parsedField.NameInReq = tags[0]
+ tags = tags[1:]
+ } else {
+ parsedField.NameInReq = parsedField.Name
+ }
+ if len(tags) > 0 {
+ return errors.New("Unexpected rest of tag '" + tags[0] + "' in tag `" + tag + "` on field " + nameIdent.Name)
+ }
+ parsedFunction.RequestTypeFields = append(parsedFunction.RequestTypeFields, parsedField)
}
- parsedFunction.RequestTypeFields = append(parsedFunction.RequestTypeFields, parsedField)
}
+ parsedPackage.Functions = append(parsedPackage.Functions, parsedFunction)
}
- parsedFile.Functions = append(parsedFile.Functions, parsedFunction)
}
- fileTemplate.Execute(os.Stdout, parsedFile)
+
+ var unformatted bytes.Buffer
+ if err := fileTemplate.Execute(&unformatted, parsedPackage); err != nil {
+ return err
+ }
+ formatted, err := format.Source(unformatted.Bytes())
+ if err != nil {
+ return err
+ }
+ if err := os.WriteFile("hh.gen.go", formatted, 0o660); err != nil {
+ return err
+ }
+
return nil
}
diff --git a/cmd/generate/templates.go.tmpl b/cmd/generate/templates.go.tmpl
index 9201ebf..0e684e3 100644
--- a/cmd/generate/templates.go.tmpl
+++ b/cmd/generate/templates.go.tmpl
@@ -1,13 +1,35 @@
+// WARNING: this file has been automatically generated by
+// codeberg.org/fooelevator/hh. DO NOT EDIT MANUALLY!
+
package {{ .PackageName }}
+import (
+ "net/http"
+)
+
+func hhMountRoutes[S any](s S, mux *http.ServeMux) {
+ if mux == nil {
+ mux = http.DefaultServeMux
+ }
+ wrapper := func(handler func(s S, w http.ResponseWriter, r *http.Request)) http.Handler {
+ return http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
+ handler(s, w, r)
+ })
+ }
+
+ {{ range $_, $fn := .Functions -}}
+ mux.Handle({{ $fn.Pattern | quote }}, wrapper(hh_{{ $fn.Name }}))
+ {{ end }}
+}
+
{{ range $_, $fn := .Functions }}
func hh_{{ $fn.Name }}[S any](s S, w http.ResponseWriter, r *http.Request) {
{{- if $fn.DoParseForm }}
- if err := r.ParseForm(); err != nil {
- panic("todo: Bad request")
- }
+ if err := r.ParseForm(); err != nil {
+ panic("todo: Bad request")
+ }
{{ end }}
- {{ range $_, $f := $fn.RequestTypeFields }}
+ {{- range $_, $f := $fn.RequestTypeFields }}
{{ if eq $f.TypeDef "*http.Request" }}
{{ continue }}
{{ end }}
@@ -18,7 +40,7 @@ func hh_{{ $fn.Name }}[S any](s S, w http.ResponseWriter, r *http.Request) {
{{- if eq $f.Extractor "form" }}
{{ $f.Name }}1 := r.Form[{{ $f.NameInReq | quote }}]
if len({{ $f.Name }}1) != 0 {
- {{ $f.Name }} = {{ $f.Name }}1[0]
+ {{ $f.Name }}0 = {{ $f.Name }}1[0]
} else {
{{- if not $f.Optional }}
panic("todo: Bad request: form value " + {{ $f.NameInReq | quote }} + " missing")
@@ -26,10 +48,21 @@ func hh_{{ $fn.Name }}[S any](s S, w http.ResponseWriter, r *http.Request) {
{{ $f.Name }}Skipped = true
{{- end }}
}
+ {{ else if eq $f.Extractor "path" }}
+ {{ $f.Name }}1 := r.PathValue({{ $f.NameInReq | quote }})
+ if {{ $f.Name }}1 != "" {
+ {{ $f.Name }}0 = {{ $f.Name }}1
+ } else {
+ {{- if not $f.Optional }}
+ panic("todo: Bad request: path value " + {{ $f.NameInReq | quote }} + " missing")
+ {{- else }}
+ {{ $f.Name }}Skipped = true
+ {{- end }}
+ }
{{ else if eq $f.Extractor "cookie" }}
{{ $f.Name }}1, _ := r.Cookie({{ $f.NameInReq | quote }})
if {{ $f.Name }}1 != nil {
- {{ $f.Name }} = {{ $f.Name }}1.Value
+ {{ $f.Name }}0 = {{ $f.Name }}1.Value
} else {
{{- if not $f.Optional }}
panic("todo: Bad request: cookie " + {{ $f.NameInReq | quote }} + " missing")
@@ -38,7 +71,7 @@ func hh_{{ $fn.Name }}[S any](s S, w http.ResponseWriter, r *http.Request) {
{{- end }}
}
{{ else }}
- {{ error }}
+ {{ error "unknown extractor" }}
{{ end -}}
{{ if eq $f.TypeDef "string" -}}
{{ $f.Name }} := {{ $f.Name }}0