Create & Init Project...

This commit is contained in:
2019-04-22 18:49:16 +08:00
commit fc4fa37393
25440 changed files with 4054998 additions and 0 deletions

View File

@@ -0,0 +1,51 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_test",
"go_library",
)
go_test(
name = "go_default_test",
srcs = [
"args_test.go",
"builder_test.go",
"cond_test.go",
"flavor_test.go",
"modifiers_test.go",
"select_test.go",
],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
)
go_library(
name = "go_default_library",
srcs = [
"args.go",
"builder.go",
"cond.go",
"flavor.go",
"modifiers.go",
"select.go",
],
importpath = "go-common/app/service/main/account-recovery/dao/sqlbuilder",
tags = ["automanaged"],
visibility = ["//visibility:public"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,236 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"bytes"
"database/sql"
"fmt"
"sort"
"strconv"
"strings"
)
// Args stores arguments associated with a SQL.
type Args struct {
// The default flavor used by `Args#Compile`
Flavor Flavor
args []interface{}
namedArgs map[string]int
sqlNamedArgs map[string]int
onlyNamed bool
}
// Add adds an arg to Args and returns a placeholder.
func (args *Args) Add(arg interface{}) string {
return fmt.Sprintf("$%v", args.add(arg))
}
func (args *Args) add(arg interface{}) int {
idx := len(args.args)
switch a := arg.(type) {
case sql.NamedArg:
if args.sqlNamedArgs == nil {
args.sqlNamedArgs = map[string]int{}
}
if p, ok := args.sqlNamedArgs[a.Name]; ok {
arg = args.args[p]
break
}
args.sqlNamedArgs[a.Name] = idx
case namedArgs:
if args.namedArgs == nil {
args.namedArgs = map[string]int{}
}
if p, ok := args.namedArgs[a.name]; ok {
arg = args.args[p]
break
}
// Find out the real arg and add it to args.
idx = args.add(a.arg)
args.namedArgs[a.name] = idx
return idx
}
args.args = append(args.args, arg)
return idx
}
// Compile compiles builder's format to standard sql and returns associated args.
//
// The format string uses a special syntax to represent arguments.
//
// $? refers successive arguments passed in the call. It works similar as `%v` in `fmt.Sprintf`.
// $0 $1 ... $n refers nth-argument passed in the call. Next $? will use arguments n+1.
// ${name} refers a named argument created by `Named` with `name`.
// $$ is a "$" string.
func (args *Args) Compile(format string, intialValue ...interface{}) (query string, values []interface{}) {
return args.CompileWithFlavor(format, args.Flavor, intialValue...)
}
// CompileWithFlavor compiles builder's format to standard sql with flavor and returns associated args.
//
// See doc for `Compile` to learn details.
func (args *Args) CompileWithFlavor(format string, flavor Flavor, intialValue ...interface{}) (query string, values []interface{}) {
buf := &bytes.Buffer{}
idx := strings.IndexRune(format, '$')
offset := 0
values = intialValue
if flavor == invalidFlavor {
flavor = DefaultFlavor
}
for idx >= 0 && len(format) > 0 {
if idx > 0 {
buf.WriteString(format[:idx])
}
format = format[idx+1:]
// Should not happen.
if len(format) == 0 {
break
}
if format[0] == '$' {
buf.WriteRune('$')
format = format[1:]
} else if format[0] == '{' {
format, values = args.compileNamed(buf, flavor, format, values)
} else if !args.onlyNamed && '0' <= format[0] && format[0] <= '9' {
format, values, offset = args.compileDigits(buf, flavor, format, values, offset)
} else if !args.onlyNamed && format[0] == '?' {
format, values, offset = args.compileSuccessive(buf, flavor, format[1:], values, offset)
}
idx = strings.IndexRune(format, '$')
}
if len(format) > 0 {
buf.WriteString(format)
}
query = buf.String()
if len(args.sqlNamedArgs) > 0 {
// Stabilize the sequence to make it easier to write test cases.
ints := make([]int, 0, len(args.sqlNamedArgs))
for _, p := range args.sqlNamedArgs {
ints = append(ints, p)
}
sort.Ints(ints)
for _, i := range ints {
values = append(values, args.args[i])
}
}
return
}
func (args *Args) compileNamed(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}) (string, []interface{}) {
i := 1
for ; i < len(format) && format[i] != '}'; i++ {
// Nothing.
}
// Invalid $ format. Ignore it.
if i == len(format) {
return format, values
}
name := format[1:i]
format = format[i+1:]
if p, ok := args.namedArgs[name]; ok {
format, values, _ = args.compileSuccessive(buf, flavor, format, values, p)
}
return format, values
}
func (args *Args) compileDigits(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) {
i := 1
for ; i < len(format) && '0' <= format[i] && format[i] <= '9'; i++ {
// Nothing.
}
digits := format[:i]
format = format[i:]
if pointer, err := strconv.Atoi(digits); err == nil {
return args.compileSuccessive(buf, flavor, format, values, pointer)
}
return format, values, offset
}
func (args *Args) compileSuccessive(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) {
if offset >= len(args.args) {
return format, values, offset
}
arg := args.args[offset]
values = args.compileArg(buf, flavor, values, arg)
return format, values, offset + 1
}
func (args *Args) compileArg(buf *bytes.Buffer, flavor Flavor, values []interface{}, arg interface{}) []interface{} {
switch a := arg.(type) {
case Builder:
var s string
s, values = a.BuildWithFlavor(flavor, values...)
buf.WriteString(s)
case sql.NamedArg:
buf.WriteRune('@')
buf.WriteString(a.Name)
case rawArgs:
buf.WriteString(a.expr)
case listArgs:
if len(a.args) > 0 {
values = args.compileArg(buf, flavor, values, a.args[0])
}
for i := 1; i < len(a.args); i++ {
buf.WriteString(", ")
values = args.compileArg(buf, flavor, values, a.args[i])
}
default:
switch flavor {
case MySQL:
buf.WriteRune('?')
case PostgreSQL:
fmt.Fprintf(buf, "$%v", len(values)+1)
default:
panic(fmt.Errorf("Args.CompileWithFlavor: invalid flavor %v (%v)", flavor, int(flavor)))
}
values = append(values, arg)
}
return values
}
// Copy is
func (args *Args) Copy() *Args {
return &Args{
Flavor: args.Flavor,
args: args.args,
namedArgs: args.namedArgs,
sqlNamedArgs: args.sqlNamedArgs,
onlyNamed: args.onlyNamed,
}
}

View File

@@ -0,0 +1,76 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"bytes"
"fmt"
"strings"
"testing"
)
func TestArgs(t *testing.T) {
cases := map[string][]interface{}{
"abc ? def\n[123]": {"abc $? def", 123},
"abc ? def\n[456]": {"abc $0 def", 456},
"abc def\n[]": {"abc $1 def", 123},
"abc ? def\n[789]": {"abc ${s} def", Named("s", 789)},
"abc def \n[]": {"abc ${unknown} def ", 123},
"abc $ def\n[]": {"abc $$ def", 123},
"abcdef\n[]": {"abcdef$", 123},
"abc ? ? ? ? def\n[123 456 123 456]": {"abc $? $? $0 $? def", 123, 456, 789},
"abc ? raw ? raw def\n[123 123]": {"abc $? $? $0 $? def", 123, Raw("raw"), 789},
}
for expected, c := range cases {
args := new(Args)
for i := 1; i < len(c); i++ {
args.Add(c[i])
}
sql, values := args.Compile(c[0].(string))
actual := fmt.Sprintf("%v\n%v", sql, values)
if actual != expected {
t.Fatalf("invalid compile result. [expected:%v] [actual:%v]", expected, actual)
}
}
old := DefaultFlavor
DefaultFlavor = PostgreSQL
defer func() {
DefaultFlavor = old
}()
// PostgreSQL flavor compiled sql.
for expected, c := range cases {
args := new(Args)
for i := 1; i < len(c); i++ {
args.Add(c[i])
}
sql, values := args.Compile(c[0].(string))
actual := fmt.Sprintf("%v\n%v", sql, values)
expected = toPostgreSQL(expected)
if actual != expected {
t.Fatalf("invalid compile result. [expected:%v] [actual:%v]", expected, actual)
}
}
}
func toPostgreSQL(sql string) string {
parts := strings.Split(sql, "?")
buf := &bytes.Buffer{}
buf.WriteString(parts[0])
for i, p := range parts[1:] {
fmt.Fprintf(buf, "$%v", i+1)
buf.WriteString(p)
}
return buf.String()
}

View File

@@ -0,0 +1,105 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"fmt"
)
// Builder is a general SQL builder.
// It's used by Args to create nested SQL like the `IN` expression in
// `SELECT * FROM t1 WHERE id IN (SELECT id FROM t2)`.
type Builder interface {
Build() (sql string, args []interface{})
BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{})
}
type compiledBuilder struct {
args *Args
format string
}
func (cb *compiledBuilder) Build() (sql string, args []interface{}) {
return cb.args.Compile(cb.format)
}
func (cb *compiledBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {
return cb.args.CompileWithFlavor(cb.format, flavor, initialArg...)
}
type flavoredBuilder struct {
builder Builder
flavor Flavor
}
func (fb *flavoredBuilder) Build() (sql string, args []interface{}) {
return fb.builder.BuildWithFlavor(fb.flavor)
}
func (fb *flavoredBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {
return fb.builder.BuildWithFlavor(flavor, initialArg...)
}
// WithFlavor creates a new Builder based on builder with a default flavor.
func WithFlavor(builder Builder, flavor Flavor) Builder {
return &flavoredBuilder{
builder: builder,
flavor: flavor,
}
}
// Buildf creates a Builder from a format string using `fmt.Sprintf`-like syntax.
// As all arguments will be converted to a string internally, e.g. "$0",
// only `%v` and `%s` are valid.
func Buildf(format string, arg ...interface{}) Builder {
args := &Args{
Flavor: DefaultFlavor,
}
vars := make([]interface{}, 0, len(arg))
for _, a := range arg {
vars = append(vars, args.Add(a))
}
return &compiledBuilder{
args: args,
format: fmt.Sprintf(Escape(format), vars...),
}
}
// Build creates a Builder from a format string.
// The format string uses special syntax to represent arguments.
// See doc in `Args#Compile` for syntax details.
func Build(format string, arg ...interface{}) Builder {
args := &Args{
Flavor: DefaultFlavor,
}
for _, a := range arg {
args.Add(a)
}
return &compiledBuilder{
args: args,
format: format,
}
}
// BuildNamed creates a Builder from a format string.
// The format string uses `${key}` to refer the value of named by key.
func BuildNamed(format string, named map[string]interface{}) Builder {
args := &Args{
Flavor: DefaultFlavor,
onlyNamed: true,
}
for n, v := range named {
args.Add(Named(n, v))
}
return &compiledBuilder{
args: args,
format: format,
}
}

View File

@@ -0,0 +1,113 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"database/sql"
"fmt"
"reflect"
"testing"
)
func ExampleBuildf() {
sb := NewSelectBuilder()
sb.Select("id").From("user")
explain := Buildf("EXPLAIN %v LEFT JOIN SELECT * FROM banned WHERE state IN (%v, %v)", sb, 1, 2)
sql, args := explain.Build()
fmt.Println(sql)
fmt.Println(args)
// Output:
// EXPLAIN SELECT id FROM user LEFT JOIN SELECT * FROM banned WHERE state IN (?, ?)
// [1 2]
}
func ExampleBuild() {
sb := NewSelectBuilder()
sb.Select("id").From("user").Where(sb.In("status", 1, 2))
b := Build("EXPLAIN $? LEFT JOIN SELECT * FROM $? WHERE created_at > $? AND state IN (${states}) AND modified_at BETWEEN $2 AND $?",
sb, Raw("banned"), 1514458225, 1514544625, Named("states", List([]int{3, 4, 5})))
sql, args := b.Build()
fmt.Println(sql)
fmt.Println(args)
// Output:
// EXPLAIN SELECT id FROM user WHERE status IN (?, ?) LEFT JOIN SELECT * FROM banned WHERE created_at > ? AND state IN (?, ?, ?) AND modified_at BETWEEN ? AND ?
// [1 2 1514458225 3 4 5 1514458225 1514544625]
}
func ExampleBuildNamed() {
b := BuildNamed("SELECT * FROM ${table} WHERE status IN (${status}) AND name LIKE ${name} AND created_at > ${time} AND modified_at < ${time} + 86400",
map[string]interface{}{
"time": sql.Named("start", 1234567890),
"status": List([]int{1, 2, 5}),
"name": "Huan%",
"table": Raw("user"),
})
sql, args := b.Build()
fmt.Println(sql)
fmt.Println(args)
// Output:
// SELECT * FROM user WHERE status IN (?, ?, ?) AND name LIKE ? AND created_at > @start AND modified_at < @start + 86400
// [1 2 5 Huan% {{} start 1234567890}]
}
func ExampleWithFlavor() {
sql, args := WithFlavor(Buildf("SELECT * FROM foo WHERE id = %v", 1234), PostgreSQL).Build()
fmt.Println(sql)
fmt.Println(args)
// Explicitly use MySQL as the flavor.
sql, args = WithFlavor(Buildf("SELECT * FROM foo WHERE id = %v", 1234), PostgreSQL).BuildWithFlavor(MySQL)
fmt.Println(sql)
fmt.Println(args)
// Output:
// SELECT * FROM foo WHERE id = $1
// [1234]
// SELECT * FROM foo WHERE id = ?
// [1234]
}
func TestBuildWithPostgreSQL(t *testing.T) {
sb1 := PostgreSQL.NewSelectBuilder()
sb1.Select("col1", "col2").From("t1").Where(sb1.E("id", 1234), sb1.G("level", 2))
sb2 := PostgreSQL.NewSelectBuilder()
sb2.Select("col3", "col4").From("t2").Where(sb2.E("id", 4567), sb2.LE("level", 5))
// Use DefaultFlavor (MySQL) instead of PostgreSQL.
sql, args := Build("SELECT $1 AS col5 LEFT JOIN $0 LEFT JOIN $2", sb1, 7890, sb2).Build()
if expected := "SELECT ? AS col5 LEFT JOIN SELECT col1, col2 FROM t1 WHERE id = ? AND level > ? LEFT JOIN SELECT col3, col4 FROM t2 WHERE id = ? AND level <= ?"; sql != expected {
t.Fatalf("invalid sql. [expected:%v] [actual:%v]", expected, sql)
}
if expected := []interface{}{7890, 1234, 2, 4567, 5}; !reflect.DeepEqual(args, expected) {
t.Fatalf("invalid args. [expected:%v] [actual:%v]", expected, args)
}
old := DefaultFlavor
DefaultFlavor = PostgreSQL
defer func() {
DefaultFlavor = old
}()
sql, args = Build("SELECT $1 AS col5 LEFT JOIN $0 LEFT JOIN $2", sb1, 7890, sb2).Build()
if expected := "SELECT $1 AS col5 LEFT JOIN SELECT col1, col2 FROM t1 WHERE id = $2 AND level > $3 LEFT JOIN SELECT col3, col4 FROM t2 WHERE id = $4 AND level <= $5"; sql != expected {
t.Fatalf("invalid sql. [expected:%v] [actual:%v]", expected, sql)
}
if expected := []interface{}{7890, 1234, 2, 4567, 5}; !reflect.DeepEqual(args, expected) {
t.Fatalf("invalid args. [expected:%v] [actual:%v]", expected, args)
}
}

View File

@@ -0,0 +1,136 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"fmt"
"strings"
)
// Cond provides several helper methods to build conditions.
type Cond struct {
Args *Args
}
// Equal represents "field = value".
func (c *Cond) Equal(field string, value interface{}) string {
return fmt.Sprintf("%v = %v", Escape(field), c.Args.Add(value))
}
// E is an alias of Equal.
func (c *Cond) E(field string, value interface{}) string {
return c.Equal(field, value)
}
// NotEqual represents "field != value".
func (c *Cond) NotEqual(field string, value interface{}) string {
return fmt.Sprintf("%v <> %v", Escape(field), c.Args.Add(value))
}
// NE is an alias of NotEqual.
func (c *Cond) NE(field string, value interface{}) string {
return c.NotEqual(field, value)
}
// GreaterThan represents "field > value".
func (c *Cond) GreaterThan(field string, value interface{}) string {
return fmt.Sprintf("%v > %v", Escape(field), c.Args.Add(value))
}
// G is an alias of GreaterThan.
func (c *Cond) G(field string, value interface{}) string {
return c.GreaterThan(field, value)
}
// GreaterEqualThan represents "field >= value".
func (c *Cond) GreaterEqualThan(field string, value interface{}) string {
return fmt.Sprintf("%v >= %v", Escape(field), c.Args.Add(value))
}
// GE is an alias of GreaterEqualThan.
func (c *Cond) GE(field string, value interface{}) string {
return c.GreaterEqualThan(field, value)
}
// LessThan represents "field < value".
func (c *Cond) LessThan(field string, value interface{}) string {
return fmt.Sprintf("%v < %v", Escape(field), c.Args.Add(value))
}
// L is an alias of LessThan.
func (c *Cond) L(field string, value interface{}) string {
return c.LessThan(field, value)
}
// LessEqualThan represents "field <= value".
func (c *Cond) LessEqualThan(field string, value interface{}) string {
return fmt.Sprintf("%v <= %v", Escape(field), c.Args.Add(value))
}
// LE is an alias of LessEqualThan.
func (c *Cond) LE(field string, value interface{}) string {
return c.LessEqualThan(field, value)
}
// In represents "field IN (value...)".
func (c *Cond) In(field string, value ...interface{}) string {
vs := make([]string, 0, len(value))
for _, v := range value {
vs = append(vs, c.Args.Add(v))
}
return fmt.Sprintf("%v IN (%v)", Escape(field), strings.Join(vs, ", "))
}
// NotIn represents "field NOT IN (value...)".
func (c *Cond) NotIn(field string, value ...interface{}) string {
vs := make([]string, 0, len(value))
for _, v := range value {
vs = append(vs, c.Args.Add(v))
}
return fmt.Sprintf("%v NOT IN (%v)", Escape(field), strings.Join(vs, ", "))
}
// Like represents "field LIKE value".
func (c *Cond) Like(field string, value interface{}) string {
return fmt.Sprintf("%v LIKE %v", Escape(field), c.Args.Add(value))
}
// NotLike represents "field NOT LIKE value".
func (c *Cond) NotLike(field string, value interface{}) string {
return fmt.Sprintf("%v NOT LIKE %v", Escape(field), c.Args.Add(value))
}
// IsNull represents "field IS NULL".
func (c *Cond) IsNull(field string) string {
return fmt.Sprintf("%v IS NULL", Escape(field))
}
// IsNotNull represents "field IS NOT NULL".
func (c *Cond) IsNotNull(field string) string {
return fmt.Sprintf("%v IS NOT NULL", Escape(field))
}
// Between represents "field BETWEEN lower AND upper".
func (c *Cond) Between(field string, lower, upper interface{}) string {
return fmt.Sprintf("%v BETWEEN %v AND %v", Escape(field), c.Args.Add(lower), c.Args.Add(upper))
}
// NotBetween represents "field NOT BETWEEN lower AND upper".
func (c *Cond) NotBetween(field string, lower, upper interface{}) string {
return fmt.Sprintf("%v NOT BETWEEN %v AND %v", Escape(field), c.Args.Add(lower), c.Args.Add(upper))
}
// Or represents OR logic like "expr1 OR expr2 OR expr3".
func (c *Cond) Or(orExpr ...string) string {
return fmt.Sprintf("(%v)", strings.Join(orExpr, " OR "))
}
// Var returns a placeholder for value.
func (c *Cond) Var(value interface{}) string {
return c.Args.Add(value)
}

View File

@@ -0,0 +1,47 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"testing"
)
func TestCond(t *testing.T) {
cases := map[string]func() string{
"$$a = $0": func() string { return newTestCond().Equal("$a", 123) },
"$$b = $0": func() string { return newTestCond().E("$b", 123) },
"$$a <> $0": func() string { return newTestCond().NotEqual("$a", 123) },
"$$b <> $0": func() string { return newTestCond().NE("$b", 123) },
"$$a > $0": func() string { return newTestCond().GreaterThan("$a", 123) },
"$$b > $0": func() string { return newTestCond().G("$b", 123) },
"$$a >= $0": func() string { return newTestCond().GreaterEqualThan("$a", 123) },
"$$b >= $0": func() string { return newTestCond().GE("$b", 123) },
"$$a < $0": func() string { return newTestCond().LessThan("$a", 123) },
"$$b < $0": func() string { return newTestCond().L("$b", 123) },
"$$a <= $0": func() string { return newTestCond().LessEqualThan("$a", 123) },
"$$b <= $0": func() string { return newTestCond().LE("$b", 123) },
"$$a IN ($0, $1, $2)": func() string { return newTestCond().In("$a", 1, 2, 3) },
"$$a NOT IN ($0, $1, $2)": func() string { return newTestCond().NotIn("$a", 1, 2, 3) },
"$$a LIKE $0": func() string { return newTestCond().Like("$a", "%Huan%") },
"$$a NOT LIKE $0": func() string { return newTestCond().NotLike("$a", "%Huan%") },
"$$a IS NULL": func() string { return newTestCond().IsNull("$a") },
"$$a IS NOT NULL": func() string { return newTestCond().IsNotNull("$a") },
"$$a BETWEEN $0 AND $1": func() string { return newTestCond().Between("$a", 123, 456) },
"$$a NOT BETWEEN $0 AND $1": func() string { return newTestCond().NotBetween("$a", 123, 456) },
"(1 = 1 OR 2 = 2 OR 3 = 3)": func() string { return newTestCond().Or("1 = 1", "2 = 2", "3 = 3") },
"$0": func() string { return newTestCond().Var(123) },
}
for expected, f := range cases {
if actual := f(); expected != actual {
t.Fatalf("invalid result. [expected:%v] [actual:%v]", expected, actual)
}
}
}
func newTestCond() *Cond {
return &Cond{
Args: &Args{},
}
}

View File

@@ -0,0 +1,57 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import "fmt"
// Supported flavors.
const (
invalidFlavor Flavor = iota
MySQL
PostgreSQL
)
var (
// DefaultFlavor is the default flavor for all builders.
DefaultFlavor = MySQL
)
// Flavor is the flag to control the format of compiled sql.
type Flavor int
// String returns the name of f.
func (f Flavor) String() string {
switch f {
case MySQL:
return "MySQL"
case PostgreSQL:
return "PostgreSQL"
}
return "<invalid>"
}
// NewSelectBuilder creates a new SELECT builder with flavor.
func (f Flavor) NewSelectBuilder() *SelectBuilder {
b := newSelectBuilder()
b.SetFlavor(f)
return b
}
// Quote adds quote for name to make sure the name can be used safely
// as table name or field name.
//
// * For MySQL, use back quote (`) to quote name;
// * For PostgreSQL, use double quote (") to quote name.
func (f Flavor) Quote(name string) string {
switch f {
case MySQL:
return fmt.Sprintf("`%v`", name)
case PostgreSQL:
return fmt.Sprintf(`"%v"`, name)
}
return name
}

View File

@@ -0,0 +1,40 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"fmt"
"testing"
)
func TestFlavor(t *testing.T) {
cases := map[Flavor]string{
0: "<invalid>",
MySQL: "MySQL",
PostgreSQL: "PostgreSQL",
}
for f, expected := range cases {
if actual := f.String(); actual != expected {
t.Fatalf("invalid flavor name. [expected:%v] [actual:%v]", expected, actual)
}
}
}
func ExampleFlavor() {
// Create a flavored builder.
sb := PostgreSQL.NewSelectBuilder()
sb.Select("name").From("user").Where(
sb.E("id", 1234),
sb.G("rank", 3),
)
sql, args := sb.Build()
fmt.Println(sql)
fmt.Println(args)
// Output:
// SELECT name FROM user WHERE id = $1 AND rank > $2
// [1234 3]
}

View File

@@ -0,0 +1,98 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"reflect"
"strings"
)
// Escape replaces `$` with `$$` in ident.
func Escape(ident string) string {
return strings.Replace(ident, "$", "$$", -1)
}
// EscapeAll replaces `$` with `$$` in all strings of ident.
func EscapeAll(ident ...string) []string {
escaped := make([]string, 0, len(ident))
for _, i := range ident {
escaped = append(escaped, Escape(i))
}
return escaped
}
// Flatten recursively extracts values in slices and returns
// a flattened []interface{} with all values.
// If slices is not a slice, return `[]interface{}{slices}`.
func Flatten(slices interface{}) (flattened []interface{}) {
v := reflect.ValueOf(slices)
slices, flattened = flatten(v)
if slices != nil {
return []interface{}{slices}
}
return flattened
}
func flatten(v reflect.Value) (elem interface{}, flattened []interface{}) {
k := v.Kind()
for k == reflect.Interface {
v = v.Elem()
k = v.Kind()
}
if k != reflect.Slice && k != reflect.Array {
return v.Interface(), nil
}
for i, l := 0, v.Len(); i < l; i++ {
e, f := flatten(v.Index(i))
if e == nil {
flattened = append(flattened, f...)
} else {
flattened = append(flattened, e)
}
}
return
}
type rawArgs struct {
expr string
}
// Raw marks the expr as a raw value which will not be added to args.
func Raw(expr string) interface{} {
return rawArgs{expr}
}
type listArgs struct {
args []interface{}
}
// List marks arg as a list of data.
// If arg is `[]int{1, 2, 3}`, it will be compiled to `?, ?, ?` with args `[1 2 3]`.
func List(arg interface{}) interface{} {
return listArgs{Flatten(arg)}
}
type namedArgs struct {
name string
arg interface{}
}
// Named creates a named argument.
// Unlike `sql.Named`, this named argument works only with `Build` or `BuildNamed` for convenience
// and will be replaced to a `?` after `Compile`.
func Named(name string, arg interface{}) interface{} {
return namedArgs{
name: name,
arg: arg,
}
}

View File

@@ -0,0 +1,59 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"reflect"
"testing"
)
func TestEscape(t *testing.T) {
cases := map[string]string{
"foo": "foo",
"$foo": "$$foo",
"$$$": "$$$$$$",
}
var inputs, expects []string
for s, expected := range cases {
inputs = append(inputs, s)
expects = append(expects, expected)
if actual := Escape(s); actual != expected {
t.Fatalf("invalid escape result. [expected:%v] [actual:%v]", expected, actual)
}
}
actuals := EscapeAll(inputs...)
if !reflect.DeepEqual(expects, actuals) {
t.Fatalf("invalid escape result. [expected:%v] [actual:%v]", expects, actuals)
}
}
func TestFlatten(t *testing.T) {
cases := [][2]interface{}{
{
"foo",
[]interface{}{"foo"},
},
{
[]int{1, 2, 3},
[]interface{}{1, 2, 3},
},
{
[]interface{}{"abc", []int{1, 2, 3}, [3]string{"def", "ghi"}},
[]interface{}{"abc", 1, 2, 3, "def", "ghi", ""},
},
}
for _, c := range cases {
input, expected := c[0], c[1]
actual := Flatten(input)
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("invalid flatten result. [expected:%v] [actual:%v]", expected, actual)
}
}
}

View File

@@ -0,0 +1,212 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"bytes"
"fmt"
"strconv"
"strings"
)
// NewSelectBuilder creates a new SELECT builder.
func NewSelectBuilder() *SelectBuilder {
return DefaultFlavor.NewSelectBuilder()
}
func newSelectBuilder() *SelectBuilder {
args := &Args{}
return &SelectBuilder{
Cond: Cond{
Args: args,
},
limit: -1,
offset: -1,
args: args,
}
}
// SelectBuilder is a builder to build SELECT.
type SelectBuilder struct {
Cond
distinct bool
tables []string
selectCols []string
whereExprs []string
havingExprs []string
groupByCols []string
orderByCols []string
order string
limit int
offset int
args *Args
}
// Distinct marks this SELECT as DISTINCT.
func (sb *SelectBuilder) Distinct() *SelectBuilder {
sb.distinct = true
return sb
}
// Select sets columns in SELECT.
func (sb *SelectBuilder) Select(col ...string) *SelectBuilder {
sb.selectCols = EscapeAll(col...)
return sb
}
// From sets table names in SELECT.
func (sb *SelectBuilder) From(table ...string) *SelectBuilder {
sb.tables = table
return sb
}
// Where sets expressions of WHERE in SELECT.
func (sb *SelectBuilder) Where(andExpr ...string) *SelectBuilder {
sb.whereExprs = append(sb.whereExprs, andExpr...)
return sb
}
// Having sets expressions of HAVING in SELECT.
func (sb *SelectBuilder) Having(andExpr ...string) *SelectBuilder {
sb.havingExprs = append(sb.havingExprs, andExpr...)
return sb
}
// GroupBy sets columns of GROUP BY in SELECT.
func (sb *SelectBuilder) GroupBy(col ...string) *SelectBuilder {
sb.groupByCols = EscapeAll(col...)
return sb
}
// OrderBy sets columns of ORDER BY in SELECT.
func (sb *SelectBuilder) OrderBy(col ...string) *SelectBuilder {
sb.orderByCols = EscapeAll(col...)
return sb
}
// Asc sets order of ORDER BY to ASC.
func (sb *SelectBuilder) Asc() *SelectBuilder {
sb.order = "ASC"
return sb
}
// Desc sets order of ORDER BY to DESC.
func (sb *SelectBuilder) Desc() *SelectBuilder {
sb.order = "DESC"
return sb
}
// Limit sets the LIMIT in SELECT.
func (sb *SelectBuilder) Limit(limit int) *SelectBuilder {
sb.limit = limit
return sb
}
// Offset sets the LIMIT offset in SELECT.
func (sb *SelectBuilder) Offset(offset int) *SelectBuilder {
sb.offset = offset
return sb
}
// As returns an AS expression.
func (sb *SelectBuilder) As(col, alias string) string {
return fmt.Sprintf("%v AS %v", col, Escape(alias))
}
// BuilderAs returns an AS expression wrapping a complex SQL.
// According to SQL syntax, SQL built by builder is surrounded by parens.
func (sb *SelectBuilder) BuilderAs(builder Builder, alias string) string {
return fmt.Sprintf("(%v) AS %v", sb.Var(builder), Escape(alias))
}
// String returns the compiled SELECT string.
func (sb *SelectBuilder) String() string {
s, _ := sb.Build()
return s
}
// Build returns compiled SELECT string and args.
// They can be used in `DB#Query` of package `database/sql` directly.
func (sb *SelectBuilder) Build() (sql string, args []interface{}) {
return sb.BuildWithFlavor(sb.args.Flavor)
}
// BuildWithFlavor returns compiled SELECT string and args with flavor and initial args.
// They can be used in `DB#Query` of package `database/sql` directly.
func (sb *SelectBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {
buf := &bytes.Buffer{}
buf.WriteString("SELECT ")
if sb.distinct {
buf.WriteString("DISTINCT ")
}
buf.WriteString(strings.Join(sb.selectCols, ", "))
buf.WriteString(" FROM ")
buf.WriteString(strings.Join(sb.tables, ", "))
if len(sb.whereExprs) > 0 {
buf.WriteString(" WHERE ")
buf.WriteString(strings.Join(sb.whereExprs, " AND "))
}
if len(sb.groupByCols) > 0 {
buf.WriteString(" GROUP BY ")
buf.WriteString(strings.Join(sb.groupByCols, ", "))
if len(sb.havingExprs) > 0 {
buf.WriteString(" HAVING ")
buf.WriteString(strings.Join(sb.havingExprs, " AND "))
}
}
if len(sb.orderByCols) > 0 {
buf.WriteString(" ORDER BY ")
buf.WriteString(strings.Join(sb.orderByCols, ", "))
if sb.order != "" {
buf.WriteRune(' ')
buf.WriteString(sb.order)
}
}
if sb.limit >= 0 {
buf.WriteString(" LIMIT ")
buf.WriteString(strconv.Itoa(sb.limit))
if sb.offset >= 0 {
buf.WriteString(" OFFSET ")
buf.WriteString(strconv.Itoa(sb.offset))
}
}
return sb.Args.CompileWithFlavor(buf.String(), flavor, initialArg...)
}
// SetFlavor sets the flavor of compiled sql.
func (sb *SelectBuilder) SetFlavor(flavor Flavor) (old Flavor) {
old = sb.args.Flavor
sb.args.Flavor = flavor
return
}
// Copy the builder
func (sb *SelectBuilder) Copy() *SelectBuilder {
return &SelectBuilder{
Cond: sb.Cond,
distinct: sb.distinct,
tables: sb.tables,
selectCols: sb.selectCols,
whereExprs: sb.whereExprs,
havingExprs: sb.havingExprs,
groupByCols: sb.groupByCols,
orderByCols: sb.orderByCols,
order: sb.order,
limit: sb.limit,
offset: sb.offset,
args: sb.args.Copy(),
}
}

View File

@@ -0,0 +1,67 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"database/sql"
"fmt"
)
func ExampleSelectBuilder() {
sb := NewSelectBuilder()
sb.Distinct().Select("id", "name", sb.As("COUNT(*)", "t"))
sb.From("demo.user")
sb.Where(
sb.GreaterThan("id", 1234),
sb.Like("name", "%Du"),
sb.Or(
sb.IsNull("id_card"),
sb.In("status", 1, 2, 5),
),
sb.NotIn(
"id",
NewSelectBuilder().Select("id").From("banned"),
), // Nested SELECT.
"modified_at > created_at + "+sb.Var(86400), // It's allowed to write arbitrary SQL.
)
sb.GroupBy("status").Having(sb.NotIn("status", 4, 5))
sb.OrderBy("modified_at").Asc()
sb.Limit(10).Offset(5)
sql, args := sb.Build()
fmt.Println(sql)
fmt.Println(args)
// Output:
// SELECT DISTINCT id, name, COUNT(*) AS t FROM demo.user WHERE id > ? AND name LIKE ? AND (id_card IS NULL OR status IN (?, ?, ?)) AND id NOT IN (SELECT id FROM banned) AND modified_at > created_at + ? GROUP BY status HAVING status NOT IN (?, ?) ORDER BY modified_at ASC LIMIT 10 OFFSET 5
// [1234 %Du 1 2 5 86400 4 5]
}
func ExampleSelectBuilder_advancedUsage() {
sb := NewSelectBuilder()
innerSb := NewSelectBuilder()
sb.Select("id", "name")
sb.From(
sb.BuilderAs(innerSb, "user"),
)
sb.Where(
sb.In("status", Flatten([]int{1, 2, 3})...),
sb.Between("created_at", sql.Named("start", 1234567890), sql.Named("end", 1234599999)),
)
innerSb.Select("*")
innerSb.From("banned")
innerSb.Where(
innerSb.NotIn("name", Flatten([]string{"Huan Du", "Charmy Liu"})...),
)
sql, args := sb.Build()
fmt.Println(sql)
fmt.Println(args)
// Output:
// SELECT id, name FROM (SELECT * FROM banned WHERE name NOT IN (?, ?)) AS user WHERE status IN (?, ?, ?) AND created_at BETWEEN @start AND @end
// [Huan Du Charmy Liu 1 2 3 {{} start 1234567890} {{} end 1234599999}]
}