Skip to content

Commit

Permalink
路由注册使用 dst 管理 import
Browse files Browse the repository at this point in the history
  • Loading branch information
taoso committed Jan 12, 2022
1 parent 7697b19 commit b12f0d2
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 118 deletions.
129 changes: 44 additions & 85 deletions cmd/sniper/rpc/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,24 @@ package rpc
import (
"bytes"
"fmt"
"go/parser"
"go/token"
"os"
"strconv"
"strings"
"text/template"

"github.com/dave/dst"
"github.com/dave/dst/decorator"
"github.com/dave/dst/decorator/resolver/goast"
"github.com/dave/dst/decorator/resolver/simple"
)

func serverImported(imports []*dst.ImportSpec) bool {
rpc := server + "_v" + version
for _, i := range imports {
if i.Name != nil && i.Name.Name == rpc {
return true
}
}
return false
}

func serverRegistered(gen *dst.FuncDecl) bool {
for _, s := range gen.Body.List {
bs, ok := s.(*dst.BlockStmt)
if !ok {
continue
}
// s := &foo_v1.FooServer{}
ue, ok := bs.List[0].(*dst.AssignStmt).Rhs[0].(*dst.UnaryExpr)
if !ok {
continue
Expand All @@ -48,12 +40,13 @@ func serverRegistered(gen *dst.FuncDecl) bool {
return false
}

func genServerRoute(fd *dst.FuncDecl) {
if serverRegistered(fd) {
func genServerRoute(initMux *dst.FuncDecl) {
if serverRegistered(initMux) {
return
}

args := &regSrvTpl{
Package: module(),
Server: server,
Version: version,
Service: upper1st(service),
Expand All @@ -67,103 +60,69 @@ func genServerRoute(fd *dst.FuncDecl) {
panic(err)
}

s := token.NewFileSet()
f, err := decorator.ParseFile(s, "", buf.String(), parser.ParseComments)
if err != nil {
panic(err)
}

stmt := f.Decls[0].(*dst.FuncDecl).Body.List[0].(*dst.BlockStmt)
if len(fd.Body.List) > 0 {
stmt.Decs.Start.Replace("\n")
}
fd.Body.List = append(fd.Body.List, stmt)
}

func genImport(file *dst.File) {
if serverImported(file.Imports) {
return
}

args := impTpl{
Name: server + "_v" + version,
Path: fmt.Sprintf(`"%s/rpc/%s/v%s"`, module(), server, version),
}
t, err := template.New("sniper").Parse(args.tpl())
if err != nil {
panic(err)
}
buf := &bytes.Buffer{}
if err := t.Execute(buf, args); err != nil {
panic(err)
}

s := token.NewFileSet()
f, err := decorator.ParseFile(s, "", buf.String(), parser.ParseComments)
d := decorator.NewDecoratorWithImports(nil, "http", goast.New())
f, err := d.Parse(buf)
if err != nil {
panic(err)
}

spec := f.Decls[0].(*dst.GenDecl).Specs[0].(*dst.ImportSpec)
for _, decl := range file.Decls {
pkg, ok := decl.(*dst.GenDecl)
// http.go 导包分三组
//
// "net/http"
//
// "sniper/cmd/http/hooks"
//
// "github.com/go-kiss/sniper/pkg/twirp"
//
// 下面代码 rpc 包导入语句插入到上面第二组中
if ok && pkg.Tok == token.IMPORT {
i := 0 // 记录第二组最后一行的位置
for _, s := range pkg.Specs {
i++
is := s.(*dst.ImportSpec)
// 第二组的包都以项目包名开头
if !strings.Contains(is.Path.Value, module()+"/") {
continue
}
// 最后一行的 After 为 EmtyLine,表示下面是空行
if is.Decs.After != dst.EmptyLine {
continue
}
// 注册新路由后需要清理倒数第二行后面的空行
is.Decs.After = dst.NewLine
break
for _, d := range f.Decls {
if fd, ok := d.(*dst.FuncDecl); ok {
stmt := fd.Body.List[0].(*dst.BlockStmt)
if len(initMux.Body.List) > 0 {
stmt.Decs.Start.Replace("\n")
}
pkg.Specs = append(pkg.Specs[:i+1], pkg.Specs[i:]...)
pkg.Specs[i] = spec
initMux.Body.List = append(initMux.Body.List, stmt)
return
}
}
}

func registerServer() {
httpFile := "cmd/http/http.go"
fset := token.NewFileSet()
httpAST, err := decorator.ParseFile(fset, httpFile, nil, parser.ParseComments)
routeFile := "cmd/http/http.go"
b, err := os.ReadFile(routeFile)
if err != nil {
panic(err)
}
d := decorator.NewDecoratorWithImports(nil, "http", goast.New())
routeAst, err := d.Parse(b)
if err != nil {
panic(err)
}

genImport(httpAST)

// 处理注册路由
for _, decl := range httpAST.Decls {
for _, decl := range routeAst.Decls {
f, ok := decl.(*dst.FuncDecl)
if ok && f.Name.Name == "initMux" {
genServerRoute(f)
break
}
}

f, err := os.OpenFile(httpFile, os.O_WRONLY|os.O_CREATE, 0766)
f, err := os.OpenFile(routeFile, os.O_WRONLY|os.O_CREATE, 0766)
if err != nil {
return
}
defer f.Close()
if err := decorator.Fprint(f, httpAST); err != nil {

alias := server + "_v" + version
path := fmt.Sprintf(`%s/rpc/%s/v%s`, module(), server, version)
rr := simple.RestorerResolver{path: alias}
for _, i := range routeAst.Imports {
alias := ""
path, _ := strconv.Unquote(i.Path.Value)
if i.Name != nil {
alias = i.Name.Name
} else {
parts := strings.Split(path, "/")
alias = parts[len(parts)-1]
}
rr[path] = alias
}
r := decorator.NewRestorerWithImports("http", rr)
fr := r.FileRestorer()
fr.Alias[path] = alias
if err := fr.Fprint(f, routeAst); err != nil {
panic(err)
}
}
19 changes: 3 additions & 16 deletions cmd/sniper/rpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ func genOrUpdateServer() {
serverPath := fmt.Sprintf("rpc/%s/v%s/%s.go", server, version, service)
twirpPath := fmt.Sprintf("rpc/%s/v%s/%s.twirp.go", server, version, service)

serverAst := parseServerAst(serverPath, serverPkg)
twirpAst := parseServerAst(twirpPath, serverPkg)
serverAst := parseAst(serverPath, serverPkg)
twirpAst := parseAst(twirpPath, serverPkg)

for _, d := range twirpAst.Decls {
if it, ok := isInterfaceType(d); ok {
Expand Down Expand Up @@ -50,7 +50,7 @@ func isInterfaceType(d dst.Decl) (*dst.InterfaceType, bool) {
return it, true
}

func parseServerAst(path, pkg string) *dst.File {
func parseAst(path, pkg string) *dst.File {
d := decorator.NewDecoratorWithImports(nil, pkg, goast.New())
ast, err := d.Parse(readCode(path))
if err != nil {
Expand All @@ -59,19 +59,6 @@ func parseServerAst(path, pkg string) *dst.File {
return ast
}

func parseTwirpAst(path string) *dst.File {
b, err := os.ReadFile(path)
if err != nil {
panic(err)
}
ast, err := decorator.Parse(b)
if err != nil {
panic(err)
}

return ast
}

func readCode(serverFile string) []byte {
var code []byte
if fileExists(serverFile) {
Expand Down
22 changes: 5 additions & 17 deletions cmd/sniper/rpc/tpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,18 @@ type funcTpl struct {
func (t *funcTpl) tpl() string {
return `
func (s *{{.Service}}Server) {{.Name}}(ctx context.Context, req *{{.ReqType}}) (resp *{{.RespType}}, err error) {
{{ if eq .Name "Echo" }}
{{if eq .Name "Echo"}}
return &{{.Service}}EchoResp{Msg: req.Msg}, nil
{{ else }}
{{else}}
// FIXME 请开始你的表演
return
{{ end }}
{{end}}
}
`
}

type regSrvTpl struct {
Package string // 包名
Server string // 服务
Version string // 版本
Service string // 子服务
Expand All @@ -66,6 +67,7 @@ type regSrvTpl struct {
func (t *regSrvTpl) tpl() string {
return strings.TrimLeft(`
package main
import {{.Server}}_v{{.Version}} "{{.Package}}/rpc/{{.Server}}/v{{.Version}}"
func main() {
{
s := &{{.Server}}_v{{.Version}}.{{.Service}}Server{}
Expand All @@ -77,20 +79,6 @@ func main() {
`, "\n")
}

type impTpl struct {
Name string
Path string
}

func (t *impTpl) tpl() string {
return strings.TrimLeft(`
package main
import(
{{.Name}} {{.Path}}
)
`, "\n")
}

type protoTpl struct {
Server string // 服务
Version string // 版本
Expand Down

0 comments on commit b12f0d2

Please sign in to comment.