diff options
Diffstat (limited to 'cmd/generate')
-rw-r--r-- | cmd/generate/main.go | 211 |
1 files changed, 68 insertions, 143 deletions
diff --git a/cmd/generate/main.go b/cmd/generate/main.go index 67e745c..4494690 100644 --- a/cmd/generate/main.go +++ b/cmd/generate/main.go @@ -6,14 +6,10 @@ import ( "fmt" "go/ast" "go/parser" - "go/scanner" "go/token" "io" - "net/http" "os" "reflect" - "slices" - "strconv" "strings" ) @@ -23,8 +19,17 @@ func main() { } } -func slice(fileContents string, filePosInfo *token.File, start token.Pos, end token.Pos) string { - return fileContents[filePosInfo.Position(start).Offset:filePosInfo.Position(end).Offset] +func slice(fileContents string, fset *token.FileSet, start token.Pos, end token.Pos) string { + return fileContents[fset.Position(start).Offset:fset.Position(end).Offset] +} + +func indent(s string) string { + var output bytes.Buffer + for _, line := range strings.SplitAfter(s, "\n") { + output.WriteByte('\t') + output.WriteString(line) + } + return output.String() } func run() error { @@ -33,10 +38,8 @@ func run() error { if err != nil { return err } - // fileContents := string(fileBytes) + fileContents := string(fileBytes) var fset token.FileSet - fp := fset.AddFile(filename, -1, len(fileBytes)) - _ = fp f, err := parser.ParseFile(&fset, "examples/basic.go", fileBytes, parser.ParseComments|parser.SkipObjectResolution) if err != nil { return err @@ -45,6 +48,7 @@ func run() error { output.WriteString("package ") output.WriteString(f.Name.Name) output.WriteByte('\n') + handlers := map[string]string{} for _, decl := range f.Decls { f, ok := decl.(*ast.FuncDecl) if !ok { @@ -58,30 +62,7 @@ func run() error { if routeSpec, ok = strings.CutPrefix(hhRoute, "//hh:route "); !ok { continue } - split := strings.Split(routeSpec, " ") - var method, path string - if len(split) == 1 { - path = split[0] - } else if len(split) == 2 { - method = split[0] - path = split[1] - } else { - return errors.New("Invalid route spec. Expected `//hh:route [method] [path]` or `//hh:route [path]`") - } - if !slices.ContainsFunc([]string{ - "", - http.MethodGet, - http.MethodHead, - http.MethodPost, - http.MethodPut, - http.MethodPatch, - http.MethodDelete, - http.MethodConnect, - http.MethodOptions, - http.MethodTrace, - }, func(m string) bool { return m == method }) { - return errors.New("Invalid http method " + method) - } + handlers[routeSpec] = "hh_" + f.Name.String() output.WriteString("\nfunc hh_") output.WriteString(f.Name.String()) output.WriteString("[S any](s S, w http.ResponseWriter, r *http.Request) {") @@ -98,10 +79,27 @@ func run() error { tag = reflect.StructTag(field.Tag.Value[1 : len(field.Tag.Value)-1]).Get("hh") } fmt.Println(typ, name, tag) + if t1, ok := typ.(*ast.StarExpr); ok { + if t2, ok := t1.X.(*ast.SelectorExpr); ok && t2.Sel.Name == "Request" { + if id, ok := t2.X.(*ast.Ident); ok && id.Name == "http" { + continue + } + } + } + if t1, ok := typ.(*ast.SelectorExpr); ok && t1.Sel.Name == "ResponseWriter" { + if id, ok := t1.X.(*ast.Ident); ok && id.Name == "http" { + continue + } + } + if tag == "" { + return errors.New("Don't know what to do with '" + name + "'. You must add a tag to specify") + } tags := strings.Split(tag, ",") // TODO: handle raw request. Or maybe that should be a separate parameter - if len(tags) == 0 { - return errors.New("Don't know what to do with '" + name + "'. You must add a tag to specify") + optional := false + if tags[0] == "optional" { + optional = true + tags = tags[1:] } switch tags[0] { case "form": @@ -111,126 +109,53 @@ func run() error { output.WriteString(name) output.WriteString("\")") case "cookie": - // panic("todo") + output.WriteString("\n\tvar ") + output.WriteString(name) + output.WriteString(" string\n\t") + output.WriteString(name) + output.WriteString("0, _ := r.Cookie(\"") + output.WriteString(name) // TODO: optionally get cookie name from tags[1] + output.WriteString("\")\n\tif ") + output.WriteString(name) + output.WriteString("0 != nil {\n\t\t") + output.WriteString(name) + output.WriteString(" = ") + output.WriteString(name) + output.WriteString("0.Value\n\t}") + if !optional { + output.WriteString(" else {\n\t\tw.WriteHeader(http.StatusBadRequest)\n\t\tw.Write([]byte(`Bad request. Missing cookie '") + output.WriteString(name) + output.WriteString("'`))\n\t}") + } + default: + return errors.New("Unknown extractor " + tags[0]) } + output.WriteString("\n") } } - i := 0 output.WriteString("\n\t") output.WriteString(f.Name.Name) output.WriteString("(w, ") - i = 0 - for _, field := range parsedRequestType.Fields.List { - for _, nameIdent := range field.Names { - typ := field.Type - name := nameIdent.Name - if i > 0 { + structDef := slice(fileContents, &fset, parsedRequestType.Pos(), parsedRequestType.End()) + output.WriteString(indent(structDef)[1:]) + output.WriteString("{") + for i, field := range parsedRequestType.Fields.List { + for j, nameIdent := range field.Names { + if i+j > 0 { output.WriteString(", ") } - output.WriteString("var") - output.WriteString(strconv.Itoa(i)) - i++ + typ := field.Type + name := nameIdent.Name + _, _ = typ, name + output.WriteString(name) + output.WriteString(": ") + output.WriteString(name) } } - output.WriteString(")\n") + output.WriteString("})\n") output.WriteString("}\n") - - fmt.Printf("`%v`\n`%v`\n", path, f.Name.Name) } io.Copy(os.Stdout, &output) + fmt.Println(handlers) return nil } - -type routeSpec struct { - method string - path string - parameters []routeSpecParam -} - -type routeSpecParam struct { - name string - extractor string -} - -func parseRouteSpec(s string) (routeSpec, error) { - s = strings.TrimRight(s, " \n") - var rs routeSpec - for _, method := range []string{ - http.MethodGet, - http.MethodHead, - http.MethodPost, - http.MethodPut, - http.MethodPatch, - http.MethodDelete, - http.MethodConnect, - http.MethodOptions, - http.MethodTrace, - } { - if rest, ok := strings.CutPrefix(s, method+" "); ok { - s = rest - rs.method = method - break - } - } - - for { - if commaPos := strings.IndexByte(s, ','); commaPos != -1 { - rs.path = s[:commaPos] - s = s[commaPos+1:] - } else { - rs.path = s - return rs, nil - } - - s = strings.TrimLeft(s, " ") - if s == "" { - break - } - - end := -1 - for i := 0; i < len(s); i++ { - if 'a' <= s[i] && s[i] <= 'z' || - 'A' <= s[i] && s[i] <= 'Z' || - s[i] == '_' { - continue - } - if s[i] == ':' { - end = i - break - } - return rs, errors.New("Expected ':' to mark end of parameter name, got " + s[i:i+1]) - } - if end == -1 { - return rs, errors.New("Expected ':' to mark end of parameter name, got end of line") - } - var p routeSpecParam - p.name = s[:end] - s = s[end+1:] - - s = strings.TrimLeft(s, " ") - - expr, err := parser.ParseExpr(s) - if el, ok := err.(scanner.ErrorList); err == nil || ok && el[0].Msg == "expected 'EOF', found ','" { - switch expr := expr.(type) { - case *ast.Ident: - switch expr.Name { - case "query": - p.extractor = ":" + expr.Name - default: - return rs, errors.New("Unexpected extractor " + expr.Name) - } - case *ast.CallExpr: - p.extractor = s[expr.Pos()-1 : expr.End()-1] - default: - return rs, errors.New("Unexpected extractor" + s[expr.Pos()-1:expr.End()-1]) - } - } else { - return rs, err - } - s = s[expr.End()-1:] - - rs.parameters = append(rs.parameters, p) - } - - return rs, nil -} |