go-common/app/tool/warden/generator/genservice.go

276 lines
6.5 KiB
Go
Raw Normal View History

2019-04-22 10:49:16 +00:00
package generator
import (
"bufio"
"fmt"
"io"
"log"
"os"
"strings"
"text/template"
assets "go-common/app/tool/warden/generator/templates"
"go-common/app/tool/warden/types"
)
const (
protoTemplateName = "service.tmpl"
contextType = "context.Context"
)
// ProtoMessage ProtoMessage
type ProtoMessage struct {
Name string
Fields []ProtoField
}
// ProtoField ProtoField
type ProtoField struct {
FieldID int
FieldType string
FieldName string
}
// ProtoMethod method info
type ProtoMethod struct {
Comments []string
Name string
Req string
Reply string
}
// ProtoValue proto template render value
type ProtoValue struct {
Package string
Name string
GoPackage string
Imports map[string]bool
Messages map[string]ProtoMessage
Methods []ProtoMethod
options *ServiceProtoOptions
}
// ServiceProtoOptions ...
type ServiceProtoOptions struct {
GoPackage string
ProtoPackage string
IgnoreType bool
ImportPaths []string
}
func readProtoPackage(protoFile string) (string, error) {
fp, err := os.Open(protoFile)
if err != nil {
return "", err
}
defer fp.Close()
buf := bufio.NewReader(fp)
for {
line, err := buf.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
return "", err
}
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "package") {
continue
}
return strings.TrimSpace(strings.TrimRight(line[len("package"):], ";")), nil
}
return "", fmt.Errorf("proto %s miss package define", protoFile)
}
func underscore(s string) string {
cc := []byte(s)
us := make([]byte, 0, len(cc)+3)
pervUp := true
for _, b := range cc {
if 65 <= b && b <= 90 {
if pervUp {
us = append(us, b+32)
} else {
us = append(us, '_', b+32)
}
pervUp = true
} else {
pervUp = false
us = append(us, b)
}
}
return string(us)
}
func (p *ProtoValue) convertType(t types.Typer) (string, error) {
switch v := t.(type) {
case *types.BasicType:
return convertBasicType(v.String())
case *types.ArrayType:
if v.EltType.String() == "byte" {
return "bytes", nil
}
elt, err := p.convertType(v.EltType)
if err != nil {
return "", err
}
return fmt.Sprintf("repeated %s", elt), nil
case *types.MapType:
kt, err := p.convertType(v.KeyType)
if err != nil {
return "", err
}
vt, err := p.convertType(v.ValueType)
if err != nil {
return "", err
}
return fmt.Sprintf("map<%s, %s>", kt, vt), nil
case *types.StructType:
if v.ProtoFile == "" {
messageName := fmt.Sprintf("%s%s", strings.Title(v.Package), v.IdentName)
err := p.renderMessage(messageName, v.Fields)
if err != nil {
return "", err
}
return messageName, nil
}
protoPackage, err := readProtoPackage(v.ProtoFile)
if err != nil {
return "", err
}
p.importPackage(v.ProtoFile)
if p.Package == protoPackage {
return v.IdentName, nil
}
return fmt.Sprintf(".%s.%s", protoPackage, v.IdentName), nil
}
return "", fmt.Errorf("unsupport type %s", t)
}
func convertBasicType(gt string) (string, error) {
switch gt {
case "float64":
return "double", nil
case "float32":
return "float", nil
case "int", "int8", "uint8", "int16", "uint16":
return "int32", nil
case "int64", "int32", "uint32", "uint64", "string", "bool":
return gt, nil
}
return "", fmt.Errorf("unsupport basic type %s", gt)
}
func (p *ProtoValue) render(spec *types.ServiceSpec, options *ServiceProtoOptions) (*ProtoValue, error) {
p.options = options
p.Name = spec.Name
p.GoPackage = options.GoPackage
p.Package = options.ProtoPackage
p.Imports = make(map[string]bool)
p.Messages = make(map[string]ProtoMessage)
return p, p.renderMethods(spec.Methods)
}
func (p *ProtoValue) renderMethods(methods []*types.Method) error {
for _, method := range methods {
protoMethod := ProtoMethod{
Comments: method.Comments,
Name: method.Name,
}
//if len(method.Parameters) == 0 || (len(method.Parameters) == 1 && method.Parameters[0].Type.String() == contextType) {
// p.importPackage(emptyProtoFile)
// protoMethod.Req = emptyProtoMsg
//} else {
// protoMethod.Req = fmt.Sprintf("%sReq", method.Name)
// if err := p.renderMessage(protoMethod.Req, method.Parameters); err != nil {
// return err
// }
//}
//if len(method.Results) == 0 || (len(method.Results) == 1 && method.Results[0].Type.String() == "error") {
// p.importPackage(emptyProtoFile)
// protoMethod.Reply = emptyProtoMsg
//} else {
// protoMethod.Reply = fmt.Sprintf("%sReply", method.Name)
// if err := p.renderMessage(protoMethod.Reply, method.Results); err != nil {
// return err
// }
//}
protoMethod.Req = fmt.Sprintf("%sReq", method.Name)
if err := p.renderMessage(protoMethod.Req, method.Parameters); err != nil {
return err
}
protoMethod.Reply = fmt.Sprintf("%sReply", method.Name)
if err := p.renderMessage(protoMethod.Reply, method.Results); err != nil {
return err
}
p.Methods = append(p.Methods, protoMethod)
}
return nil
}
func (p *ProtoValue) importPackage(imp string) {
for _, importPath := range p.options.ImportPaths {
if strings.HasPrefix(imp, importPath) {
p.Imports[strings.TrimLeft(imp[len(importPath):], "/")] = true
return
}
}
p.Imports[imp] = true
}
func (p *ProtoValue) renderMessage(name string, fields []*types.Field) error {
if _, ok := p.Messages[name]; ok {
return nil
}
message := ProtoMessage{
Name: name,
}
for i, field := range fields {
if field.Type.String() == "error" || field.Type.String() == contextType {
continue
}
fieldName := underscore(field.Name)
if fieldName == "" {
fieldName = fmt.Sprintf("data_%d", i)
}
pField := ProtoField{
FieldID: i + 1,
FieldName: fieldName,
}
ptype, err := p.convertType(field.Type)
if err != nil {
if p.options.IgnoreType {
log.Printf("warning convert type fail %s", err)
ptype = fmt.Sprintf("//FIXME type %s", field.Type)
} else {
return err
}
}
pField.FieldType = ptype
message.Fields = append(message.Fields, pField)
}
p.Messages[name] = message
return nil
}
func renderProtoValue(spec *types.ServiceSpec, options *ServiceProtoOptions) (*ProtoValue, error) {
v := &ProtoValue{}
return v.render(spec, options)
}
// GenServiceProto generator proto service by service spec
func GenServiceProto(out io.Writer, spec *types.ServiceSpec, options *ServiceProtoOptions) error {
value, err := renderProtoValue(spec, options)
if err != nil {
return err
}
assets.MustAsset(protoTemplateName)
t, err := template.New(protoTemplateName).Parse(string(assets.MustAsset(protoTemplateName)))
if err != nil {
return err
}
return t.Execute(out, value)
}