548 lines
13 KiB
Go
548 lines
13 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"flag"
|
||
|
"fmt"
|
||
|
"go/ast"
|
||
|
"go/parser"
|
||
|
"go/token"
|
||
|
"os"
|
||
|
"path"
|
||
|
"path/filepath"
|
||
|
"reflect"
|
||
|
"runtime"
|
||
|
"strings"
|
||
|
)
|
||
|
|
||
|
// gloabl var.
|
||
|
var (
|
||
|
ErrParams = errors.New("err params")
|
||
|
_gopath = filepath.SplitList(os.Getenv("GOPATH"))
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
dir string
|
||
|
pkgs = make(map[string]*ast.Package)
|
||
|
rlpkgs = make(map[string]*ast.Package)
|
||
|
definitions = make(map[string]*Schema)
|
||
|
swagger = Swagger{
|
||
|
Definitions: make(map[string]*Schema),
|
||
|
Paths: make(map[string]*Item),
|
||
|
SwaggerVersion: "2.0",
|
||
|
Infos: Information{
|
||
|
Title: "go-common api",
|
||
|
Description: "api",
|
||
|
Version: "1.0",
|
||
|
Contact: Contact{
|
||
|
EMail: "lintanghui@bilibili.com",
|
||
|
},
|
||
|
License: &License{
|
||
|
Name: "Apache 2.0",
|
||
|
URL: "http://www.apache.org/licenses/LICENSE-2.0.html",
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
stdlibObject = map[string]string{
|
||
|
"&{time Time}": "time.Time",
|
||
|
}
|
||
|
)
|
||
|
|
||
|
// refer to builtin.go
|
||
|
var basicTypes = map[string]string{
|
||
|
"bool": "boolean:",
|
||
|
"uint": "integer:int32",
|
||
|
"uint8": "integer:int32",
|
||
|
"uint16": "integer:int32",
|
||
|
"uint32": "integer:int32",
|
||
|
"uint64": "integer:int64",
|
||
|
"int": "integer:int64",
|
||
|
"int8": "integer:int32",
|
||
|
"int16": "integer:int32",
|
||
|
"int32": "integer:int32",
|
||
|
"int64": "integer:int64",
|
||
|
"uintptr": "integer:int64",
|
||
|
"float32": "number:float",
|
||
|
"float64": "number:double",
|
||
|
"string": "string:",
|
||
|
"complex64": "number:float",
|
||
|
"complex128": "number:double",
|
||
|
"byte": "string:byte",
|
||
|
"rune": "string:byte",
|
||
|
// builtin golang objects
|
||
|
"time.Time": "string:string",
|
||
|
}
|
||
|
|
||
|
func main() {
|
||
|
flag.StringVar(&dir, "d", "./", "specific project dir")
|
||
|
flag.Parse()
|
||
|
err := ParseFromDir(dir)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
parseModel(pkgs)
|
||
|
parseModel(rlpkgs)
|
||
|
parseRouter()
|
||
|
fd, err := os.Create(path.Join(dir, "swagger.json"))
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
b, _ := json.MarshalIndent(swagger, "", " ")
|
||
|
fd.Write(b)
|
||
|
}
|
||
|
|
||
|
// ParseFromDir parse ast pkg from dir.
|
||
|
func ParseFromDir(dir string) (err error) {
|
||
|
filepath.Walk(dir, func(fpath string, fileInfo os.FileInfo, err error) error {
|
||
|
if err != nil {
|
||
|
return nil
|
||
|
}
|
||
|
if !fileInfo.IsDir() {
|
||
|
return nil
|
||
|
}
|
||
|
err = parseFromDir(fpath)
|
||
|
return err
|
||
|
})
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func parseFromDir(dir string) (err error) {
|
||
|
fset := token.NewFileSet()
|
||
|
pkgFolder, err := parser.ParseDir(fset, dir, func(info os.FileInfo) bool {
|
||
|
name := info.Name()
|
||
|
return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
|
||
|
}, parser.ParseComments)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
for k, p := range pkgFolder {
|
||
|
pkgs[k] = p
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
func parseImport(dir string) (err error) {
|
||
|
fset := token.NewFileSet()
|
||
|
pkgFolder, err := parser.ParseDir(fset, dir, func(info os.FileInfo) bool {
|
||
|
name := info.Name()
|
||
|
return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
|
||
|
}, parser.ParseComments)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
for k, p := range pkgFolder {
|
||
|
rlpkgs[k] = p
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
func parseModel(pkgs map[string]*ast.Package) {
|
||
|
for _, p := range pkgs {
|
||
|
for _, f := range p.Files {
|
||
|
for _, im := range f.Imports {
|
||
|
if !isSystemPackage(im.Path.Value) {
|
||
|
for _, gp := range _gopath {
|
||
|
path := gp + "/src/" + strings.Trim(im.Path.Value, "\"")
|
||
|
if isExist(path) {
|
||
|
parseImport(path)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
scom := parseStructComment(f)
|
||
|
for _, obj := range f.Scope.Objects {
|
||
|
if obj.Kind == ast.Typ {
|
||
|
objName := obj.Name
|
||
|
schema := &Schema{
|
||
|
Title: objName,
|
||
|
Type: "object",
|
||
|
}
|
||
|
ts, ok := obj.Decl.(*ast.TypeSpec)
|
||
|
if !ok {
|
||
|
fmt.Printf("obj type error %v ", obj.Kind)
|
||
|
}
|
||
|
st, ok := ts.Type.(*ast.StructType)
|
||
|
if !ok {
|
||
|
continue
|
||
|
}
|
||
|
properites := make(map[string]*Propertie)
|
||
|
for _, fd := range st.Fields.List {
|
||
|
if len(fd.Names) == 0 {
|
||
|
continue
|
||
|
}
|
||
|
name, required, omit, desc := parseFieldTag(fd)
|
||
|
if omit {
|
||
|
continue
|
||
|
}
|
||
|
isSlice, realType, sType := typeAnalyser(fd)
|
||
|
if (isSlice && isBasicType(realType)) || sType == "object" {
|
||
|
if len(strings.Split(realType, " ")) > 1 {
|
||
|
realType = strings.Replace(realType, " ", ".", -1)
|
||
|
realType = strings.Replace(realType, "&", "", -1)
|
||
|
realType = strings.Replace(realType, "{", "", -1)
|
||
|
realType = strings.Replace(realType, "}", "", -1)
|
||
|
}
|
||
|
}
|
||
|
mp := &Propertie{}
|
||
|
if isSlice {
|
||
|
mp.Type = "array"
|
||
|
if isBasicType(strings.Replace(realType, "[]", "", -1)) {
|
||
|
typeFormat := strings.Split(sType, ":")
|
||
|
mp.Items = &Propertie{
|
||
|
Type: typeFormat[0],
|
||
|
Format: typeFormat[1],
|
||
|
}
|
||
|
} else {
|
||
|
ss := strings.Split(realType, ".")
|
||
|
mp.RefImport = ss[len(ss)-1]
|
||
|
mp.Type = "array"
|
||
|
mp.Items = &Propertie{
|
||
|
Ref: "#/definitions/" + mp.RefImport,
|
||
|
Type: sType,
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
if sType == "object" {
|
||
|
ss := strings.Split(realType, ".")
|
||
|
mp.RefImport = ss[len(ss)-1]
|
||
|
mp.Type = sType
|
||
|
mp.Ref = "#/definitions/" + mp.RefImport
|
||
|
} else if isBasicType(realType) {
|
||
|
typeFormat := strings.Split(sType, ":")
|
||
|
mp.Type = typeFormat[0]
|
||
|
mp.Format = typeFormat[1]
|
||
|
} else if realType == "map" {
|
||
|
typeFormat := strings.Split(sType, ":")
|
||
|
mp.AdditionalProperties = &Propertie{
|
||
|
Type: typeFormat[0],
|
||
|
Format: typeFormat[1],
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
if name == "" {
|
||
|
name = fd.Names[0].Name
|
||
|
}
|
||
|
if required {
|
||
|
schema.Required = append(schema.Required, name)
|
||
|
}
|
||
|
mp.Description = desc
|
||
|
if scm, ok := scom[obj.Name]; ok {
|
||
|
if cm, ok := scm.field[fd.Names[0].Name]; ok {
|
||
|
mp.Description = cm + desc
|
||
|
}
|
||
|
}
|
||
|
properites[name] = mp
|
||
|
}
|
||
|
if scm, ok := scom[obj.Name]; ok {
|
||
|
schema.Description = scm.comment
|
||
|
}
|
||
|
schema.Properties = properites
|
||
|
definitions[schema.Title] = schema
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
func parseFieldTag(field *ast.Field) (name string, required, omit bool, tagDes string) {
|
||
|
if field.Tag == nil {
|
||
|
return
|
||
|
}
|
||
|
tag := reflect.StructTag(strings.Trim(field.Tag.Value, "`"))
|
||
|
param := tag.Get("form")
|
||
|
if param != "" {
|
||
|
params := strings.Split(param, ",")
|
||
|
if len(params) > 0 {
|
||
|
name = params[0]
|
||
|
}
|
||
|
if len(params) == 2 && params[1] == "split" {
|
||
|
tagDes = "数组,按逗号分隔"
|
||
|
}
|
||
|
}
|
||
|
if def := tag.Get("default"); def != "" {
|
||
|
tagDes = fmt.Sprintf("%s 默认值 %s", tagDes, def)
|
||
|
}
|
||
|
validate := tag.Get("validate")
|
||
|
if validate != "" {
|
||
|
params := strings.Split(validate, ",")
|
||
|
for _, param := range params {
|
||
|
switch {
|
||
|
case param == "required":
|
||
|
required = true
|
||
|
case strings.HasPrefix(param, "min"):
|
||
|
tagDes = fmt.Sprintf("%s 最小值 %s", tagDes, strings.Split(param, "=")[1])
|
||
|
case strings.HasPrefix(param, "max"):
|
||
|
tagDes = fmt.Sprintf("%s 最大值 %s", tagDes, strings.Split(param, "=")[1])
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
// parse json response.
|
||
|
json := tag.Get("json")
|
||
|
if json != "" {
|
||
|
jsons := strings.Split(json, ",")
|
||
|
if len(jsons) > 0 {
|
||
|
if jsons[0] == "-" {
|
||
|
omit = true
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func parseRouter() {
|
||
|
for _, p := range pkgs {
|
||
|
if p.Name != "http" {
|
||
|
continue
|
||
|
}
|
||
|
fmt.Printf("开始解析生成swagger文档\n")
|
||
|
for _, f := range p.Files {
|
||
|
for _, decl := range f.Decls {
|
||
|
if fdecl, ok := decl.(*ast.FuncDecl); ok {
|
||
|
if fdecl.Doc != nil {
|
||
|
path, req, resp, item, err := parseFuncDoc(fdecl.Doc)
|
||
|
if err != nil {
|
||
|
fmt.Printf("解析失败 注解错误 %v\n", err)
|
||
|
continue
|
||
|
}
|
||
|
if path != "" && err == nil {
|
||
|
fmt.Printf("解析 %s 完成 请求参数为 %s 返回结构为 %s\n", path, req, resp)
|
||
|
swagger.Paths[path] = item
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
func parseFuncDoc(f *ast.CommentGroup) (path, reqObj, respObj string, item *Item, err error) {
|
||
|
item = new(Item)
|
||
|
op := new(Operation)
|
||
|
params := make([]*Parameter, 0)
|
||
|
response := make(map[string]*Response)
|
||
|
for _, d := range f.List {
|
||
|
t := strings.TrimSpace(strings.TrimPrefix(d.Text, "//"))
|
||
|
content := strings.Split(t, " ")
|
||
|
switch content[0] {
|
||
|
case "@params":
|
||
|
if len(content) < 2 {
|
||
|
err = fmt.Errorf("err params %s", content)
|
||
|
return
|
||
|
}
|
||
|
reqObj = content[1]
|
||
|
if model, ok := definitions[content[1]]; ok {
|
||
|
for n, p := range model.Properties {
|
||
|
param := &Parameter{
|
||
|
In: "query",
|
||
|
Name: n,
|
||
|
Description: p.Description,
|
||
|
Type: p.Type,
|
||
|
Format: p.Format,
|
||
|
}
|
||
|
for _, p := range model.Required {
|
||
|
if p == n {
|
||
|
param.Required = true
|
||
|
}
|
||
|
}
|
||
|
params = append(params, param)
|
||
|
}
|
||
|
} else {
|
||
|
err = ErrParams
|
||
|
return
|
||
|
}
|
||
|
case "@router":
|
||
|
if len(content) != 3 {
|
||
|
err = ErrParams
|
||
|
return
|
||
|
}
|
||
|
switch content[1] {
|
||
|
case "get":
|
||
|
item.Get = op
|
||
|
case "post":
|
||
|
item.Post = op
|
||
|
}
|
||
|
path = content[2]
|
||
|
op.OperationID = path
|
||
|
case "@response":
|
||
|
if len(content) < 2 {
|
||
|
err = fmt.Errorf("err response %s", content)
|
||
|
return
|
||
|
}
|
||
|
var (
|
||
|
isarray bool
|
||
|
ismap bool
|
||
|
)
|
||
|
if strings.HasPrefix(content[1], "[]") {
|
||
|
isarray = true
|
||
|
respObj = content[1][2:]
|
||
|
} else if strings.HasPrefix(content[1], "map[]") {
|
||
|
ismap = true
|
||
|
respObj = content[1][5:]
|
||
|
} else {
|
||
|
respObj = content[1]
|
||
|
}
|
||
|
defini, ok := definitions[respObj]
|
||
|
if !ok {
|
||
|
err = ErrParams
|
||
|
return
|
||
|
}
|
||
|
var resp *Propertie
|
||
|
if isarray {
|
||
|
resp = &Propertie{
|
||
|
Type: "array",
|
||
|
Items: &Propertie{
|
||
|
Type: "object",
|
||
|
Ref: "#/definitions/" + respObj,
|
||
|
},
|
||
|
}
|
||
|
} else if ismap {
|
||
|
resp = &Propertie{
|
||
|
Type: "object",
|
||
|
AdditionalProperties: &Propertie{
|
||
|
Ref: "#/definitions/" + respObj,
|
||
|
},
|
||
|
}
|
||
|
} else {
|
||
|
resp = &Propertie{
|
||
|
Type: "object",
|
||
|
Ref: "#/definitions/" + respObj,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
response["200"] = &Response{
|
||
|
Schema: &Schema{
|
||
|
Type: "object",
|
||
|
Properties: map[string]*Propertie{
|
||
|
"code": &Propertie{
|
||
|
Type: "integer",
|
||
|
Description: "错误码描述",
|
||
|
},
|
||
|
"data": resp,
|
||
|
"message": &Propertie{
|
||
|
Type: "string",
|
||
|
Description: "错误码文本描述",
|
||
|
},
|
||
|
"ttl": &Propertie{
|
||
|
Type: "integer",
|
||
|
Format: "int64",
|
||
|
Description: "客户端限速时间",
|
||
|
},
|
||
|
},
|
||
|
},
|
||
|
Description: "服务成功响应内容",
|
||
|
}
|
||
|
op.Responses = response
|
||
|
for _, rl := range defini.Properties {
|
||
|
if rl.RefImport != "" {
|
||
|
swagger.Definitions[rl.RefImport] = definitions[rl.RefImport]
|
||
|
}
|
||
|
}
|
||
|
swagger.Definitions[respObj] = defini
|
||
|
case "@description":
|
||
|
op.Description = content[1]
|
||
|
}
|
||
|
}
|
||
|
op.Parameters = params
|
||
|
return
|
||
|
}
|
||
|
|
||
|
type structComment struct {
|
||
|
comment string
|
||
|
field map[string]string
|
||
|
}
|
||
|
|
||
|
func parseStructComment(f *ast.File) (scom map[string]structComment) {
|
||
|
scom = make(map[string]structComment)
|
||
|
for _, d := range f.Decls {
|
||
|
switch specDecl := d.(type) {
|
||
|
case *ast.GenDecl:
|
||
|
if specDecl.Tok == token.TYPE {
|
||
|
for _, s := range specDecl.Specs {
|
||
|
switch tp := s.(*ast.TypeSpec).Type.(type) {
|
||
|
case *ast.StructType:
|
||
|
fcom := make(map[string]string)
|
||
|
for _, fd := range tp.Fields.List {
|
||
|
if len(fd.Names) == 0 {
|
||
|
continue
|
||
|
}
|
||
|
if len(fd.Comment.Text()) > 0 {
|
||
|
fcom[fd.Names[0].Name] = strings.TrimSuffix(fd.Comment.Text(), "\n")
|
||
|
}
|
||
|
}
|
||
|
sspec := s.(*ast.TypeSpec)
|
||
|
scom[sspec.Name.String()] = structComment{comment: strings.TrimSuffix(specDecl.Doc.Text(), "\n"), field: fcom}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
func isBasicType(Type string) bool {
|
||
|
if _, ok := basicTypes[Type]; ok {
|
||
|
return true
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
func typeAnalyser(f *ast.Field) (isSlice bool, realType, swaggerType string) {
|
||
|
if arr, ok := f.Type.(*ast.ArrayType); ok {
|
||
|
if isBasicType(fmt.Sprint(arr.Elt)) {
|
||
|
return true, fmt.Sprintf("[]%v", arr.Elt), basicTypes[fmt.Sprint(arr.Elt)]
|
||
|
}
|
||
|
if mp, ok := arr.Elt.(*ast.MapType); ok {
|
||
|
return false, fmt.Sprintf("map[%v][%v]", mp.Key, mp.Value), "object"
|
||
|
}
|
||
|
if star, ok := arr.Elt.(*ast.StarExpr); ok {
|
||
|
return true, fmt.Sprint(star.X), "object"
|
||
|
}
|
||
|
basicType := fmt.Sprint(arr.Elt)
|
||
|
if object, isStdLibObject := stdlibObject[basicType]; isStdLibObject {
|
||
|
basicType = object
|
||
|
|
||
|
}
|
||
|
if k, ok := basicTypes[basicType]; ok {
|
||
|
return true, basicType, k
|
||
|
}
|
||
|
return true, fmt.Sprint(arr.Elt), "object"
|
||
|
}
|
||
|
switch t := f.Type.(type) {
|
||
|
case *ast.StarExpr:
|
||
|
basicType := fmt.Sprint(t.X)
|
||
|
if k, ok := basicTypes[basicType]; ok {
|
||
|
return false, basicType, k
|
||
|
}
|
||
|
return false, basicType, "object"
|
||
|
case *ast.MapType:
|
||
|
val := fmt.Sprintf("%v", t.Value)
|
||
|
if isBasicType(val) {
|
||
|
return false, "map", basicTypes[val]
|
||
|
}
|
||
|
return false, val, "object"
|
||
|
}
|
||
|
basicType := fmt.Sprint(f.Type)
|
||
|
if object, isStdLibObject := stdlibObject[basicType]; isStdLibObject {
|
||
|
basicType = object
|
||
|
}
|
||
|
if k, ok := basicTypes[basicType]; ok {
|
||
|
return false, basicType, k
|
||
|
}
|
||
|
return false, basicType, "object"
|
||
|
}
|
||
|
|
||
|
func isSystemPackage(pkgpath string) bool {
|
||
|
goroot := os.Getenv("GOROOT")
|
||
|
if goroot == "" {
|
||
|
goroot = runtime.GOROOT()
|
||
|
}
|
||
|
wg, _ := filepath.EvalSymlinks(filepath.Join(goroot, "src", "pkg", pkgpath))
|
||
|
if isExist(wg) {
|
||
|
return true
|
||
|
}
|
||
|
wg, _ = filepath.EvalSymlinks(filepath.Join(goroot, "src", pkgpath))
|
||
|
return isExist(wg)
|
||
|
}
|
||
|
|
||
|
func isExist(path string) bool {
|
||
|
_, err := os.Stat(path)
|
||
|
return err == nil || os.IsExist(err)
|
||
|
}
|