1404 lines
39 KiB
Go
1404 lines
39 KiB
Go
// Copyright 2018 Twitch Interactive, Inc. All Rights Reserved.
|
||
//
|
||
// Licensed under the Apache License, Version 2.0 (the "License"). You may not
|
||
// use this file except in compliance with the License. A copy of the License is
|
||
// located at
|
||
//
|
||
// http://www.apache.org/licenses/LICENSE-2.0
|
||
//
|
||
// or in the "license" file accompanying this file. This file is distributed on
|
||
// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||
// express or implied. See the License for the specific language governing
|
||
// permissions and limitations under the License.
|
||
|
||
package main
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"compress/gzip"
|
||
"fmt"
|
||
"go/ast"
|
||
"go/parser"
|
||
"go/printer"
|
||
"go/token"
|
||
"io/ioutil"
|
||
"os"
|
||
"path"
|
||
"path/filepath"
|
||
"reflect"
|
||
"sort"
|
||
"strconv"
|
||
"strings"
|
||
"unicode"
|
||
|
||
"go-common/app/tool/liverpc/protoc-gen-liverpc/gen"
|
||
"go-common/app/tool/liverpc/protoc-gen-liverpc/gen/stringutils"
|
||
"go-common/app/tool/liverpc/protoc-gen-liverpc/gen/typemap"
|
||
|
||
"github.com/golang/protobuf/proto"
|
||
"github.com/golang/protobuf/protoc-gen-go/descriptor"
|
||
plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
|
||
"github.com/pkg/errors"
|
||
"github.com/siddontang/go/ioutil2"
|
||
)
|
||
|
||
var legacyPathMapping = map[string]string{
|
||
"live.webucenter": "live.web-ucenter",
|
||
"live.webroom": "live.web-room",
|
||
"live.appucenter": "live.app-ucenter",
|
||
"live.appblink": "live.app-blink",
|
||
"live.approom": "live.app-room",
|
||
"live.appinterface": "live.app-interface",
|
||
"live.liveadmin": "live.live-admin",
|
||
"live.resource": "live.resource",
|
||
"live.livedemo": "live.live-demo",
|
||
"live.lotteryinterface": "live.lottery-interface",
|
||
}
|
||
|
||
type bm struct {
|
||
filesHandled int
|
||
|
||
reg *typemap.Registry
|
||
|
||
// Map to record whether we've built each package
|
||
pkgs map[string]string
|
||
pkgNamesInUse map[string]bool
|
||
|
||
importPrefix string // String to prefix to imported package file names.
|
||
importMap map[string]string // Mapping from .proto file name to import path.
|
||
tpl bool // only generate service template file, no docs, no .bm.go, default false
|
||
|
||
// Package naming:
|
||
genPkgName string // Name of the package that we're generating
|
||
fileToGoPackageName map[*descriptor.FileDescriptorProto]string
|
||
|
||
// List of files that were inputs to the generator. We need to hold this in
|
||
// the struct so we can write a header for the file that lists its inputs.
|
||
genFiles []*descriptor.FileDescriptorProto
|
||
|
||
// Output buffer that holds the bytes we want to write out for a single file.
|
||
// Gets reset after working on a file.
|
||
output *bytes.Buffer
|
||
|
||
deps map[string]string
|
||
}
|
||
|
||
// if current dir is a go-common project
|
||
// or is the internal directory of a go-common project
|
||
// this present a project info
|
||
type projectInfo struct {
|
||
absolutePath string
|
||
// relative to go-common
|
||
importPath string
|
||
name string
|
||
department string
|
||
// interface, service, admin ...
|
||
typ string
|
||
hasInternalPkg bool
|
||
// 从工作目录到project目录的相对路径 比如./ ../
|
||
pathRefToProj string
|
||
}
|
||
|
||
// projectInfo for current directory
|
||
var projInfo *projectInfo
|
||
|
||
func bmGenerator() *bm {
|
||
t := &bm{
|
||
pkgs: make(map[string]string),
|
||
pkgNamesInUse: make(map[string]bool),
|
||
importMap: make(map[string]string),
|
||
fileToGoPackageName: make(map[*descriptor.FileDescriptorProto]string),
|
||
output: bytes.NewBuffer(nil),
|
||
}
|
||
|
||
return t
|
||
}
|
||
|
||
func (t *bm) Generate(in *plugin.CodeGeneratorRequest) *plugin.CodeGeneratorResponse {
|
||
params, err := parseCommandLineParams(in.GetParameter())
|
||
if err != nil {
|
||
gen.Fail("could not parse parameters passed to --bm_out", err.Error())
|
||
}
|
||
t.importPrefix = params.importPrefix
|
||
t.importMap = params.importMap
|
||
t.tpl = params.tpl
|
||
|
||
t.genFiles = gen.FilesToGenerate(in)
|
||
|
||
// Collect information on types.
|
||
t.reg = typemap.New(in.ProtoFile)
|
||
|
||
t.registerPackageName("context")
|
||
t.registerPackageName("ioutil")
|
||
t.registerPackageName("proto")
|
||
|
||
// Time to figure out package names of objects defined in protobuf. First,
|
||
// we'll figure out the name for the package we're generating.
|
||
genPkgName, err := deduceGenPkgName(t.genFiles)
|
||
if err != nil {
|
||
gen.Fail(err.Error())
|
||
}
|
||
t.genPkgName = genPkgName
|
||
|
||
// Next, we need to pick names for all the files that are dependencies.
|
||
if len(in.ProtoFile) > 0 {
|
||
t.initProjInfo(in.ProtoFile[0])
|
||
}
|
||
for _, f := range in.ProtoFile {
|
||
if fileDescSliceContains(t.genFiles, f) {
|
||
// This is a file we are generating. It gets the shared package name.
|
||
t.fileToGoPackageName[f] = t.genPkgName
|
||
} else {
|
||
// This is a dependency. Use its package name.
|
||
name := f.GetPackage()
|
||
if name == "" {
|
||
name = stringutils.BaseName(f.GetName())
|
||
}
|
||
name = stringutils.CleanIdentifier(name)
|
||
alias := t.registerPackageName(name)
|
||
t.fileToGoPackageName[f] = alias
|
||
}
|
||
}
|
||
|
||
// Showtime! Generate the response.
|
||
resp := new(plugin.CodeGeneratorResponse)
|
||
for _, f := range t.genFiles {
|
||
respFile := t.generate(f)
|
||
if respFile != nil {
|
||
resp.File = append(resp.File, respFile)
|
||
}
|
||
for _, s := range f.Service {
|
||
docResp := t.generateDoc(f, s)
|
||
if docResp != nil {
|
||
resp.File = append(resp.File, docResp)
|
||
}
|
||
}
|
||
|
||
if t.tpl {
|
||
if projInfo != nil {
|
||
for _, s := range f.Service {
|
||
serviceResp := t.generateServiceImpl(f, s)
|
||
if serviceResp != nil {
|
||
resp.File = append(resp.File, serviceResp)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return resp
|
||
}
|
||
|
||
// lookupProjPath get project path by proto absolute path
|
||
// assume that proto is in the project's model directory
|
||
func lookupProjPath(protoAbs string) (result string) {
|
||
lastIndex := len(protoAbs)
|
||
curPath := protoAbs
|
||
|
||
for lastIndex > 0 {
|
||
if ioutil2.FileExists(curPath+"/cmd") && ioutil2.FileExists(curPath+"/api") {
|
||
result = curPath
|
||
return
|
||
}
|
||
lastIndex = strings.LastIndex(curPath, string(os.PathSeparator))
|
||
curPath = protoAbs[:lastIndex]
|
||
}
|
||
result = ""
|
||
return
|
||
}
|
||
|
||
func (t *bm) initProjInfo(file *descriptor.FileDescriptorProto) {
|
||
var err error
|
||
projInfo = &projectInfo{}
|
||
defer func() {
|
||
if err != nil {
|
||
projInfo = nil
|
||
}
|
||
}()
|
||
wd, err := os.Getwd()
|
||
if err != nil {
|
||
panic("cannot get working directory")
|
||
}
|
||
protoAbs := wd + "/" + file.GetName()
|
||
appIndex := strings.Index(wd, "go-common/app/")
|
||
if appIndex == -1 {
|
||
err = errors.New("not in go-common/app/")
|
||
return
|
||
}
|
||
|
||
projPath := lookupProjPath(protoAbs)
|
||
if projPath == "" {
|
||
err = errors.New("not in project")
|
||
return
|
||
}
|
||
if strings.Contains(wd, projPath) {
|
||
rest := strings.Replace(wd, projPath, "", 1)
|
||
projInfo.pathRefToProj = "./"
|
||
if rest != "" {
|
||
split := strings.Split(rest, "/")
|
||
ref := ""
|
||
for i := 0; i < len(split)-1; i++ {
|
||
ref = ref + "../"
|
||
}
|
||
projInfo.pathRefToProj = ref
|
||
}
|
||
}
|
||
projInfo.absolutePath = projPath
|
||
if ioutil2.FileExists(projPath + "/internal") {
|
||
projInfo.hasInternalPkg = true
|
||
}
|
||
|
||
relativePath := projInfo.absolutePath[appIndex+len("go-common/app/"):]
|
||
projInfo.importPath = "go-common/app/" + relativePath
|
||
split := strings.Split(relativePath, "/")
|
||
projInfo.typ = split[0]
|
||
projInfo.department = split[1]
|
||
projInfo.name = split[2]
|
||
}
|
||
|
||
// find tag between backtick, start & end is the position of backtick
|
||
func getLineTag(line string) (tag reflect.StructTag, start int, end int) {
|
||
start = strings.Index(line, "`")
|
||
end = strings.LastIndex(line, "`")
|
||
if end <= start {
|
||
return
|
||
}
|
||
tag = reflect.StructTag(line[start+1 : end])
|
||
return
|
||
}
|
||
|
||
func getCommentWithoutTag(comment string) []string {
|
||
var lines []string
|
||
if comment == "" {
|
||
return lines
|
||
}
|
||
split := strings.Split(strings.TrimRight(comment, "\n\r"), "\n")
|
||
for _, line := range split {
|
||
tag, _, _ := getLineTag(line)
|
||
if tag == "" {
|
||
lines = append(lines, line)
|
||
}
|
||
}
|
||
return lines
|
||
}
|
||
|
||
func getTagsInComment(comment string) []reflect.StructTag {
|
||
split := strings.Split(comment, "\n")
|
||
var tagsInComment []reflect.StructTag
|
||
for _, line := range split {
|
||
tag, _, _ := getLineTag(line)
|
||
if tag != "" {
|
||
tagsInComment = append(tagsInComment, tag)
|
||
}
|
||
}
|
||
return tagsInComment
|
||
}
|
||
|
||
func getTagValue(key string, tags []reflect.StructTag) string {
|
||
for _, t := range tags {
|
||
val := t.Get(key)
|
||
if val != "" {
|
||
return val
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// Is this field repeated?
|
||
func isRepeated(field *descriptor.FieldDescriptorProto) bool {
|
||
return field.Label != nil && *field.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED
|
||
}
|
||
|
||
func (t *bm) isMap(field *descriptor.FieldDescriptorProto) bool {
|
||
if field.GetType() != descriptor.FieldDescriptorProto_TYPE_MESSAGE {
|
||
return false
|
||
}
|
||
md := t.reg.MessageDefinition(field.GetTypeName())
|
||
if md == nil || !md.Descriptor.GetOptions().GetMapEntry() {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
|
||
func (t *bm) generateToc(file *descriptor.FileDescriptorProto, service *descriptor.ServiceDescriptorProto) {
|
||
for _, method := range service.Method {
|
||
comment, _ := t.reg.MethodComments(file, service, method)
|
||
tags := getTagsInComment(comment.Leading)
|
||
_, path, newPath := t.getHttpInfo(file, service, method, tags)
|
||
|
||
cleanComments := getCommentWithoutTag(comment.Leading)
|
||
var title string
|
||
if len(cleanComments) > 0 {
|
||
title = cleanComments[0]
|
||
}
|
||
|
||
// 如果有老的路径,只显示老的路径文档
|
||
if path != "" {
|
||
anchor := strings.Replace(path, "/", "", -1)
|
||
t.P(fmt.Sprintf("- [%s](#%s) %s", path, anchor, title))
|
||
} else {
|
||
anchor := strings.Replace(newPath, "/", "", -1)
|
||
t.P(fmt.Sprintf("- [%s](#%s) %s", newPath, anchor, title))
|
||
}
|
||
}
|
||
}
|
||
|
||
func (t *bm) generateDoc(file *descriptor.FileDescriptorProto, service *descriptor.ServiceDescriptorProto) *plugin.CodeGeneratorResponse_File {
|
||
resp := new(plugin.CodeGeneratorResponse_File)
|
||
var name = goFileName(file, "."+lcFirst(service.GetName())+".md")
|
||
resp.Name = &name
|
||
t.P("<!-- package=" + file.GetPackage() + " -->")
|
||
t.generateToc(file, service)
|
||
for _, method := range service.Method {
|
||
comment, err := t.reg.MethodComments(file, service, method)
|
||
tags := getTagsInComment(comment.Leading)
|
||
cleanComments := getCommentWithoutTag(comment.Leading)
|
||
midwaresStr := getTagValue("midware", tags)
|
||
needAuth := false
|
||
if midwaresStr != "" {
|
||
split := strings.Split(midwaresStr, ",")
|
||
for _, m := range split {
|
||
if m == "auth" {
|
||
needAuth = true
|
||
break
|
||
}
|
||
}
|
||
}
|
||
t.P()
|
||
httpMethod, legacyPath, path := t.getHttpInfo(file, service, method, tags)
|
||
if legacyPath != "" {
|
||
path = legacyPath
|
||
}
|
||
t.P("## " + path)
|
||
|
||
if err == nil {
|
||
if len(cleanComments) == 0 {
|
||
t.P(`### 无标题`)
|
||
} else {
|
||
t.P(`###`, strings.Join(cleanComments, "\n"))
|
||
}
|
||
}
|
||
t.P()
|
||
|
||
if needAuth {
|
||
t.P(`> `, "需要登录")
|
||
t.P()
|
||
}
|
||
|
||
t.P("#### 方法:" + httpMethod)
|
||
t.P()
|
||
|
||
t.genRequestParam(file, service, method)
|
||
t.P("#### 响应")
|
||
t.P()
|
||
|
||
t.P("```javascript")
|
||
t.P(`{`)
|
||
t.P(` "code": 0,`)
|
||
t.P(` "message": "ok",`)
|
||
t.P(t.getExampleJson(file, service, method))
|
||
t.P(`}`)
|
||
t.P("```")
|
||
t.P()
|
||
}
|
||
resp.Content = proto.String(t.output.String())
|
||
t.output.Reset()
|
||
return resp
|
||
}
|
||
|
||
func (t *bm) genRequestParam(
|
||
file *descriptor.FileDescriptorProto,
|
||
svc *descriptor.ServiceDescriptorProto,
|
||
method *descriptor.MethodDescriptorProto) {
|
||
md := t.reg.MessageDefinition(method.GetInputType())
|
||
t.P(`#### 请求参数`)
|
||
t.P()
|
||
|
||
var outputs []string
|
||
for i, f := range md.Descriptor.Field {
|
||
if f.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE {
|
||
// 如果有message 只能以json的方式显示参数了
|
||
var buf = &[]string{}
|
||
t.exampleJsonForMsg(md, file, buf, "", 0, "")
|
||
j := strings.Join(*buf, "\n")
|
||
t.P("```javascript")
|
||
t.P(j)
|
||
t.P("```")
|
||
t.P()
|
||
return
|
||
}
|
||
if i == 0 {
|
||
outputs = append(outputs, `|参数名|必选|类型|描述|`)
|
||
outputs = append(outputs, `|:---|:---|:---|:---|`)
|
||
}
|
||
fComment, _ := t.reg.FieldComments(file, md, f)
|
||
var tags []reflect.StructTag
|
||
{
|
||
//get required info from gogoproto.moretags
|
||
moretags := getMoreTags(f)
|
||
if moretags != nil {
|
||
tags = []reflect.StructTag{reflect.StructTag(*moretags)}
|
||
}
|
||
}
|
||
if len(tags) == 0 {
|
||
tags = getTagsInComment(fComment.Leading)
|
||
}
|
||
validateTag := getTagValue("validate", tags)
|
||
var validateRules []string
|
||
if validateTag != "" {
|
||
validateRules = strings.Split(validateTag, ",")
|
||
}
|
||
required := false
|
||
for _, rule := range validateRules {
|
||
if rule == "required" {
|
||
required = true
|
||
}
|
||
}
|
||
requiredDesc := "是"
|
||
if !required {
|
||
requiredDesc = "否"
|
||
}
|
||
_, typeName := t.mockValueForField(f, tags)
|
||
split := strings.Split(fComment.Leading, "\n")
|
||
desc := ""
|
||
for _, line := range split {
|
||
if line != "" {
|
||
tag, _, _ := getLineTag(line)
|
||
if tag == "" {
|
||
desc += line
|
||
}
|
||
}
|
||
}
|
||
outputs = append(outputs, fmt.Sprintf(`|%s|%s|%s|%s|`, getJsonTag(f), requiredDesc, typeName, desc))
|
||
}
|
||
for _, s := range outputs {
|
||
t.P(s)
|
||
}
|
||
t.P()
|
||
}
|
||
|
||
func (t *bm) getExampleJson(file *descriptor.FileDescriptorProto,
|
||
svc *descriptor.ServiceDescriptorProto,
|
||
method *descriptor.MethodDescriptorProto) string {
|
||
md := t.reg.MessageDefinition(method.GetOutputType())
|
||
var buf = &[]string{}
|
||
t.exampleJsonForMsg(md, file, buf, "data", 4, "")
|
||
return strings.Join(*buf, "\n")
|
||
}
|
||
|
||
func makeIndentStr(i int) string {
|
||
return strings.Repeat(" ", i)
|
||
}
|
||
|
||
func (t *bm) mockValueForField(field *descriptor.FieldDescriptorProto,
|
||
tags []reflect.StructTag) (mockVal string, typeName string) {
|
||
tagMock := getTagValue("mock", tags)
|
||
mockVal = "\"unknown\""
|
||
typeName = "unknown"
|
||
switch field.GetType() {
|
||
case descriptor.FieldDescriptorProto_TYPE_BOOL:
|
||
if tagMock == "true" || tagMock == "false" {
|
||
mockVal = tagMock
|
||
} else {
|
||
mockVal = "true"
|
||
}
|
||
typeName = "bool"
|
||
case descriptor.FieldDescriptorProto_TYPE_DOUBLE,
|
||
descriptor.FieldDescriptorProto_TYPE_FLOAT:
|
||
mockVal = "0.1"
|
||
if tagMock != "" {
|
||
if _, err := strconv.ParseFloat(tagMock, 64); err == nil {
|
||
mockVal = tagMock
|
||
}
|
||
}
|
||
typeName = "float"
|
||
case
|
||
descriptor.FieldDescriptorProto_TYPE_INT64,
|
||
descriptor.FieldDescriptorProto_TYPE_UINT64,
|
||
descriptor.FieldDescriptorProto_TYPE_INT32,
|
||
descriptor.FieldDescriptorProto_TYPE_FIXED64,
|
||
descriptor.FieldDescriptorProto_TYPE_FIXED32,
|
||
descriptor.FieldDescriptorProto_TYPE_ENUM,
|
||
descriptor.FieldDescriptorProto_TYPE_UINT32,
|
||
descriptor.FieldDescriptorProto_TYPE_SFIXED32,
|
||
descriptor.FieldDescriptorProto_TYPE_SFIXED64,
|
||
descriptor.FieldDescriptorProto_TYPE_SINT32,
|
||
descriptor.FieldDescriptorProto_TYPE_SINT64:
|
||
mockVal = "0"
|
||
if tagMock != "" {
|
||
if _, err := strconv.Atoi(tagMock); err == nil {
|
||
mockVal = tagMock
|
||
}
|
||
}
|
||
typeName = "integer"
|
||
case
|
||
descriptor.FieldDescriptorProto_TYPE_STRING,
|
||
descriptor.FieldDescriptorProto_TYPE_BYTES:
|
||
mockVal = `""`
|
||
if tagMock != "" {
|
||
mockVal = strconv.Quote(tagMock)
|
||
}
|
||
typeName = "string"
|
||
}
|
||
if isRepeated(field) {
|
||
typeName = "多个" + typeName
|
||
}
|
||
return
|
||
}
|
||
|
||
func (t *bm) exampleJsonForMsg(
|
||
msg *typemap.MessageDefinition,
|
||
file *descriptor.FileDescriptorProto,
|
||
buf *[]string, fieldName string, indent int, outEndComma string) {
|
||
if fieldName == "" {
|
||
*buf = append(*buf, makeIndentStr(indent)+"{")
|
||
} else {
|
||
*buf = append(*buf, makeIndentStr(indent)+fmt.Sprintf(`"%s": {`, fieldName))
|
||
}
|
||
num := len(msg.Descriptor.Field)
|
||
for i, f := range msg.Descriptor.Field {
|
||
isScalar := isScalar(f)
|
||
fComment, _ := t.reg.FieldComments(file, msg, f)
|
||
cleanComment := getCommentWithoutTag(fComment.Leading)
|
||
for _, line := range cleanComment {
|
||
if strings.Trim(line, " \t\n\r") != "" {
|
||
*buf = append(*buf, makeIndentStr(indent+4)+"// "+line)
|
||
}
|
||
}
|
||
|
||
endComma := ""
|
||
if i < (num - 1) {
|
||
endComma = ","
|
||
}
|
||
repeated := isRepeated(f)
|
||
tags := getTagsInComment(fComment.Leading)
|
||
if isScalar {
|
||
mockVal, _ := t.mockValueForField(f, tags)
|
||
|
||
if repeated {
|
||
// "key" : [
|
||
// value
|
||
// ]
|
||
*buf = append(*buf, makeIndentStr(indent+4)+`"`+getJsonTag(f)+`": [`)
|
||
*buf = append(*buf, makeIndentStr(indent+8)+mockVal)
|
||
*buf = append(*buf, makeIndentStr(indent+4)+`]`+endComma)
|
||
} else {
|
||
// "key" : value
|
||
*buf = append(*buf, makeIndentStr(indent+4)+`"`+getJsonTag(f)+`": `+mockVal+endComma)
|
||
}
|
||
} else {
|
||
isMap := t.isMap(f)
|
||
if repeated {
|
||
if isMap {
|
||
*buf = append(*buf, makeIndentStr(indent+4)+`"`+getJsonTag(f)+`": {`)
|
||
} else {
|
||
*buf = append(*buf, makeIndentStr(indent+4)+`"`+getJsonTag(f)+`": [`)
|
||
}
|
||
}
|
||
subMsg := t.reg.MessageDefinition(f.GetTypeName())
|
||
if subMsg == nil {
|
||
panic(fmt.Sprintf("%v%v", f.TypeName, f.Type))
|
||
}
|
||
nextIndent := indent + 4
|
||
nextFname := getJsonTag(f)
|
||
if repeated {
|
||
nextIndent = indent + 8
|
||
nextFname = ""
|
||
}
|
||
if isMap {
|
||
mapKeyField := subMsg.Descriptor.Field[0]
|
||
mapValueField := subMsg.Descriptor.Field[1]
|
||
keyDesc := "mapKey"
|
||
if mapKeyField.GetType() != descriptor.FieldDescriptorProto_TYPE_STRING {
|
||
keyDesc = "1"
|
||
}
|
||
|
||
if mapValueField.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE {
|
||
// "mapKey" : {
|
||
// ...
|
||
// }
|
||
mapValueMsg := t.reg.MessageDefinition(mapValueField.GetTypeName())
|
||
t.exampleJsonForMsg(mapValueMsg, file, buf, keyDesc, nextIndent, "")
|
||
} else {
|
||
// "mapKey" : "map value"
|
||
val, _ := t.mockValueForField(mapValueField, tags)
|
||
|
||
*buf = append(*buf, makeIndentStr(indent+8)+`"`+keyDesc+`": `+val)
|
||
}
|
||
*buf = append(*buf, makeIndentStr(indent+4)+`}`+endComma)
|
||
} else {
|
||
if repeated {
|
||
t.exampleJsonForMsg(subMsg, file, buf, nextFname, nextIndent, "")
|
||
*buf = append(*buf, makeIndentStr(indent+4)+`]`+endComma)
|
||
} else {
|
||
t.exampleJsonForMsg(subMsg, file, buf, nextFname, nextIndent, endComma)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
*buf = append(*buf, makeIndentStr(indent)+"}"+outEndComma)
|
||
}
|
||
|
||
// Is this field a scalar numeric type?
|
||
func isScalar(field *descriptor.FieldDescriptorProto) bool {
|
||
if field.Type == nil {
|
||
return false
|
||
}
|
||
switch *field.Type {
|
||
case descriptor.FieldDescriptorProto_TYPE_DOUBLE,
|
||
descriptor.FieldDescriptorProto_TYPE_FLOAT,
|
||
descriptor.FieldDescriptorProto_TYPE_INT64,
|
||
descriptor.FieldDescriptorProto_TYPE_UINT64,
|
||
descriptor.FieldDescriptorProto_TYPE_INT32,
|
||
descriptor.FieldDescriptorProto_TYPE_FIXED64,
|
||
descriptor.FieldDescriptorProto_TYPE_FIXED32,
|
||
descriptor.FieldDescriptorProto_TYPE_BOOL,
|
||
descriptor.FieldDescriptorProto_TYPE_UINT32,
|
||
descriptor.FieldDescriptorProto_TYPE_ENUM,
|
||
descriptor.FieldDescriptorProto_TYPE_SFIXED32,
|
||
descriptor.FieldDescriptorProto_TYPE_SFIXED64,
|
||
descriptor.FieldDescriptorProto_TYPE_SINT32,
|
||
descriptor.FieldDescriptorProto_TYPE_SINT64,
|
||
descriptor.FieldDescriptorProto_TYPE_BYTES,
|
||
descriptor.FieldDescriptorProto_TYPE_STRING:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
func (t *bm) registerPackageName(name string) (alias string) {
|
||
alias = name
|
||
i := 1
|
||
for t.pkgNamesInUse[alias] {
|
||
alias = name + strconv.Itoa(i)
|
||
i++
|
||
}
|
||
t.pkgNamesInUse[alias] = true
|
||
t.pkgs[name] = alias
|
||
return alias
|
||
}
|
||
|
||
type visitor struct {
|
||
funcMap map[string]bool
|
||
}
|
||
|
||
func (v visitor) Visit(n ast.Node) ast.Visitor {
|
||
switch d := n.(type) {
|
||
case *ast.FuncDecl:
|
||
v.funcMap[d.Name.Name] = true
|
||
}
|
||
return v
|
||
}
|
||
|
||
// generateServiceImpl returns service implementation file service/{prefix}/service.go
|
||
// if file not exists
|
||
// else it returns nil
|
||
func (t *bm) generateServiceImpl(file *descriptor.FileDescriptorProto, svc *descriptor.ServiceDescriptorProto) *plugin.CodeGeneratorResponse_File {
|
||
resp := new(plugin.CodeGeneratorResponse_File)
|
||
prefix := t.getVersionPrefix()
|
||
importPath := t.getPbImportPath(file.GetName())
|
||
var alias = t.getPkgAlias()
|
||
confPath := projInfo.importPath + "/conf"
|
||
if projInfo.hasInternalPkg {
|
||
confPath = projInfo.importPath + "/internal/conf"
|
||
}
|
||
name := "service/" + prefix + "/" + lcFirst(svc.GetName()) + ".go"
|
||
if projInfo.hasInternalPkg {
|
||
name = "internal/" + name
|
||
}
|
||
name = projInfo.pathRefToProj + name
|
||
resp.Name = &name
|
||
if _, err := os.Stat(name); !os.IsNotExist(err) {
|
||
// Insert methods if file already exists
|
||
fset := token.NewFileSet()
|
||
astTree, err := parser.ParseFile(fset, name, nil, parser.ParseComments)
|
||
if err != nil {
|
||
panic("parse file error: " + name + " err: " + err.Error())
|
||
}
|
||
v := visitor{funcMap: map[string]bool{}}
|
||
ast.Walk(v, astTree)
|
||
|
||
t.output.Reset()
|
||
buf, err := ioutil.ReadFile(name)
|
||
if err != nil {
|
||
panic("cannot read file:" + name)
|
||
}
|
||
t.P(string(buf))
|
||
t.generateBmImpl(file, svc, v.funcMap)
|
||
resp.Content = proto.String(t.formattedOutput())
|
||
t.output.Reset()
|
||
return resp
|
||
}
|
||
|
||
tplPkg := "service"
|
||
if t.genPkgName[:1] == "v" {
|
||
tplPkg = t.genPkgName
|
||
}
|
||
t.P(`package `, tplPkg)
|
||
t.P()
|
||
t.P(`import (`)
|
||
t.P(` `, alias, ` "`, importPath, `"`)
|
||
t.P(` "`, confPath, `"`)
|
||
t.P(` "context"`)
|
||
t.P(`)`)
|
||
for pkg, importPath := range t.deps {
|
||
t.P(`import `, pkg, ` `, importPath)
|
||
}
|
||
svcStructName := serviceName(svc) + "Service"
|
||
t.P(`// `, svcStructName, ` struct`)
|
||
t.P(`type `, svcStructName, ` struct {`)
|
||
t.P(` conf *conf.Config`)
|
||
t.P(` // optionally add other properties here, such as dao`)
|
||
t.P(` // dao *dao.Dao`)
|
||
t.P(`}`)
|
||
t.P()
|
||
t.P(`//New`, svcStructName, ` init`)
|
||
t.P(`func New`, svcStructName, `(c *conf.Config) (s *`, svcStructName, `) {`)
|
||
t.P(` s = &`, svcStructName, `{`)
|
||
t.P(` conf: c,`)
|
||
t.P(` }`)
|
||
t.P(` return s`)
|
||
t.P(`}`)
|
||
|
||
comments, err := t.reg.ServiceComments(file, svc)
|
||
if err == nil {
|
||
t.printComments(comments)
|
||
}
|
||
t.P()
|
||
t.generateBmImpl(file, svc, map[string]bool{})
|
||
resp.Content = proto.String(t.formattedOutput())
|
||
t.output.Reset()
|
||
return resp
|
||
}
|
||
|
||
func (t *bm) generate(file *descriptor.FileDescriptorProto) *plugin.CodeGeneratorResponse_File {
|
||
resp := new(plugin.CodeGeneratorResponse_File)
|
||
if len(file.Service) == 0 {
|
||
return nil
|
||
}
|
||
|
||
t.generateFileHeader(file, t.genPkgName)
|
||
|
||
t.generateImports(file)
|
||
|
||
t.generateMiddlewareInfo(file)
|
||
for i, service := range file.Service {
|
||
t.generateService(file, service, i)
|
||
t.generateSingleRoute(file, service, i)
|
||
}
|
||
|
||
t.generateFileDescriptor(file)
|
||
|
||
resp.Name = proto.String(goFileName(file, ".bm.go"))
|
||
resp.Content = proto.String(t.formattedOutput())
|
||
t.output.Reset()
|
||
|
||
t.filesHandled++
|
||
return resp
|
||
}
|
||
|
||
func (t *bm) generateMiddlewareInfo(file *descriptor.FileDescriptorProto) {
|
||
t.P()
|
||
for _, service := range file.Service {
|
||
name := serviceName(service)
|
||
for _, method := range service.Method {
|
||
_, _, path := t.getHttpInfo(file, service, method, nil)
|
||
t.P(`var Path`, name, methodName(method), ` = "`, path, `"`)
|
||
}
|
||
t.P()
|
||
}
|
||
}
|
||
func (t *bm) generateFileHeader(file *descriptor.FileDescriptorProto, pkgName string) {
|
||
t.P("// Code generated by protoc-gen-bm ", gen.Version, ", DO NOT EDIT.")
|
||
t.P("// source: ", file.GetName())
|
||
t.P()
|
||
if t.filesHandled == 0 {
|
||
t.P("/*")
|
||
t.P("Package ", t.genPkgName, " is a generated blademaster stub package.")
|
||
t.P("This code was generated with go-common/app/tool/bmgen/protoc-gen-bm ", gen.Version, ".")
|
||
t.P()
|
||
comment, err := t.reg.FileComments(file)
|
||
if err == nil && comment.Leading != "" {
|
||
for _, line := range strings.Split(comment.Leading, "\n") {
|
||
line = strings.TrimPrefix(line, " ")
|
||
// ensure we don't escape from the block comment
|
||
line = strings.Replace(line, "*/", "* /", -1)
|
||
t.P(line)
|
||
}
|
||
t.P()
|
||
}
|
||
t.P("It is generated from these files:")
|
||
for _, f := range t.genFiles {
|
||
t.P("\t", f.GetName())
|
||
}
|
||
t.P("*/")
|
||
}
|
||
t.P(`package `, pkgName)
|
||
t.P()
|
||
}
|
||
|
||
func (t *bm) generateImports(file *descriptor.FileDescriptorProto) {
|
||
if len(file.Service) == 0 {
|
||
return
|
||
}
|
||
t.P(`import (`)
|
||
//t.P(` `,t.pkgs["context"], ` "context"`)
|
||
t.P(` "context"`)
|
||
t.P()
|
||
t.P(` bm "go-common/library/net/http/blademaster"`)
|
||
t.P(` "go-common/library/net/http/blademaster/binding"`)
|
||
|
||
t.P(`)`)
|
||
// It's legal to import a message and use it as an input or output for a
|
||
// method. Make sure to import the package of any such message. First, dedupe
|
||
// them.
|
||
deps := make(map[string]string) // Map of package name to quoted import path.
|
||
ourImportPath := path.Dir(goFileName(file, ""))
|
||
for _, s := range file.Service {
|
||
for _, m := range s.Method {
|
||
defs := []*typemap.MessageDefinition{
|
||
t.reg.MethodInputDefinition(m),
|
||
t.reg.MethodOutputDefinition(m),
|
||
}
|
||
for _, def := range defs {
|
||
// By default, import path is the dirname of the Go filename.
|
||
importPath := path.Dir(goFileName(def.File, ""))
|
||
if importPath == ourImportPath {
|
||
continue
|
||
}
|
||
if substitution, ok := t.importMap[def.File.GetName()]; ok {
|
||
importPath = substitution
|
||
}
|
||
importPath = t.importPrefix + importPath
|
||
pkg := t.goPackageName(def.File)
|
||
deps[pkg] = strconv.Quote(importPath)
|
||
}
|
||
}
|
||
}
|
||
t.deps = deps
|
||
for pkg, importPath := range deps {
|
||
t.P(`import `, pkg, ` `, importPath)
|
||
}
|
||
if len(deps) > 0 {
|
||
t.P()
|
||
}
|
||
t.P()
|
||
t.P(`// to suppressed 'imported but not used warning'`)
|
||
t.P(`var _ *bm.Context`)
|
||
t.P(`var _ context.Context`)
|
||
t.P(`var _ binding.StructValidator`)
|
||
|
||
}
|
||
|
||
// P forwards to g.gen.P, which prints output.
|
||
func (t *bm) P(args ...string) {
|
||
for _, v := range args {
|
||
t.output.WriteString(v)
|
||
}
|
||
t.output.WriteByte('\n')
|
||
}
|
||
|
||
// Big header comments to makes it easier to visually parse a generated file.
|
||
func (t *bm) sectionComment(sectionTitle string) {
|
||
t.P()
|
||
t.P(`// `, strings.Repeat("=", len(sectionTitle)))
|
||
t.P(`// `, sectionTitle)
|
||
t.P(`// `, strings.Repeat("=", len(sectionTitle)))
|
||
t.P()
|
||
}
|
||
|
||
func (t *bm) generateService(file *descriptor.FileDescriptorProto, service *descriptor.ServiceDescriptorProto, index int) {
|
||
servName := serviceName(service)
|
||
|
||
t.sectionComment(servName + ` Interface`)
|
||
t.generateBMInterface(file, service)
|
||
|
||
}
|
||
|
||
// import project/api的路径
|
||
func (t *bm) getPbImportPath(filename string) (importPath string) {
|
||
wd, err := os.Getwd()
|
||
if err != nil {
|
||
panic("cannot get working directory")
|
||
}
|
||
index := strings.Index(wd, "go-common")
|
||
if index == -1 {
|
||
gen.Fail("must use inside go-common")
|
||
}
|
||
dir := filepath.Dir(filename)
|
||
if dir != "." {
|
||
importPath = wd + "/" + dir
|
||
} else {
|
||
importPath = wd
|
||
}
|
||
importPath = importPath[index:]
|
||
return
|
||
}
|
||
|
||
// getProjPath return project path relative to GOPATH
|
||
func (t *bm) getProjPath() string {
|
||
wd, err := os.Getwd()
|
||
if err != nil {
|
||
panic("cannot get working directory")
|
||
}
|
||
index := strings.Index(wd, "go-common")
|
||
if index == -1 {
|
||
gen.Fail("must use inside go-common")
|
||
}
|
||
projPkgPath := wd[index:]
|
||
return projPkgPath
|
||
}
|
||
|
||
func lcFirst(str string) string {
|
||
for i, v := range str {
|
||
return string(unicode.ToLower(v)) + str[i+1:]
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// TODO rename
|
||
func (t *bm) getLegacyPathPrefix(
|
||
svc *descriptor.ServiceDescriptorProto, pathParts []string, isInternal bool) (uriPrefix string) {
|
||
var parts []string
|
||
parts = append(parts, pathParts[0])
|
||
if isInternal {
|
||
parts = append(parts, "internal")
|
||
}
|
||
parts = append(parts, pathParts[1:]...)
|
||
uriPrefix = fmt.Sprintf("/x%s/%s", strings.Join(parts, "/"), lcFirst(svc.GetName()))
|
||
return
|
||
}
|
||
|
||
func (t *bm) getHttpInfo(
|
||
file *descriptor.FileDescriptorProto,
|
||
service *descriptor.ServiceDescriptorProto,
|
||
method *descriptor.MethodDescriptorProto,
|
||
tags []reflect.StructTag,
|
||
) (httpMethod string, oldPath string, newPath string) {
|
||
googleOptionInfo, err := ParseBMMethod(method)
|
||
if err == nil {
|
||
httpMethod = strings.ToUpper(googleOptionInfo.Method)
|
||
p := googleOptionInfo.PathPattern
|
||
if p != "" {
|
||
oldPath = p
|
||
newPath = p
|
||
return
|
||
}
|
||
}
|
||
|
||
if httpMethod == "" {
|
||
// resolve http method
|
||
httpMethod = getTagValue("method", tags)
|
||
if httpMethod == "" {
|
||
httpMethod = "GET"
|
||
} else {
|
||
httpMethod = strings.ToUpper(httpMethod)
|
||
}
|
||
}
|
||
|
||
isLegacy, parts := t.convertLegacyPackage(file.GetPackage())
|
||
if isLegacy {
|
||
apiInternal := getTagValue("internal", tags) == "true"
|
||
pathPrefix := t.getLegacyPathPrefix(service, parts, apiInternal)
|
||
oldPath = pathPrefix + `/` + method.GetName()
|
||
}
|
||
newPath = "/" + file.GetPackage() + "." + service.GetName() + "/" + method.GetName()
|
||
return
|
||
}
|
||
|
||
// 返回空,则不用考虑历史package
|
||
// 如果非空,则表示按照返回的pathParts做url规则
|
||
func (t *bm) convertLegacyPackage(pkgName string) (isLegacy bool, pathParts []string) {
|
||
var splits = strings.Split(pkgName, ".")
|
||
var remain []string
|
||
if len(splits) >= 2 {
|
||
splits = splits[0:2]
|
||
remain = splits[2:]
|
||
}
|
||
var pkgPrefix = strings.Join(splits, ".")
|
||
legacyPkg, isLegacy := legacyPathMapping[pkgPrefix]
|
||
if isLegacy {
|
||
legacyPkg = strings.Replace(pkgName, pkgPrefix, legacyPkg, 1)
|
||
pathParts = append(pathParts, strings.Split(legacyPkg, ".")...)
|
||
pathParts = append(pathParts, remain...)
|
||
}
|
||
return
|
||
}
|
||
|
||
func (t *bm) generateSingleRoute(
|
||
file *descriptor.FileDescriptorProto,
|
||
service *descriptor.ServiceDescriptorProto,
|
||
index int) {
|
||
// old mode is generate xx.route.go in the http pkg
|
||
// new mode is generate route code in the same .bm.go
|
||
// route rule /x{department}/{project-name}/{path_prefix}/method_name
|
||
// generate each route method
|
||
servName := serviceName(service)
|
||
versionPrefix := t.getVersionPrefix()
|
||
svcName := lcFirst(stringutils.CamelCase(versionPrefix)) + servName + "Svc"
|
||
t.P(`var `, svcName, ` `, servName, `BMServer`)
|
||
|
||
type methodInfo struct {
|
||
httpMethod string
|
||
midwares []string
|
||
routeFuncName string
|
||
path string
|
||
legacyPath string
|
||
methodName string
|
||
}
|
||
var methList []methodInfo
|
||
var allMidwareMap = make(map[string]bool)
|
||
var isLegacyPkg = false
|
||
for _, method := range service.Method {
|
||
var httpMethod string
|
||
var midwares []string
|
||
comments, _ := t.reg.MethodComments(file, service, method)
|
||
tags := getTagsInComment(comments.Leading)
|
||
if getTagValue("dynamic", tags) == "true" {
|
||
continue
|
||
}
|
||
httpMethod, legacyPath, path := t.getHttpInfo(file, service, method, tags)
|
||
if legacyPath != "" {
|
||
isLegacyPkg = true
|
||
}
|
||
|
||
midStr := getTagValue("midware", tags)
|
||
if midStr != "" {
|
||
midwares = strings.Split(midStr, ",")
|
||
for _, m := range midwares {
|
||
allMidwareMap[m] = true
|
||
}
|
||
}
|
||
|
||
methName := methodName(method)
|
||
inputType := t.goTypeName(method.GetInputType())
|
||
|
||
routeName := lcFirst(stringutils.CamelCase(servName) +
|
||
stringutils.CamelCase(methName))
|
||
|
||
methList = append(methList, methodInfo{
|
||
httpMethod: httpMethod,
|
||
midwares: midwares,
|
||
routeFuncName: routeName,
|
||
path: path,
|
||
legacyPath: legacyPath,
|
||
methodName: method.GetName(),
|
||
})
|
||
|
||
t.P(fmt.Sprintf("func %s (c *bm.Context) {", routeName))
|
||
t.P(` p := new(`, inputType, `)`)
|
||
t.P(` if err := c.BindWith(p, binding.Default(c.Request.Method, c.Request.Header.Get("Content-Type"))); err != nil {`)
|
||
t.P(` return`)
|
||
t.P(` }`)
|
||
t.P(` resp, err := `, svcName, `.`, methName, `(c, p)`)
|
||
t.P(` c.JSON(resp, err)`)
|
||
t.P(`}`)
|
||
t.P(``)
|
||
}
|
||
|
||
// generate route group
|
||
var midList []string
|
||
for m := range allMidwareMap {
|
||
midList = append(midList, m+" bm.HandlerFunc")
|
||
}
|
||
|
||
sort.Strings(midList)
|
||
|
||
// 注册老的路由的方法
|
||
if isLegacyPkg {
|
||
funcName := `Register` + stringutils.CamelCase(versionPrefix) + servName + `Service`
|
||
t.P(`// `, funcName, ` Register the blademaster route with middleware map`)
|
||
t.P(`// midMap is the middleware map, the key is defined in proto`)
|
||
t.P(`func `, funcName, `(e *bm.Engine, svc `, servName, "BMServer, midMap map[string]bm.HandlerFunc)", ` {`)
|
||
var keys []string
|
||
for m := range allMidwareMap {
|
||
keys = append(keys, m)
|
||
}
|
||
// to keep generated code consistent
|
||
sort.Strings(keys)
|
||
for _, m := range keys {
|
||
t.P(m, ` := midMap["`, m, `"]`)
|
||
}
|
||
|
||
t.P(svcName, ` = svc`)
|
||
for _, methInfo := range methList {
|
||
var midArgStr string
|
||
if len(methInfo.midwares) == 0 {
|
||
midArgStr = ""
|
||
} else {
|
||
midArgStr = strings.Join(methInfo.midwares, ", ") + ", "
|
||
}
|
||
t.P(`e.`, methInfo.httpMethod, `("`, methInfo.legacyPath, `", `, midArgStr, methInfo.routeFuncName, `)`)
|
||
}
|
||
t.P(` }`)
|
||
}
|
||
// 新的注册路由的方法
|
||
var bmFuncName = fmt.Sprintf("Register%sBMServer", servName)
|
||
t.P(`// `, bmFuncName, ` Register the blademaster route`)
|
||
t.P(`func `, bmFuncName, `(e *bm.Engine, server `, servName, `BMServer) {`)
|
||
t.P(svcName, ` = server`)
|
||
for _, methInfo := range methList {
|
||
t.P(`e.`, methInfo.httpMethod, `("`, methInfo.path, `",`, methInfo.routeFuncName, ` )`)
|
||
}
|
||
t.P(` }`)
|
||
}
|
||
|
||
func (t *bm) generateBMInterface(file *descriptor.FileDescriptorProto, service *descriptor.ServiceDescriptorProto) {
|
||
servName := serviceName(service)
|
||
|
||
comments, err := t.reg.ServiceComments(file, service)
|
||
if err == nil {
|
||
t.printComments(comments)
|
||
}
|
||
t.P(`type `, servName, `BMServer interface {`)
|
||
for _, method := range service.Method {
|
||
t.generateSignature(file, service, method, comments)
|
||
t.P()
|
||
}
|
||
t.P(`}`)
|
||
}
|
||
|
||
// pb包的别名
|
||
// 用户生成service实现模板时,对pb文件的引用
|
||
// 如果是v*的package 则为v*pb
|
||
// 其他为pb
|
||
func (t *bm) getPkgAlias() string {
|
||
if t.genPkgName == "" {
|
||
return "pb"
|
||
}
|
||
if t.genPkgName[:1] == "v" {
|
||
return t.genPkgName + "pb"
|
||
}
|
||
return "pb"
|
||
}
|
||
|
||
// 如果是v*开始的 返回v*
|
||
// 否则返回空
|
||
func (t *bm) getVersionPrefix() string {
|
||
if t.genPkgName == "" {
|
||
return ""
|
||
}
|
||
if t.genPkgName[:1] == "v" {
|
||
return t.genPkgName
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func (t *bm) generateBmImpl(file *descriptor.FileDescriptorProto, service *descriptor.ServiceDescriptorProto,
|
||
existMap map[string]bool) {
|
||
var pkgName = t.getPkgAlias()
|
||
svcName := serviceName(service) + "Service"
|
||
|
||
for _, method := range service.Method {
|
||
methName := methodName(method)
|
||
if existMap[methName] {
|
||
continue
|
||
}
|
||
comments, err := t.reg.MethodComments(file, service, method)
|
||
tags := getTagsInComment(comments.Leading)
|
||
respDynamic := getTagValue("dynamic_resp", tags) == "true"
|
||
genImp := func(dynamicRet bool) {
|
||
t.P(`// `, methName, " implementation")
|
||
if err == nil {
|
||
t.printComments(comments)
|
||
}
|
||
outputType := t.goTypeName(method.GetOutputType())
|
||
inputType := t.goTypeName(method.GetInputType())
|
||
var body string
|
||
var ownPkg = t.isOwnPackage(method.GetOutputType())
|
||
var respType string
|
||
if ownPkg {
|
||
respType = pkgName + "." + outputType
|
||
} else {
|
||
respType = outputType
|
||
}
|
||
if dynamicRet {
|
||
body = fmt.Sprintf(`func (s *%s) %s(ctx context.Context, req *%s.%s) (resp interface{}, err error) {`,
|
||
svcName, methName, pkgName, inputType)
|
||
} else {
|
||
body = fmt.Sprintf(`func (s *%s) %s(ctx context.Context, req *%s.%s) (resp *%s, err error) {`,
|
||
svcName, methName, pkgName, inputType, respType)
|
||
}
|
||
|
||
t.P(body)
|
||
t.P(fmt.Sprintf("resp = &%s{}", respType))
|
||
t.P(` return`)
|
||
t.P(`}`)
|
||
t.P()
|
||
}
|
||
genImp(respDynamic)
|
||
}
|
||
}
|
||
|
||
func (t *bm) generateSignature(file *descriptor.FileDescriptorProto,
|
||
service *descriptor.ServiceDescriptorProto,
|
||
method *descriptor.MethodDescriptorProto,
|
||
comments typemap.DefinitionComments) {
|
||
comments, err := t.reg.MethodComments(file, service, method)
|
||
|
||
methName := methodName(method)
|
||
outputType := t.goTypeName(method.GetOutputType())
|
||
inputType := t.goTypeName(method.GetInputType())
|
||
tags := getTagsInComment(comments.Leading)
|
||
if getTagValue("dynamic", tags) == "true" {
|
||
return
|
||
}
|
||
|
||
if err == nil {
|
||
t.printComments(comments)
|
||
}
|
||
|
||
respDynamic := getTagValue("dynamic_resp", tags) == "true"
|
||
if respDynamic {
|
||
t.P(fmt.Sprintf(` %s(ctx context.Context, req *%s) (resp interface{}, err error)`,
|
||
methName, inputType))
|
||
} else {
|
||
t.P(fmt.Sprintf(` %s(ctx context.Context, req *%s) (resp *%s, err error)`,
|
||
methName, inputType, outputType))
|
||
}
|
||
}
|
||
|
||
func (t *bm) generateFileDescriptor(file *descriptor.FileDescriptorProto) {
|
||
// Copied straight of of protoc-gen-go, which trims out comments.
|
||
pb := proto.Clone(file).(*descriptor.FileDescriptorProto)
|
||
pb.SourceCodeInfo = nil
|
||
|
||
b, err := proto.Marshal(pb)
|
||
if err != nil {
|
||
gen.Fail(err.Error())
|
||
}
|
||
|
||
var buf bytes.Buffer
|
||
w, _ := gzip.NewWriterLevel(&buf, gzip.BestCompression)
|
||
w.Write(b)
|
||
w.Close()
|
||
buf.Bytes()
|
||
}
|
||
|
||
func (t *bm) printComments(comments typemap.DefinitionComments) bool {
|
||
text := strings.TrimSuffix(comments.Leading, "\n")
|
||
if len(strings.TrimSpace(text)) == 0 {
|
||
return false
|
||
}
|
||
split := strings.Split(text, "\n")
|
||
for _, line := range split {
|
||
t.P("// ", strings.TrimPrefix(line, " "))
|
||
}
|
||
return len(split) > 0
|
||
}
|
||
|
||
// Given a protobuf name for a Message, return the Go name we will use for that
|
||
// type, including its package prefix.
|
||
func (t *bm) goTypeName(protoName string) string {
|
||
def := t.reg.MessageDefinition(protoName)
|
||
if def == nil {
|
||
gen.Fail("could not find message for", protoName)
|
||
}
|
||
|
||
var prefix string
|
||
if pkg := t.goPackageName(def.File); pkg != t.genPkgName {
|
||
prefix = pkg + "."
|
||
}
|
||
|
||
var name string
|
||
for _, parent := range def.Lineage() {
|
||
name += parent.Descriptor.GetName() + "_"
|
||
}
|
||
name += def.Descriptor.GetName()
|
||
return prefix + name
|
||
}
|
||
|
||
func (t *bm) isOwnPackage(protoName string) bool {
|
||
def := t.reg.MessageDefinition(protoName)
|
||
if def == nil {
|
||
gen.Fail("could not find message for", protoName)
|
||
}
|
||
pkg := t.goPackageName(def.File)
|
||
return pkg == t.genPkgName
|
||
}
|
||
|
||
func (t *bm) goPackageName(file *descriptor.FileDescriptorProto) string {
|
||
return t.fileToGoPackageName[file]
|
||
}
|
||
|
||
func (t *bm) formattedOutput() string {
|
||
// Reformat generated code.
|
||
fset := token.NewFileSet()
|
||
raw := t.output.Bytes()
|
||
ast, err := parser.ParseFile(fset, "", raw, parser.ParseComments)
|
||
if err != nil {
|
||
// Print out the bad code with line numbers.
|
||
// This should never happen in practice, but it can while changing generated code,
|
||
// so consider this a debugging aid.
|
||
var src bytes.Buffer
|
||
s := bufio.NewScanner(bytes.NewReader(raw))
|
||
for line := 1; s.Scan(); line++ {
|
||
fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
|
||
}
|
||
gen.Fail("bad Go source code was generated:", err.Error(), "\n"+src.String())
|
||
}
|
||
|
||
out := bytes.NewBuffer(nil)
|
||
err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(out, fset, ast)
|
||
if err != nil {
|
||
gen.Fail("generated Go source code could not be reformatted:", err.Error())
|
||
}
|
||
|
||
return out.String()
|
||
}
|
||
|
||
func serviceName(service *descriptor.ServiceDescriptorProto) string {
|
||
return stringutils.CamelCase(service.GetName())
|
||
}
|
||
|
||
func methodName(method *descriptor.MethodDescriptorProto) string {
|
||
return stringutils.CamelCase(method.GetName())
|
||
}
|
||
|
||
func fileDescSliceContains(slice []*descriptor.FileDescriptorProto, f *descriptor.FileDescriptorProto) bool {
|
||
for _, sf := range slice {
|
||
if f == sf {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// deduceGenPkgName figures out the go package name to use for generated code.
|
||
// Will try to use the explicit go_package setting in a file (if set, must be
|
||
// consistent in all files). If no files have go_package set, then use the
|
||
// protobuf package name (must be consistent in all files)
|
||
func deduceGenPkgName(genFiles []*descriptor.FileDescriptorProto) (string, error) {
|
||
var genPkgName string
|
||
for _, f := range genFiles {
|
||
name, explicit := goPackageName(f)
|
||
if explicit {
|
||
name = stringutils.CleanIdentifier(name)
|
||
if genPkgName != "" && genPkgName != name {
|
||
// Make sure they're all set consistently.
|
||
return "", errors.Errorf("files have conflicting go_package settings, must be the same: %q and %q", genPkgName, name)
|
||
}
|
||
genPkgName = name
|
||
}
|
||
}
|
||
if genPkgName != "" {
|
||
return genPkgName, nil
|
||
}
|
||
|
||
// If there is no explicit setting, then check the implicit package name
|
||
// (derived from the protobuf package name) of the files and make sure it's
|
||
// consistent.
|
||
for _, f := range genFiles {
|
||
name, _ := goPackageName(f)
|
||
name = stringutils.CleanIdentifier(name)
|
||
if genPkgName != "" && genPkgName != name {
|
||
return "", errors.Errorf("files have conflicting package names, must be the same or overridden with go_package: %q and %q", genPkgName, name)
|
||
}
|
||
genPkgName = name
|
||
}
|
||
|
||
// All the files have the same name, so we're good.
|
||
return genPkgName, nil
|
||
}
|