Create & Init Project...
This commit is contained in:
236
app/service/main/account-recovery/dao/sqlbuilder/args.go
Normal file
236
app/service/main/account-recovery/dao/sqlbuilder/args.go
Normal file
@ -0,0 +1,236 @@
|
||||
// Copyright 2018 Huan Du. All rights reserved.
|
||||
// Licensed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Args stores arguments associated with a SQL.
|
||||
type Args struct {
|
||||
// The default flavor used by `Args#Compile`
|
||||
Flavor Flavor
|
||||
|
||||
args []interface{}
|
||||
namedArgs map[string]int
|
||||
sqlNamedArgs map[string]int
|
||||
onlyNamed bool
|
||||
}
|
||||
|
||||
// Add adds an arg to Args and returns a placeholder.
|
||||
func (args *Args) Add(arg interface{}) string {
|
||||
return fmt.Sprintf("$%v", args.add(arg))
|
||||
}
|
||||
|
||||
func (args *Args) add(arg interface{}) int {
|
||||
idx := len(args.args)
|
||||
|
||||
switch a := arg.(type) {
|
||||
case sql.NamedArg:
|
||||
if args.sqlNamedArgs == nil {
|
||||
args.sqlNamedArgs = map[string]int{}
|
||||
}
|
||||
|
||||
if p, ok := args.sqlNamedArgs[a.Name]; ok {
|
||||
arg = args.args[p]
|
||||
break
|
||||
}
|
||||
|
||||
args.sqlNamedArgs[a.Name] = idx
|
||||
case namedArgs:
|
||||
if args.namedArgs == nil {
|
||||
args.namedArgs = map[string]int{}
|
||||
}
|
||||
|
||||
if p, ok := args.namedArgs[a.name]; ok {
|
||||
arg = args.args[p]
|
||||
break
|
||||
}
|
||||
|
||||
// Find out the real arg and add it to args.
|
||||
idx = args.add(a.arg)
|
||||
args.namedArgs[a.name] = idx
|
||||
return idx
|
||||
}
|
||||
|
||||
args.args = append(args.args, arg)
|
||||
return idx
|
||||
}
|
||||
|
||||
// Compile compiles builder's format to standard sql and returns associated args.
|
||||
//
|
||||
// The format string uses a special syntax to represent arguments.
|
||||
//
|
||||
// $? refers successive arguments passed in the call. It works similar as `%v` in `fmt.Sprintf`.
|
||||
// $0 $1 ... $n refers nth-argument passed in the call. Next $? will use arguments n+1.
|
||||
// ${name} refers a named argument created by `Named` with `name`.
|
||||
// $$ is a "$" string.
|
||||
func (args *Args) Compile(format string, intialValue ...interface{}) (query string, values []interface{}) {
|
||||
return args.CompileWithFlavor(format, args.Flavor, intialValue...)
|
||||
}
|
||||
|
||||
// CompileWithFlavor compiles builder's format to standard sql with flavor and returns associated args.
|
||||
//
|
||||
// See doc for `Compile` to learn details.
|
||||
func (args *Args) CompileWithFlavor(format string, flavor Flavor, intialValue ...interface{}) (query string, values []interface{}) {
|
||||
buf := &bytes.Buffer{}
|
||||
idx := strings.IndexRune(format, '$')
|
||||
offset := 0
|
||||
values = intialValue
|
||||
|
||||
if flavor == invalidFlavor {
|
||||
flavor = DefaultFlavor
|
||||
}
|
||||
|
||||
for idx >= 0 && len(format) > 0 {
|
||||
if idx > 0 {
|
||||
buf.WriteString(format[:idx])
|
||||
}
|
||||
|
||||
format = format[idx+1:]
|
||||
|
||||
// Should not happen.
|
||||
if len(format) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if format[0] == '$' {
|
||||
buf.WriteRune('$')
|
||||
format = format[1:]
|
||||
} else if format[0] == '{' {
|
||||
format, values = args.compileNamed(buf, flavor, format, values)
|
||||
} else if !args.onlyNamed && '0' <= format[0] && format[0] <= '9' {
|
||||
format, values, offset = args.compileDigits(buf, flavor, format, values, offset)
|
||||
} else if !args.onlyNamed && format[0] == '?' {
|
||||
format, values, offset = args.compileSuccessive(buf, flavor, format[1:], values, offset)
|
||||
}
|
||||
|
||||
idx = strings.IndexRune(format, '$')
|
||||
}
|
||||
|
||||
if len(format) > 0 {
|
||||
buf.WriteString(format)
|
||||
}
|
||||
|
||||
query = buf.String()
|
||||
|
||||
if len(args.sqlNamedArgs) > 0 {
|
||||
// Stabilize the sequence to make it easier to write test cases.
|
||||
ints := make([]int, 0, len(args.sqlNamedArgs))
|
||||
|
||||
for _, p := range args.sqlNamedArgs {
|
||||
ints = append(ints, p)
|
||||
}
|
||||
|
||||
sort.Ints(ints)
|
||||
|
||||
for _, i := range ints {
|
||||
values = append(values, args.args[i])
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (args *Args) compileNamed(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}) (string, []interface{}) {
|
||||
i := 1
|
||||
|
||||
for ; i < len(format) && format[i] != '}'; i++ {
|
||||
// Nothing.
|
||||
}
|
||||
|
||||
// Invalid $ format. Ignore it.
|
||||
if i == len(format) {
|
||||
return format, values
|
||||
}
|
||||
|
||||
name := format[1:i]
|
||||
format = format[i+1:]
|
||||
|
||||
if p, ok := args.namedArgs[name]; ok {
|
||||
format, values, _ = args.compileSuccessive(buf, flavor, format, values, p)
|
||||
}
|
||||
|
||||
return format, values
|
||||
}
|
||||
|
||||
func (args *Args) compileDigits(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) {
|
||||
i := 1
|
||||
|
||||
for ; i < len(format) && '0' <= format[i] && format[i] <= '9'; i++ {
|
||||
// Nothing.
|
||||
}
|
||||
|
||||
digits := format[:i]
|
||||
format = format[i:]
|
||||
|
||||
if pointer, err := strconv.Atoi(digits); err == nil {
|
||||
return args.compileSuccessive(buf, flavor, format, values, pointer)
|
||||
}
|
||||
|
||||
return format, values, offset
|
||||
}
|
||||
|
||||
func (args *Args) compileSuccessive(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) {
|
||||
if offset >= len(args.args) {
|
||||
return format, values, offset
|
||||
}
|
||||
|
||||
arg := args.args[offset]
|
||||
values = args.compileArg(buf, flavor, values, arg)
|
||||
|
||||
return format, values, offset + 1
|
||||
}
|
||||
|
||||
func (args *Args) compileArg(buf *bytes.Buffer, flavor Flavor, values []interface{}, arg interface{}) []interface{} {
|
||||
switch a := arg.(type) {
|
||||
case Builder:
|
||||
var s string
|
||||
s, values = a.BuildWithFlavor(flavor, values...)
|
||||
buf.WriteString(s)
|
||||
case sql.NamedArg:
|
||||
buf.WriteRune('@')
|
||||
buf.WriteString(a.Name)
|
||||
case rawArgs:
|
||||
buf.WriteString(a.expr)
|
||||
case listArgs:
|
||||
if len(a.args) > 0 {
|
||||
values = args.compileArg(buf, flavor, values, a.args[0])
|
||||
}
|
||||
|
||||
for i := 1; i < len(a.args); i++ {
|
||||
buf.WriteString(", ")
|
||||
values = args.compileArg(buf, flavor, values, a.args[i])
|
||||
}
|
||||
default:
|
||||
switch flavor {
|
||||
case MySQL:
|
||||
buf.WriteRune('?')
|
||||
case PostgreSQL:
|
||||
fmt.Fprintf(buf, "$%v", len(values)+1)
|
||||
default:
|
||||
panic(fmt.Errorf("Args.CompileWithFlavor: invalid flavor %v (%v)", flavor, int(flavor)))
|
||||
}
|
||||
|
||||
values = append(values, arg)
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
// Copy is
|
||||
func (args *Args) Copy() *Args {
|
||||
return &Args{
|
||||
Flavor: args.Flavor,
|
||||
args: args.args,
|
||||
namedArgs: args.namedArgs,
|
||||
sqlNamedArgs: args.sqlNamedArgs,
|
||||
onlyNamed: args.onlyNamed,
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user