276 lines
6.5 KiB
Go
276 lines
6.5 KiB
Go
|
package generator
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"log"
|
||
|
"os"
|
||
|
"strings"
|
||
|
"text/template"
|
||
|
|
||
|
assets "go-common/app/tool/warden/generator/templates"
|
||
|
"go-common/app/tool/warden/types"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
protoTemplateName = "service.tmpl"
|
||
|
contextType = "context.Context"
|
||
|
)
|
||
|
|
||
|
// ProtoMessage ProtoMessage
|
||
|
type ProtoMessage struct {
|
||
|
Name string
|
||
|
Fields []ProtoField
|
||
|
}
|
||
|
|
||
|
// ProtoField ProtoField
|
||
|
type ProtoField struct {
|
||
|
FieldID int
|
||
|
FieldType string
|
||
|
FieldName string
|
||
|
}
|
||
|
|
||
|
// ProtoMethod method info
|
||
|
type ProtoMethod struct {
|
||
|
Comments []string
|
||
|
Name string
|
||
|
Req string
|
||
|
Reply string
|
||
|
}
|
||
|
|
||
|
// ProtoValue proto template render value
|
||
|
type ProtoValue struct {
|
||
|
Package string
|
||
|
Name string
|
||
|
GoPackage string
|
||
|
Imports map[string]bool
|
||
|
Messages map[string]ProtoMessage
|
||
|
Methods []ProtoMethod
|
||
|
|
||
|
options *ServiceProtoOptions
|
||
|
}
|
||
|
|
||
|
// ServiceProtoOptions ...
|
||
|
type ServiceProtoOptions struct {
|
||
|
GoPackage string
|
||
|
ProtoPackage string
|
||
|
IgnoreType bool
|
||
|
ImportPaths []string
|
||
|
}
|
||
|
|
||
|
func readProtoPackage(protoFile string) (string, error) {
|
||
|
fp, err := os.Open(protoFile)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
defer fp.Close()
|
||
|
buf := bufio.NewReader(fp)
|
||
|
for {
|
||
|
line, err := buf.ReadString('\n')
|
||
|
if err != nil {
|
||
|
if err == io.EOF {
|
||
|
break
|
||
|
}
|
||
|
return "", err
|
||
|
}
|
||
|
line = strings.TrimSpace(line)
|
||
|
if !strings.HasPrefix(line, "package") {
|
||
|
continue
|
||
|
}
|
||
|
return strings.TrimSpace(strings.TrimRight(line[len("package"):], ";")), nil
|
||
|
}
|
||
|
return "", fmt.Errorf("proto %s miss package define", protoFile)
|
||
|
}
|
||
|
|
||
|
func underscore(s string) string {
|
||
|
cc := []byte(s)
|
||
|
us := make([]byte, 0, len(cc)+3)
|
||
|
pervUp := true
|
||
|
for _, b := range cc {
|
||
|
if 65 <= b && b <= 90 {
|
||
|
if pervUp {
|
||
|
us = append(us, b+32)
|
||
|
} else {
|
||
|
us = append(us, '_', b+32)
|
||
|
}
|
||
|
pervUp = true
|
||
|
} else {
|
||
|
pervUp = false
|
||
|
us = append(us, b)
|
||
|
}
|
||
|
}
|
||
|
return string(us)
|
||
|
}
|
||
|
|
||
|
func (p *ProtoValue) convertType(t types.Typer) (string, error) {
|
||
|
switch v := t.(type) {
|
||
|
case *types.BasicType:
|
||
|
return convertBasicType(v.String())
|
||
|
case *types.ArrayType:
|
||
|
if v.EltType.String() == "byte" {
|
||
|
return "bytes", nil
|
||
|
}
|
||
|
elt, err := p.convertType(v.EltType)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
return fmt.Sprintf("repeated %s", elt), nil
|
||
|
case *types.MapType:
|
||
|
kt, err := p.convertType(v.KeyType)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
vt, err := p.convertType(v.ValueType)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
return fmt.Sprintf("map<%s, %s>", kt, vt), nil
|
||
|
case *types.StructType:
|
||
|
if v.ProtoFile == "" {
|
||
|
messageName := fmt.Sprintf("%s%s", strings.Title(v.Package), v.IdentName)
|
||
|
err := p.renderMessage(messageName, v.Fields)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
return messageName, nil
|
||
|
}
|
||
|
protoPackage, err := readProtoPackage(v.ProtoFile)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
p.importPackage(v.ProtoFile)
|
||
|
if p.Package == protoPackage {
|
||
|
return v.IdentName, nil
|
||
|
}
|
||
|
return fmt.Sprintf(".%s.%s", protoPackage, v.IdentName), nil
|
||
|
}
|
||
|
return "", fmt.Errorf("unsupport type %s", t)
|
||
|
}
|
||
|
|
||
|
func convertBasicType(gt string) (string, error) {
|
||
|
switch gt {
|
||
|
case "float64":
|
||
|
return "double", nil
|
||
|
case "float32":
|
||
|
return "float", nil
|
||
|
case "int", "int8", "uint8", "int16", "uint16":
|
||
|
return "int32", nil
|
||
|
case "int64", "int32", "uint32", "uint64", "string", "bool":
|
||
|
return gt, nil
|
||
|
}
|
||
|
return "", fmt.Errorf("unsupport basic type %s", gt)
|
||
|
}
|
||
|
|
||
|
func (p *ProtoValue) render(spec *types.ServiceSpec, options *ServiceProtoOptions) (*ProtoValue, error) {
|
||
|
p.options = options
|
||
|
p.Name = spec.Name
|
||
|
p.GoPackage = options.GoPackage
|
||
|
p.Package = options.ProtoPackage
|
||
|
p.Imports = make(map[string]bool)
|
||
|
p.Messages = make(map[string]ProtoMessage)
|
||
|
return p, p.renderMethods(spec.Methods)
|
||
|
}
|
||
|
|
||
|
func (p *ProtoValue) renderMethods(methods []*types.Method) error {
|
||
|
for _, method := range methods {
|
||
|
protoMethod := ProtoMethod{
|
||
|
Comments: method.Comments,
|
||
|
Name: method.Name,
|
||
|
}
|
||
|
//if len(method.Parameters) == 0 || (len(method.Parameters) == 1 && method.Parameters[0].Type.String() == contextType) {
|
||
|
// p.importPackage(emptyProtoFile)
|
||
|
// protoMethod.Req = emptyProtoMsg
|
||
|
//} else {
|
||
|
// protoMethod.Req = fmt.Sprintf("%sReq", method.Name)
|
||
|
// if err := p.renderMessage(protoMethod.Req, method.Parameters); err != nil {
|
||
|
// return err
|
||
|
// }
|
||
|
//}
|
||
|
|
||
|
//if len(method.Results) == 0 || (len(method.Results) == 1 && method.Results[0].Type.String() == "error") {
|
||
|
// p.importPackage(emptyProtoFile)
|
||
|
// protoMethod.Reply = emptyProtoMsg
|
||
|
//} else {
|
||
|
// protoMethod.Reply = fmt.Sprintf("%sReply", method.Name)
|
||
|
// if err := p.renderMessage(protoMethod.Reply, method.Results); err != nil {
|
||
|
// return err
|
||
|
// }
|
||
|
//}
|
||
|
protoMethod.Req = fmt.Sprintf("%sReq", method.Name)
|
||
|
if err := p.renderMessage(protoMethod.Req, method.Parameters); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
protoMethod.Reply = fmt.Sprintf("%sReply", method.Name)
|
||
|
if err := p.renderMessage(protoMethod.Reply, method.Results); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
p.Methods = append(p.Methods, protoMethod)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (p *ProtoValue) importPackage(imp string) {
|
||
|
for _, importPath := range p.options.ImportPaths {
|
||
|
if strings.HasPrefix(imp, importPath) {
|
||
|
p.Imports[strings.TrimLeft(imp[len(importPath):], "/")] = true
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
p.Imports[imp] = true
|
||
|
}
|
||
|
|
||
|
func (p *ProtoValue) renderMessage(name string, fields []*types.Field) error {
|
||
|
if _, ok := p.Messages[name]; ok {
|
||
|
return nil
|
||
|
}
|
||
|
message := ProtoMessage{
|
||
|
Name: name,
|
||
|
}
|
||
|
for i, field := range fields {
|
||
|
if field.Type.String() == "error" || field.Type.String() == contextType {
|
||
|
continue
|
||
|
}
|
||
|
fieldName := underscore(field.Name)
|
||
|
if fieldName == "" {
|
||
|
fieldName = fmt.Sprintf("data_%d", i)
|
||
|
}
|
||
|
pField := ProtoField{
|
||
|
FieldID: i + 1,
|
||
|
FieldName: fieldName,
|
||
|
}
|
||
|
ptype, err := p.convertType(field.Type)
|
||
|
if err != nil {
|
||
|
if p.options.IgnoreType {
|
||
|
log.Printf("warning convert type fail %s", err)
|
||
|
ptype = fmt.Sprintf("//FIXME type %s", field.Type)
|
||
|
} else {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
pField.FieldType = ptype
|
||
|
message.Fields = append(message.Fields, pField)
|
||
|
}
|
||
|
p.Messages[name] = message
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func renderProtoValue(spec *types.ServiceSpec, options *ServiceProtoOptions) (*ProtoValue, error) {
|
||
|
v := &ProtoValue{}
|
||
|
return v.render(spec, options)
|
||
|
}
|
||
|
|
||
|
// GenServiceProto generator proto service by service spec
|
||
|
func GenServiceProto(out io.Writer, spec *types.ServiceSpec, options *ServiceProtoOptions) error {
|
||
|
value, err := renderProtoValue(spec, options)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
assets.MustAsset(protoTemplateName)
|
||
|
t, err := template.New(protoTemplateName).Parse(string(assets.MustAsset(protoTemplateName)))
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return t.Execute(out, value)
|
||
|
}
|