go-common/app/tool/gdoc/gdoc.go
2019-04-22 18:49:16 +08:00

548 lines
13 KiB
Go

package main
import (
"encoding/json"
"errors"
"flag"
"fmt"
"go/ast"
"go/parser"
"go/token"
"os"
"path"
"path/filepath"
"reflect"
"runtime"
"strings"
)
// gloabl var.
var (
ErrParams = errors.New("err params")
_gopath = filepath.SplitList(os.Getenv("GOPATH"))
)
var (
dir string
pkgs = make(map[string]*ast.Package)
rlpkgs = make(map[string]*ast.Package)
definitions = make(map[string]*Schema)
swagger = Swagger{
Definitions: make(map[string]*Schema),
Paths: make(map[string]*Item),
SwaggerVersion: "2.0",
Infos: Information{
Title: "go-common api",
Description: "api",
Version: "1.0",
Contact: Contact{
EMail: "lintanghui@bilibili.com",
},
License: &License{
Name: "Apache 2.0",
URL: "http://www.apache.org/licenses/LICENSE-2.0.html",
},
},
}
stdlibObject = map[string]string{
"&{time Time}": "time.Time",
}
)
// refer to builtin.go
var basicTypes = map[string]string{
"bool": "boolean:",
"uint": "integer:int32",
"uint8": "integer:int32",
"uint16": "integer:int32",
"uint32": "integer:int32",
"uint64": "integer:int64",
"int": "integer:int64",
"int8": "integer:int32",
"int16": "integer:int32",
"int32": "integer:int32",
"int64": "integer:int64",
"uintptr": "integer:int64",
"float32": "number:float",
"float64": "number:double",
"string": "string:",
"complex64": "number:float",
"complex128": "number:double",
"byte": "string:byte",
"rune": "string:byte",
// builtin golang objects
"time.Time": "string:string",
}
func main() {
flag.StringVar(&dir, "d", "./", "specific project dir")
flag.Parse()
err := ParseFromDir(dir)
if err != nil {
panic(err)
}
parseModel(pkgs)
parseModel(rlpkgs)
parseRouter()
fd, err := os.Create(path.Join(dir, "swagger.json"))
if err != nil {
panic(err)
}
b, _ := json.MarshalIndent(swagger, "", " ")
fd.Write(b)
}
// ParseFromDir parse ast pkg from dir.
func ParseFromDir(dir string) (err error) {
filepath.Walk(dir, func(fpath string, fileInfo os.FileInfo, err error) error {
if err != nil {
return nil
}
if !fileInfo.IsDir() {
return nil
}
err = parseFromDir(fpath)
return err
})
return
}
func parseFromDir(dir string) (err error) {
fset := token.NewFileSet()
pkgFolder, err := parser.ParseDir(fset, dir, func(info os.FileInfo) bool {
name := info.Name()
return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
}, parser.ParseComments)
if err != nil {
return
}
for k, p := range pkgFolder {
pkgs[k] = p
}
return
}
func parseImport(dir string) (err error) {
fset := token.NewFileSet()
pkgFolder, err := parser.ParseDir(fset, dir, func(info os.FileInfo) bool {
name := info.Name()
return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
}, parser.ParseComments)
if err != nil {
return
}
for k, p := range pkgFolder {
rlpkgs[k] = p
}
return
}
func parseModel(pkgs map[string]*ast.Package) {
for _, p := range pkgs {
for _, f := range p.Files {
for _, im := range f.Imports {
if !isSystemPackage(im.Path.Value) {
for _, gp := range _gopath {
path := gp + "/src/" + strings.Trim(im.Path.Value, "\"")
if isExist(path) {
parseImport(path)
}
}
}
}
scom := parseStructComment(f)
for _, obj := range f.Scope.Objects {
if obj.Kind == ast.Typ {
objName := obj.Name
schema := &Schema{
Title: objName,
Type: "object",
}
ts, ok := obj.Decl.(*ast.TypeSpec)
if !ok {
fmt.Printf("obj type error %v ", obj.Kind)
}
st, ok := ts.Type.(*ast.StructType)
if !ok {
continue
}
properites := make(map[string]*Propertie)
for _, fd := range st.Fields.List {
if len(fd.Names) == 0 {
continue
}
name, required, omit, desc := parseFieldTag(fd)
if omit {
continue
}
isSlice, realType, sType := typeAnalyser(fd)
if (isSlice && isBasicType(realType)) || sType == "object" {
if len(strings.Split(realType, " ")) > 1 {
realType = strings.Replace(realType, " ", ".", -1)
realType = strings.Replace(realType, "&", "", -1)
realType = strings.Replace(realType, "{", "", -1)
realType = strings.Replace(realType, "}", "", -1)
}
}
mp := &Propertie{}
if isSlice {
mp.Type = "array"
if isBasicType(strings.Replace(realType, "[]", "", -1)) {
typeFormat := strings.Split(sType, ":")
mp.Items = &Propertie{
Type: typeFormat[0],
Format: typeFormat[1],
}
} else {
ss := strings.Split(realType, ".")
mp.RefImport = ss[len(ss)-1]
mp.Type = "array"
mp.Items = &Propertie{
Ref: "#/definitions/" + mp.RefImport,
Type: sType,
}
}
} else {
if sType == "object" {
ss := strings.Split(realType, ".")
mp.RefImport = ss[len(ss)-1]
mp.Type = sType
mp.Ref = "#/definitions/" + mp.RefImport
} else if isBasicType(realType) {
typeFormat := strings.Split(sType, ":")
mp.Type = typeFormat[0]
mp.Format = typeFormat[1]
} else if realType == "map" {
typeFormat := strings.Split(sType, ":")
mp.AdditionalProperties = &Propertie{
Type: typeFormat[0],
Format: typeFormat[1],
}
}
}
if name == "" {
name = fd.Names[0].Name
}
if required {
schema.Required = append(schema.Required, name)
}
mp.Description = desc
if scm, ok := scom[obj.Name]; ok {
if cm, ok := scm.field[fd.Names[0].Name]; ok {
mp.Description = cm + desc
}
}
properites[name] = mp
}
if scm, ok := scom[obj.Name]; ok {
schema.Description = scm.comment
}
schema.Properties = properites
definitions[schema.Title] = schema
}
}
}
}
}
func parseFieldTag(field *ast.Field) (name string, required, omit bool, tagDes string) {
if field.Tag == nil {
return
}
tag := reflect.StructTag(strings.Trim(field.Tag.Value, "`"))
param := tag.Get("form")
if param != "" {
params := strings.Split(param, ",")
if len(params) > 0 {
name = params[0]
}
if len(params) == 2 && params[1] == "split" {
tagDes = "数组,按逗号分隔"
}
}
if def := tag.Get("default"); def != "" {
tagDes = fmt.Sprintf("%s 默认值 %s", tagDes, def)
}
validate := tag.Get("validate")
if validate != "" {
params := strings.Split(validate, ",")
for _, param := range params {
switch {
case param == "required":
required = true
case strings.HasPrefix(param, "min"):
tagDes = fmt.Sprintf("%s 最小值 %s", tagDes, strings.Split(param, "=")[1])
case strings.HasPrefix(param, "max"):
tagDes = fmt.Sprintf("%s 最大值 %s", tagDes, strings.Split(param, "=")[1])
}
}
}
// parse json response.
json := tag.Get("json")
if json != "" {
jsons := strings.Split(json, ",")
if len(jsons) > 0 {
if jsons[0] == "-" {
omit = true
return
}
}
}
return
}
func parseRouter() {
for _, p := range pkgs {
if p.Name != "http" {
continue
}
fmt.Printf("开始解析生成swagger文档\n")
for _, f := range p.Files {
for _, decl := range f.Decls {
if fdecl, ok := decl.(*ast.FuncDecl); ok {
if fdecl.Doc != nil {
path, req, resp, item, err := parseFuncDoc(fdecl.Doc)
if err != nil {
fmt.Printf("解析失败 注解错误 %v\n", err)
continue
}
if path != "" && err == nil {
fmt.Printf("解析 %s 完成 请求参数为 %s 返回结构为 %s\n", path, req, resp)
swagger.Paths[path] = item
}
}
}
}
}
}
}
func parseFuncDoc(f *ast.CommentGroup) (path, reqObj, respObj string, item *Item, err error) {
item = new(Item)
op := new(Operation)
params := make([]*Parameter, 0)
response := make(map[string]*Response)
for _, d := range f.List {
t := strings.TrimSpace(strings.TrimPrefix(d.Text, "//"))
content := strings.Split(t, " ")
switch content[0] {
case "@params":
if len(content) < 2 {
err = fmt.Errorf("err params %s", content)
return
}
reqObj = content[1]
if model, ok := definitions[content[1]]; ok {
for n, p := range model.Properties {
param := &Parameter{
In: "query",
Name: n,
Description: p.Description,
Type: p.Type,
Format: p.Format,
}
for _, p := range model.Required {
if p == n {
param.Required = true
}
}
params = append(params, param)
}
} else {
err = ErrParams
return
}
case "@router":
if len(content) != 3 {
err = ErrParams
return
}
switch content[1] {
case "get":
item.Get = op
case "post":
item.Post = op
}
path = content[2]
op.OperationID = path
case "@response":
if len(content) < 2 {
err = fmt.Errorf("err response %s", content)
return
}
var (
isarray bool
ismap bool
)
if strings.HasPrefix(content[1], "[]") {
isarray = true
respObj = content[1][2:]
} else if strings.HasPrefix(content[1], "map[]") {
ismap = true
respObj = content[1][5:]
} else {
respObj = content[1]
}
defini, ok := definitions[respObj]
if !ok {
err = ErrParams
return
}
var resp *Propertie
if isarray {
resp = &Propertie{
Type: "array",
Items: &Propertie{
Type: "object",
Ref: "#/definitions/" + respObj,
},
}
} else if ismap {
resp = &Propertie{
Type: "object",
AdditionalProperties: &Propertie{
Ref: "#/definitions/" + respObj,
},
}
} else {
resp = &Propertie{
Type: "object",
Ref: "#/definitions/" + respObj,
}
}
response["200"] = &Response{
Schema: &Schema{
Type: "object",
Properties: map[string]*Propertie{
"code": &Propertie{
Type: "integer",
Description: "错误码描述",
},
"data": resp,
"message": &Propertie{
Type: "string",
Description: "错误码文本描述",
},
"ttl": &Propertie{
Type: "integer",
Format: "int64",
Description: "客户端限速时间",
},
},
},
Description: "服务成功响应内容",
}
op.Responses = response
for _, rl := range defini.Properties {
if rl.RefImport != "" {
swagger.Definitions[rl.RefImport] = definitions[rl.RefImport]
}
}
swagger.Definitions[respObj] = defini
case "@description":
op.Description = content[1]
}
}
op.Parameters = params
return
}
type structComment struct {
comment string
field map[string]string
}
func parseStructComment(f *ast.File) (scom map[string]structComment) {
scom = make(map[string]structComment)
for _, d := range f.Decls {
switch specDecl := d.(type) {
case *ast.GenDecl:
if specDecl.Tok == token.TYPE {
for _, s := range specDecl.Specs {
switch tp := s.(*ast.TypeSpec).Type.(type) {
case *ast.StructType:
fcom := make(map[string]string)
for _, fd := range tp.Fields.List {
if len(fd.Names) == 0 {
continue
}
if len(fd.Comment.Text()) > 0 {
fcom[fd.Names[0].Name] = strings.TrimSuffix(fd.Comment.Text(), "\n")
}
}
sspec := s.(*ast.TypeSpec)
scom[sspec.Name.String()] = structComment{comment: strings.TrimSuffix(specDecl.Doc.Text(), "\n"), field: fcom}
}
}
}
}
}
return
}
func isBasicType(Type string) bool {
if _, ok := basicTypes[Type]; ok {
return true
}
return false
}
func typeAnalyser(f *ast.Field) (isSlice bool, realType, swaggerType string) {
if arr, ok := f.Type.(*ast.ArrayType); ok {
if isBasicType(fmt.Sprint(arr.Elt)) {
return true, fmt.Sprintf("[]%v", arr.Elt), basicTypes[fmt.Sprint(arr.Elt)]
}
if mp, ok := arr.Elt.(*ast.MapType); ok {
return false, fmt.Sprintf("map[%v][%v]", mp.Key, mp.Value), "object"
}
if star, ok := arr.Elt.(*ast.StarExpr); ok {
return true, fmt.Sprint(star.X), "object"
}
basicType := fmt.Sprint(arr.Elt)
if object, isStdLibObject := stdlibObject[basicType]; isStdLibObject {
basicType = object
}
if k, ok := basicTypes[basicType]; ok {
return true, basicType, k
}
return true, fmt.Sprint(arr.Elt), "object"
}
switch t := f.Type.(type) {
case *ast.StarExpr:
basicType := fmt.Sprint(t.X)
if k, ok := basicTypes[basicType]; ok {
return false, basicType, k
}
return false, basicType, "object"
case *ast.MapType:
val := fmt.Sprintf("%v", t.Value)
if isBasicType(val) {
return false, "map", basicTypes[val]
}
return false, val, "object"
}
basicType := fmt.Sprint(f.Type)
if object, isStdLibObject := stdlibObject[basicType]; isStdLibObject {
basicType = object
}
if k, ok := basicTypes[basicType]; ok {
return false, basicType, k
}
return false, basicType, "object"
}
func isSystemPackage(pkgpath string) bool {
goroot := os.Getenv("GOROOT")
if goroot == "" {
goroot = runtime.GOROOT()
}
wg, _ := filepath.EvalSymlinks(filepath.Join(goroot, "src", "pkg", pkgpath))
if isExist(wg) {
return true
}
wg, _ = filepath.EvalSymlinks(filepath.Join(goroot, "src", pkgpath))
return isExist(wg)
}
func isExist(path string) bool {
_, err := os.Stat(path)
return err == nil || os.IsExist(err)
}