summaryrefslogtreecommitdiff
path: root/cmd
diff options
context:
space:
mode:
Diffstat (limited to 'cmd')
-rw-r--r--cmd/generate/main.go211
1 files changed, 68 insertions, 143 deletions
diff --git a/cmd/generate/main.go b/cmd/generate/main.go
index 67e745c..4494690 100644
--- a/cmd/generate/main.go
+++ b/cmd/generate/main.go
@@ -6,14 +6,10 @@ import (
"fmt"
"go/ast"
"go/parser"
- "go/scanner"
"go/token"
"io"
- "net/http"
"os"
"reflect"
- "slices"
- "strconv"
"strings"
)
@@ -23,8 +19,17 @@ func main() {
}
}
-func slice(fileContents string, filePosInfo *token.File, start token.Pos, end token.Pos) string {
- return fileContents[filePosInfo.Position(start).Offset:filePosInfo.Position(end).Offset]
+func slice(fileContents string, fset *token.FileSet, start token.Pos, end token.Pos) string {
+ return fileContents[fset.Position(start).Offset:fset.Position(end).Offset]
+}
+
+func indent(s string) string {
+ var output bytes.Buffer
+ for _, line := range strings.SplitAfter(s, "\n") {
+ output.WriteByte('\t')
+ output.WriteString(line)
+ }
+ return output.String()
}
func run() error {
@@ -33,10 +38,8 @@ func run() error {
if err != nil {
return err
}
- // fileContents := string(fileBytes)
+ fileContents := string(fileBytes)
var fset token.FileSet
- fp := fset.AddFile(filename, -1, len(fileBytes))
- _ = fp
f, err := parser.ParseFile(&fset, "examples/basic.go", fileBytes, parser.ParseComments|parser.SkipObjectResolution)
if err != nil {
return err
@@ -45,6 +48,7 @@ func run() error {
output.WriteString("package ")
output.WriteString(f.Name.Name)
output.WriteByte('\n')
+ handlers := map[string]string{}
for _, decl := range f.Decls {
f, ok := decl.(*ast.FuncDecl)
if !ok {
@@ -58,30 +62,7 @@ func run() error {
if routeSpec, ok = strings.CutPrefix(hhRoute, "//hh:route "); !ok {
continue
}
- split := strings.Split(routeSpec, " ")
- var method, path string
- if len(split) == 1 {
- path = split[0]
- } else if len(split) == 2 {
- method = split[0]
- path = split[1]
- } else {
- return errors.New("Invalid route spec. Expected `//hh:route [method] [path]` or `//hh:route [path]`")
- }
- if !slices.ContainsFunc([]string{
- "",
- http.MethodGet,
- http.MethodHead,
- http.MethodPost,
- http.MethodPut,
- http.MethodPatch,
- http.MethodDelete,
- http.MethodConnect,
- http.MethodOptions,
- http.MethodTrace,
- }, func(m string) bool { return m == method }) {
- return errors.New("Invalid http method " + method)
- }
+ handlers[routeSpec] = "hh_" + f.Name.String()
output.WriteString("\nfunc hh_")
output.WriteString(f.Name.String())
output.WriteString("[S any](s S, w http.ResponseWriter, r *http.Request) {")
@@ -98,10 +79,27 @@ func run() error {
tag = reflect.StructTag(field.Tag.Value[1 : len(field.Tag.Value)-1]).Get("hh")
}
fmt.Println(typ, name, tag)
+ if t1, ok := typ.(*ast.StarExpr); ok {
+ if t2, ok := t1.X.(*ast.SelectorExpr); ok && t2.Sel.Name == "Request" {
+ if id, ok := t2.X.(*ast.Ident); ok && id.Name == "http" {
+ continue
+ }
+ }
+ }
+ if t1, ok := typ.(*ast.SelectorExpr); ok && t1.Sel.Name == "ResponseWriter" {
+ if id, ok := t1.X.(*ast.Ident); ok && id.Name == "http" {
+ continue
+ }
+ }
+ if tag == "" {
+ return errors.New("Don't know what to do with '" + name + "'. You must add a tag to specify")
+ }
tags := strings.Split(tag, ",")
// TODO: handle raw request. Or maybe that should be a separate parameter
- if len(tags) == 0 {
- return errors.New("Don't know what to do with '" + name + "'. You must add a tag to specify")
+ optional := false
+ if tags[0] == "optional" {
+ optional = true
+ tags = tags[1:]
}
switch tags[0] {
case "form":
@@ -111,126 +109,53 @@ func run() error {
output.WriteString(name)
output.WriteString("\")")
case "cookie":
- // panic("todo")
+ output.WriteString("\n\tvar ")
+ output.WriteString(name)
+ output.WriteString(" string\n\t")
+ output.WriteString(name)
+ output.WriteString("0, _ := r.Cookie(\"")
+ output.WriteString(name) // TODO: optionally get cookie name from tags[1]
+ output.WriteString("\")\n\tif ")
+ output.WriteString(name)
+ output.WriteString("0 != nil {\n\t\t")
+ output.WriteString(name)
+ output.WriteString(" = ")
+ output.WriteString(name)
+ output.WriteString("0.Value\n\t}")
+ if !optional {
+ output.WriteString(" else {\n\t\tw.WriteHeader(http.StatusBadRequest)\n\t\tw.Write([]byte(`Bad request. Missing cookie '")
+ output.WriteString(name)
+ output.WriteString("'`))\n\t}")
+ }
+ default:
+ return errors.New("Unknown extractor " + tags[0])
}
+ output.WriteString("\n")
}
}
- i := 0
output.WriteString("\n\t")
output.WriteString(f.Name.Name)
output.WriteString("(w, ")
- i = 0
- for _, field := range parsedRequestType.Fields.List {
- for _, nameIdent := range field.Names {
- typ := field.Type
- name := nameIdent.Name
- if i > 0 {
+ structDef := slice(fileContents, &fset, parsedRequestType.Pos(), parsedRequestType.End())
+ output.WriteString(indent(structDef)[1:])
+ output.WriteString("{")
+ for i, field := range parsedRequestType.Fields.List {
+ for j, nameIdent := range field.Names {
+ if i+j > 0 {
output.WriteString(", ")
}
- output.WriteString("var")
- output.WriteString(strconv.Itoa(i))
- i++
+ typ := field.Type
+ name := nameIdent.Name
+ _, _ = typ, name
+ output.WriteString(name)
+ output.WriteString(": ")
+ output.WriteString(name)
}
}
- output.WriteString(")\n")
+ output.WriteString("})\n")
output.WriteString("}\n")
-
- fmt.Printf("`%v`\n`%v`\n", path, f.Name.Name)
}
io.Copy(os.Stdout, &output)
+ fmt.Println(handlers)
return nil
}
-
-type routeSpec struct {
- method string
- path string
- parameters []routeSpecParam
-}
-
-type routeSpecParam struct {
- name string
- extractor string
-}
-
-func parseRouteSpec(s string) (routeSpec, error) {
- s = strings.TrimRight(s, " \n")
- var rs routeSpec
- for _, method := range []string{
- http.MethodGet,
- http.MethodHead,
- http.MethodPost,
- http.MethodPut,
- http.MethodPatch,
- http.MethodDelete,
- http.MethodConnect,
- http.MethodOptions,
- http.MethodTrace,
- } {
- if rest, ok := strings.CutPrefix(s, method+" "); ok {
- s = rest
- rs.method = method
- break
- }
- }
-
- for {
- if commaPos := strings.IndexByte(s, ','); commaPos != -1 {
- rs.path = s[:commaPos]
- s = s[commaPos+1:]
- } else {
- rs.path = s
- return rs, nil
- }
-
- s = strings.TrimLeft(s, " ")
- if s == "" {
- break
- }
-
- end := -1
- for i := 0; i < len(s); i++ {
- if 'a' <= s[i] && s[i] <= 'z' ||
- 'A' <= s[i] && s[i] <= 'Z' ||
- s[i] == '_' {
- continue
- }
- if s[i] == ':' {
- end = i
- break
- }
- return rs, errors.New("Expected ':' to mark end of parameter name, got " + s[i:i+1])
- }
- if end == -1 {
- return rs, errors.New("Expected ':' to mark end of parameter name, got end of line")
- }
- var p routeSpecParam
- p.name = s[:end]
- s = s[end+1:]
-
- s = strings.TrimLeft(s, " ")
-
- expr, err := parser.ParseExpr(s)
- if el, ok := err.(scanner.ErrorList); err == nil || ok && el[0].Msg == "expected 'EOF', found ','" {
- switch expr := expr.(type) {
- case *ast.Ident:
- switch expr.Name {
- case "query":
- p.extractor = ":" + expr.Name
- default:
- return rs, errors.New("Unexpected extractor " + expr.Name)
- }
- case *ast.CallExpr:
- p.extractor = s[expr.Pos()-1 : expr.End()-1]
- default:
- return rs, errors.New("Unexpected extractor" + s[expr.Pos()-1:expr.End()-1])
- }
- } else {
- return rs, err
- }
- s = s[expr.End()-1:]
-
- rs.parameters = append(rs.parameters, p)
- }
-
- return rs, nil
-}