Create & Init Project...

This commit is contained in:
2019-04-22 18:49:16 +08:00
commit fc4fa37393
25440 changed files with 4054998 additions and 0 deletions

View File

@ -0,0 +1,86 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_library(
name = "go_default_library",
srcs = [
"aes.go",
"captcha.go",
"dao.go",
"email.go",
"mid_info.go",
"mysql.go",
"redis.go",
"req_rpc.go",
"user_act_log.go",
],
importpath = "go-common/app/service/main/account-recovery/dao",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//app/service/main/account-recovery/conf:go_default_library",
"//app/service/main/account-recovery/dao/sqlbuilder:go_default_library",
"//app/service/main/account-recovery/model:go_default_library",
"//app/service/main/account/api:go_default_library",
"//app/service/main/location/model:go_default_library",
"//app/service/main/location/rpc/client:go_default_library",
"//app/service/main/member/api:go_default_library",
"//library/cache/redis:go_default_library",
"//library/database/elastic:go_default_library",
"//library/database/sql:go_default_library",
"//library/ecode:go_default_library",
"//library/log:go_default_library",
"//library/net/http/blademaster:go_default_library",
"//library/net/metadata:go_default_library",
"//library/time:go_default_library",
"//library/xstr:go_default_library",
"//vendor/github.com/pkg/errors:go_default_library",
"//vendor/gopkg.in/gomail.v2:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [
":package-srcs",
"//app/service/main/account-recovery/dao/sqlbuilder:all-srcs",
],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)
go_test(
name = "go_default_test",
srcs = [
"captcha_test.go",
"dao_test.go",
"email_test.go",
"mid_info_test.go",
"mysql_test.go",
"redis_test.go",
"req_rpc_test.go",
"user_act_log_test.go",
],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
deps = [
"//app/service/main/account-recovery/conf:go_default_library",
"//app/service/main/account-recovery/model:go_default_library",
"//library/time:go_default_library",
"//vendor/github.com/smartystreets/goconvey/convey:go_default_library",
"//vendor/gopkg.in/h2non/gock.v1:go_default_library",
],
)

View File

@ -0,0 +1,63 @@
package dao
import (
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"errors"
)
//func pad(src []byte) []byte {
// padding := aes.BlockSize - len(src)%aes.BlockSize
// padText := bytes.Repeat([]byte{byte(padding)}, padding)
// return append(src, padText...)
//}
func unpad(src []byte) ([]byte, error) {
length := len(src)
unpadding := int(src[length-1])
if unpadding > length {
return nil, errors.New("unpad error. This could happen when incorrect encryption key is used")
}
return src[:(length - unpadding)], nil
}
//func (s *Service) encrypt(text string) (string, error) {
// msg := pad([]byte(text))
// cipherText := make([]byte, aes.BlockSize+len(msg))
// iv := cipherText[:aes.BlockSize]
// if _, err := io.ReadFull(rand.Reader, iv); err != nil {
// return "", err
// }
//
// cfb := cipher.NewCFBEncrypter(s.AESBlock, iv)
// cfb.XORKeyStream(cipherText[aes.BlockSize:], []byte(msg))
// finalMsg := base64.URLEncoding.EncodeToString(cipherText)
// return finalMsg, nil
//}
func (d *Dao) decrypt(text string) (string, error) {
decodedMsg, err := base64.URLEncoding.DecodeString(text)
if err != nil {
return "", err
}
if (len(decodedMsg) % aes.BlockSize) != 0 {
return "", errors.New("blocksize must be multipe of decoded message length")
}
iv := decodedMsg[:aes.BlockSize]
msg := decodedMsg[aes.BlockSize:]
cfb := cipher.NewCFBDecrypter(d.AESBlock, iv)
cfb.XORKeyStream(msg, msg)
unpadMsg, err := unpad(msg)
if err != nil {
return "", err
}
return string(unpadMsg), nil
}

View File

@ -0,0 +1,47 @@
package dao
import (
"context"
"net/url"
"go-common/app/service/main/account-recovery/model"
"go-common/library/ecode"
"go-common/library/log"
"go-common/library/net/metadata"
)
// GetToken get open token.
func (d *Dao) GetToken(c context.Context, bid string) (res *model.TokenResq, err error) {
params := url.Values{}
params.Add("bid", bid)
if err = d.httpClient.Get(c, d.c.CaptchaConf.TokenURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
log.Error("GetToken HTTP request err(%v)", err)
return
}
if res.Code != 0 {
log.Error("GetToken service return err(%v)", res.Code)
err = ecode.Int(int(res.Code))
return
}
return
}
// Verify verify code.
func (d *Dao) Verify(c context.Context, code, token string) (ok bool, err error) {
params := url.Values{}
params.Add("token", token)
params.Add("code", code)
res := new(struct {
Code int `json:"code"`
})
if err = d.httpClient.Post(c, d.c.CaptchaConf.VerifyURL, metadata.String(c, metadata.RemoteIP), params, res); err != nil {
log.Error("Verify HTTP request err(%v)", err)
return
}
if res.Code != 0 {
log.Error("Verify service return err(%v)", res.Code)
err = ecode.Int(res.Code)
return
}
return true, nil
}

View File

@ -0,0 +1,22 @@
package dao
import (
"context"
"testing"
"github.com/smartystreets/goconvey/convey"
)
func TestDaoGetToken(t *testing.T) {
var (
c = context.Background()
bid = "account"
)
convey.Convey("GetToken", t, func(ctx convey.C) {
res, err := d.GetToken(c, bid)
ctx.Convey("Then err should be nil.res should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(res, convey.ShouldNotBeNil)
})
})
}

View File

@ -0,0 +1,90 @@
package dao
import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/tls"
"go-common/app/service/main/account-recovery/conf"
account "go-common/app/service/main/account/api"
location "go-common/app/service/main/location/rpc/client"
member "go-common/app/service/main/member/api"
"go-common/library/cache/redis"
"go-common/library/database/elastic"
xsql "go-common/library/database/sql"
bm "go-common/library/net/http/blademaster"
"gopkg.in/gomail.v2"
)
// Dao dao
type Dao struct {
c *conf.Config
redis *redis.Pool
db *xsql.DB
// httpClient
httpClient *bm.Client
// email
email *gomail.Dialer
es *elastic.Elastic
// rpc
locRPC *location.Service
// grpc
memberClient member.MemberClient
accountClient account.AccountClient
hashSalt []byte
AESBlock cipher.Block
}
// New init mysql db
func New(c *conf.Config) (dao *Dao) {
dao = &Dao{
c: c,
redis: redis.NewPool(c.Redis),
db: xsql.NewMySQL(c.MySQL),
// httpClient
httpClient: bm.NewClient(c.HTTPClientConfig),
email: gomail.NewDialer(c.MailConf.Host, c.MailConf.Port, c.MailConf.Username, c.MailConf.Password),
es: elastic.NewElastic(c.Elastic),
locRPC: location.New(c.LocationRPC),
hashSalt: []byte(c.AESEncode.Salt),
}
dao.email.TLSConfig = &tls.Config{
InsecureSkipVerify: true,
}
dao.AESBlock, _ = aes.NewCipher([]byte(c.AESEncode.AesKey))
var err error
if dao.memberClient, err = member.NewClient(c.MemberGRPC); err != nil {
panic(err)
}
if dao.accountClient, err = account.NewClient(c.AccountGRPC); err != nil {
panic(err)
}
return
}
// Close close the resource.
func (d *Dao) Close() {
d.redis.Close()
d.db.Close()
}
// Ping dao ping
func (d *Dao) Ping(c context.Context) (err error) {
if err = d.db.Ping(c); err != nil {
return
}
if err = d.PingRedis(c); err != nil {
return
}
// TODO: if you need use mc,redis, please add
return
}

View File

@ -0,0 +1,46 @@
package dao
import (
"flag"
"os"
"strings"
"testing"
"go-common/app/service/main/account-recovery/conf"
"gopkg.in/h2non/gock.v1"
)
var (
d *Dao
)
func TestMain(m *testing.M) {
if os.Getenv("DEPLOY_ENV") != "" {
flag.Set("app_id", "main.account.account-recovery")
flag.Set("conf_token", "5fe12bcbf11eb0c368ee0c2d1f567184")
flag.Set("tree_id", "55382")
flag.Set("conf_version", "docker-1")
flag.Set("deploy_env", "uat")
flag.Set("conf_host", "config.bilibili.co")
flag.Set("conf_path", "/tmp")
flag.Set("region", "sh")
flag.Set("zone", "sh001")
} else {
flag.Set("conf", "../cmd/account-recovery-test.toml")
}
flag.Parse()
if err := conf.Init(); err != nil {
panic(err)
}
d = New(conf.Conf)
d.httpClient.SetTransport(gock.DefaultTransport)
m.Run()
os.Exit(0)
}
func httpMock(method, url string) *gock.Request {
r := gock.New(url)
r.Method = strings.ToUpper(method)
return r
}

View File

@ -0,0 +1,23 @@
package dao
import (
"go-common/app/service/main/account-recovery/conf"
"go-common/library/log"
"gopkg.in/gomail.v2"
)
// SendMail send the email.
func (d *Dao) SendMail(body string, subject string, send ...string) (err error) {
log.Info("send mail send:%v", send)
msg := gomail.NewMessage()
msg.SetHeader("From", conf.Conf.MailConf.Username)
msg.SetHeader("To", send...)
msg.SetHeader("Subject", subject)
msg.SetBody("text/html", body, gomail.SetPartEncoding(gomail.Base64))
if err = d.email.DialAndSend(msg); err != nil {
log.Error("s.email.DialAndSend error(%v)", err)
return
}
return
}

View File

@ -0,0 +1,23 @@
package dao
import (
"math/rand"
"strconv"
"testing"
"github.com/smartystreets/goconvey/convey"
)
func TestDaoSendMail(t *testing.T) {
var (
body = strconv.Itoa(rand.Intn(100))
subject = "邮件测试hyy"
send = "2459593393@qq.com"
)
convey.Convey("SendMail", t, func(ctx convey.C) {
err := d.SendMail(body, subject, send)
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
})
})
}

View File

@ -0,0 +1,272 @@
package dao
import (
"context"
"net/url"
"strconv"
"strings"
"go-common/app/service/main/account-recovery/model"
"go-common/library/ecode"
"go-common/library/log"
"go-common/library/net/metadata"
)
// GetMidInfo get mid info by more condition
func (d *Dao) GetMidInfo(c context.Context, qType string, qKey string) (v *model.MIDInfo, err error) {
params := url.Values{}
params.Set("q_type", qType)
params.Set("q_key", qKey)
res := new(struct {
Code int `json:"code"`
Data model.MIDInfo `json:"data"`
})
if err = d.httpClient.Get(c, d.c.AccRecover.MidInfoURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
log.Error("GetMidInfo HTTP request err %+v", err)
return
}
if res.Code != 0 {
log.Error("GetMidInfo server err_code %d,params: qType=%s,qKey=%s", res.Code, qType, qKey)
err = ecode.ServerErr
return
}
log.Info("GetMidInfo url=%v, params: qType=%s,qKey=%s, res: %+v", d.c.AccRecover.MidInfoURL, qType, qKey, res)
return &res.Data, nil
}
// GetUserInfo get user info by mid
func (d *Dao) GetUserInfo(c context.Context, mid int64) (v *model.UserInfo, err error) {
params := url.Values{}
params.Add("mid", strconv.Itoa(int(mid)))
res := new(struct {
Code int `json:"code"`
Data model.UserInfo `json:"data"`
})
if err = d.httpClient.Get(c, d.c.AccRecover.GetUserInfoURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
log.Error("GetUserInfo HTTP request err %+v", err)
return
}
if res.Code != 0 {
log.Error("GetUserInfo server err_code %d,params: mid=%d", res.Code, mid)
err = ecode.ServerErr
return
}
log.Info("GetUserInfo url=%v, params: mid=%d, res: %+v", d.c.AccRecover.GetUserInfoURL, mid, res)
return &res.Data, nil
}
// UpdatePwd update password
func (d *Dao) UpdatePwd(c context.Context, mid int64, operator string) (user *model.User, err error) {
params := url.Values{}
params.Set("mid", strconv.Itoa(int(mid)))
params.Set("operator", operator)
res := new(struct {
Code int `json:"code"`
Data model.User `json:"data"`
})
if err = d.httpClient.Post(c, d.c.AccRecover.UpPwdURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
log.Error("UpdatePwd HTTP request err %+v", err)
return
}
if res.Code != 0 {
log.Error("UpdatePwd server err_code %d,params: mid=%d", res.Code, mid)
err = ecode.Int(res.Code)
return
}
log.Info("UpdatePwd url=%v, params: mid=%d,operator=%s, res: %+v", d.c.AccRecover.UpPwdURL, mid, operator, res)
return &res.Data, nil
}
// CheckSafe safe info
func (d *Dao) CheckSafe(c context.Context, mid int64, question int8, answer string) (check *model.Check, err error) {
params := url.Values{}
params.Add("mid", strconv.Itoa(int(mid)))
params.Add("question", strconv.Itoa(int(question)))
params.Add("answer", answer)
res := new(struct {
Code int `json:"code"`
Data model.Check `json:"data"`
})
if err = d.httpClient.Post(c, d.c.AccRecover.CheckSafeURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
log.Error("CheckSafe HTTP request err %+v", err)
return
}
if res.Code != 0 {
log.Error("CheckSafe server err_code %d,params: mid=%d,question=%d,answer=%s", res.Code, mid, question, answer)
err = ecode.Int(res.Code)
return
}
log.Info("CheckSafe url=%v, params: mid=%d,question=%d,answer=%s, res: %+v", d.c.AccRecover.CheckSafeURL, mid, question, answer, res)
return &res.Data, nil
}
// GetUserType get user_type
func (d *Dao) GetUserType(c context.Context, mid int64) (gams []*model.Game, err error) {
params := url.Values{}
params.Add("mid", strconv.Itoa(int(mid)))
res := new(struct {
Code int `json:"code"`
Data []*model.Game `json:"items"`
})
if err = d.httpClient.Get(c, d.c.AccRecover.GameURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
log.Error("GetUserType HTTP request err %+v", err)
return
}
if res.Code != 0 {
log.Error("GetUserType server err_code %d,params: mid=%d", res.Code, mid)
err = ecode.Int(res.Code)
return
}
log.Info("GetUserType url=%v, params: mid=%d, res: %+v", d.c.AccRecover.GameURL, mid, res)
return res.Data, nil
}
// CheckReg check reg info
func (d *Dao) CheckReg(c context.Context, mid int64, regTime int64, regType int8, regAddr string) (v *model.Check, err error) {
params := url.Values{}
params.Add("mid", strconv.Itoa(int(mid)))
params.Add("reg_time", strconv.FormatInt(regTime, 10))
params.Add("reg_type", strconv.Itoa(int(regType)))
params.Add("reg_addr", regAddr)
res := new(struct {
Code int `json:"code"`
Data model.Check `json:"data"`
})
if err = d.httpClient.Post(c, d.c.AccRecover.CheckRegURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
log.Error("CheckReg HTTP request err %+v", err)
return
}
if res.Code != 0 {
log.Error("CheckReg server err_code %d,params: mid=%d,regTime=%d,regType=%d,regAddr=%s", res.Code, mid, regTime, regType, regAddr)
err = ecode.Int(res.Code)
return
}
log.Info("CheckReg url=%v, params: mid=%d,regTime=%d,regType=%d,regAddr=%s, res: %+v", d.c.AccRecover.CheckRegURL, mid, regTime, regType, regAddr, res)
return &res.Data, nil
}
// UpdateBatchPwd batch update password
func (d *Dao) UpdateBatchPwd(c context.Context, mids string, operator string) (userMap map[string]*model.User, err error) {
params := url.Values{}
params.Set("mids", mids)
params.Set("operator", operator)
res := new(struct {
Code int `json:"code"`
Data map[string]*model.User `json:"data"`
})
if err = d.httpClient.Post(c, d.c.AccRecover.UpBatchPwdURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
log.Error("UpdateBatchPwd HTTP request err %+v", err)
return
}
if res.Code != 0 {
log.Error("UpdateBatchPwd server err_code %d,params: mids=%s", res.Code, mids)
err = ecode.Int(res.Code)
return
}
log.Info("UpdateBatchPwd url=%v, params: mid=%s,operator=%s, res: %+v", d.c.AccRecover.UpBatchPwdURL, mids, operator, res)
return res.Data, nil
}
// CheckCard check card
func (d *Dao) CheckCard(c context.Context, mid int64, cardType int8, cardCode string) (ok bool, err error) {
params := url.Values{}
params.Set("mid", strconv.FormatInt(mid, 10))
params.Set("card_type", strconv.Itoa(int(cardType)))
params.Set("card_code", cardCode)
res := new(struct {
Code int `json:"code"`
Data bool `json:"data"`
})
if err = d.httpClient.Get(c, d.c.AccRecover.CheckCardURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
log.Error("CheckCard HTTP request err %+v", err)
return
}
if res.Code != 0 {
log.Error("CheckCard server err_code %d,params: mid=%d,cardType=%d,cardCode=%s", res.Code, mid, cardType, cardCode)
err = ecode.Int(res.Code)
return
}
log.Info("CheckCard url=%v, params: mid=%d,cardType=%d,cardCode=%s, res: %+v", d.c.AccRecover.CheckCardURL, mid, cardType, cardCode, res)
return res.Data, nil
}
// CheckPwds check pwd
func (d *Dao) CheckPwds(c context.Context, mid int64, pwds string) (v string, err error) {
params := url.Values{}
params.Set("mid", strconv.FormatInt(mid, 10))
params.Set("pwd", pwds)
res := new(struct {
Code int `json:"code"`
Data string `json:"data"`
})
if err = d.httpClient.Post(c, d.c.AccRecover.CheckPwdURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
log.Error("CheckPwds HTTP request err %+v", err)
return
}
if res.Code != 0 {
log.Error("CheckPwds server err_code %d,params: mid=%d,pwds=%s", res.Code, mid, pwds)
err = ecode.Int(res.Code)
return
}
log.Info("CheckPwds url=%v, params: mid=%d,pwds=%s, res: %+v", d.c.AccRecover.CheckPwdURL, mid, pwds, res)
return res.Data, nil
}
// GetLoginIPs get login ip
func (d *Dao) GetLoginIPs(c context.Context, mid int64, limit int64) (ipInfo []*model.LoginIPInfo, err error) {
params := url.Values{}
params.Set("mid", strconv.FormatInt(mid, 10))
params.Set("limit", strconv.FormatInt(limit, 10))
res := new(struct {
Code int `json:"code"`
Data []*model.LoginIPInfo `json:"data"`
})
if err = d.httpClient.Get(c, d.c.AccRecover.GetLoginIPURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
log.Error("GetLoginIPs HTTP request err %+v", err)
return
}
if res.Code != 0 {
log.Error("GetLoginIPs server err_code %d,params: mid=%d,limit=%d", res.Code, mid, limit)
err = ecode.ServerErr
return
}
log.Info("GetLoginIPs url=%v, params: mid=%d,limit=%d, res: %+v", d.c.AccRecover.GetLoginIPURL, mid, limit, res)
return res.Data, nil
}
// GetAddrByIP get addr by ip
func (d *Dao) GetAddrByIP(c context.Context, mid int64, limit int64) (addrs string, err error) {
ipInfo, err := d.GetLoginIPs(c, mid, limit)
if err != nil || len(ipInfo) == 0 {
return
}
var ipLen = len(ipInfo)
ips := make([]string, 0, ipLen)
//ip去重复和空串
for i := 0; i < ipLen; i++ {
if (i > 0 && ipInfo[i-1].LoginIP == ipInfo[i].LoginIP) || len(ipInfo[i].LoginIP) == 0 {
continue
}
ips = append(ips, ipInfo[i].LoginIP)
}
ipMap, err := d.Infos(c, ips)
i := 0
for _, loc := range ipMap {
if loc.Country != "" {
addrs += loc.Country + "-"
}
if loc.Province != "" {
addrs += loc.Province + "-"
}
if loc.City != "" {
addrs += loc.City + "-"
}
addrs = strings.TrimRight(addrs, "-") + ","
i++
if i >= 3 {
break
}
}
addrs = strings.TrimRight(addrs, ",")
return
}

View File

@ -0,0 +1,182 @@
package dao
import (
"context"
"testing"
"github.com/smartystreets/goconvey/convey"
"gopkg.in/h2non/gock.v1"
)
func TestDaoGetMidInfo(t *testing.T) {
var (
c = context.Background()
qType = "1"
qKey = "silg@yahoo.cn"
)
convey.Convey("GetMidInfo", t, func(ctx convey.C) {
v, err := d.GetMidInfo(c, qType, qKey)
ctx.Convey("Then err should be nil.v should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(v, convey.ShouldNotBeNil)
})
})
}
func TestDaoGetUserInfo(t *testing.T) {
var (
c = context.Background()
mid = int64(1)
)
convey.Convey("GetUserInfo", t, func(ctx convey.C) {
defer gock.OffAll()
httpMock("GET", d.c.AccRecover.GetUserInfoURL).Reply(200).JSON(`{"code":0,"data":{"mid":21,"email":"raiden131@yahoo.cn","telphone":"","join_time":1245902140}}`)
v, err := d.GetUserInfo(c, mid)
ctx.Convey("Then err should be nil.v should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(v, convey.ShouldNotBeNil)
})
})
}
func TestDaoUpdatePwd(t *testing.T) {
var (
c = context.Background()
mid = int64(1)
)
convey.Convey("UpdatePwd", t, func(ctx convey.C) {
defer gock.OffAll()
httpMock("POST", d.c.AccRecover.UpPwdURL).Reply(200).JSON(`{"code": 0, "data":{"pwd":"d4txsunbb1","userid":"minorin"}}`)
user, err := d.UpdatePwd(c, mid, "账号找回服务")
ctx.So(err, convey.ShouldBeNil)
ctx.So(user, convey.ShouldNotBeNil)
})
}
func TestDaoCheckSafe(t *testing.T) {
var (
c = context.Background()
mid = int64(1)
question = int8(0)
answer = "1"
)
convey.Convey("CheckSafe", t, func(ctx convey.C) {
check, err := d.CheckSafe(c, mid, question, answer)
ctx.Convey("Then err should be nil.check should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(check, convey.ShouldNotBeNil)
})
})
}
//func httpMock(method, url string) *gock.Request {
// r := gock.New(url)
// r.Method = strings.ToUpper(method)
// return r
//}
//func TestDaoGetUserType(t *testing.T) {
// var (
// c = context.Background()
// mid = int64(2)
// )
// convey.Convey("When http request gets code != 0", t, func(ctx convey.C) {
// defer gock.OffAll()
// httpMock("GET", d.c.AccRecover.GameURL).Reply(0).JSON(`{"requestId":"0def8d70b7ef11e8a395fa163e01a2e9","ts":"1535440592","code":0,"items":[{"id":14,"name":"SDK测试2","lastLogin":"1500969010"}]}`)
// games, err := d.GetUserType(c, mid)
// ctx.So(err, convey.ShouldBeNil)
// ctx.So(games, convey.ShouldNotBeNil)
// })
//}
func TestDaoCheckReg(t *testing.T) {
var (
c = context.Background()
mid = int64(1)
regTime = int64(1532441644)
regType = int8(0)
regAddr = "中国_上海"
)
convey.Convey("CheckReg", t, func(ctx convey.C) {
v, err := d.CheckReg(c, mid, regTime, regType, regAddr)
ctx.Convey("Then err should be nil.v should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(v, convey.ShouldNotBeNil)
})
})
}
func TestDaoUpdateBatchPwd(t *testing.T) {
var (
c = context.Background()
mids = "1,2"
)
convey.Convey("UpdateBatchPwd", t, func(ctx convey.C) {
defer gock.OffAll()
httpMock("POST", d.c.AccRecover.UpBatchPwdURL).Reply(200).JSON(`{"code":0,"data":{"6":{"pwd":"tgs52r1st9","userid":"腹黑君"},"7":{"pwd":"g20ahzrf7j","userid":"Tzwcard"}}}`)
userMap, err := d.UpdateBatchPwd(c, mids, "账号找回服务")
ctx.So(err, convey.ShouldBeNil)
ctx.So(userMap, convey.ShouldNotBeNil)
})
}
func TestDaoCheckCard(t *testing.T) {
var (
c = context.Background()
mid = int64(1)
cardType = int8(1)
cardCode = "123"
)
convey.Convey("CheckCard", t, func(ctx convey.C) {
ok, err := d.CheckCard(c, mid, cardType, cardCode)
ctx.Convey("Then err should be nil.ok should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(ok, convey.ShouldNotBeNil)
})
})
}
func TestDaoCheckPwds(t *testing.T) {
var (
c = context.Background()
mid = int64(1)
pwds = "123"
)
convey.Convey("CheckPwds", t, func(ctx convey.C) {
v, err := d.CheckPwds(c, mid, pwds)
ctx.Convey("Then err should be nil.v should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(v, convey.ShouldNotBeNil)
})
})
}
func TestDaoGetLoginIPs(t *testing.T) {
var (
c = context.Background()
mid = int64(2)
limit = int64(10)
)
convey.Convey("GetLoginIPs", t, func(ctx convey.C) {
ipInfo, err := d.GetLoginIPs(c, mid, limit)
ctx.Convey("Then err should be nil.ipInfo should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(ipInfo, convey.ShouldNotBeNil)
})
})
}
func TestDaoGetAddrByIP(t *testing.T) {
var (
c = context.Background()
mid = int64(111001254)
limit = int64(10)
)
convey.Convey("GetAddrByIP", t, func(ctx convey.C) {
addrs, err := d.GetAddrByIP(c, mid, limit)
ctx.Convey("Then err should be nil.addrs should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(addrs, convey.ShouldNotBeNil)
})
})
}

View File

@ -0,0 +1,475 @@
package dao
import (
"context"
rsql "database/sql"
"fmt"
"strings"
"go-common/app/service/main/account-recovery/dao/sqlbuilder"
"go-common/app/service/main/account-recovery/model"
"go-common/library/database/sql"
"go-common/library/log"
xtime "go-common/library/time"
"go-common/library/xstr"
)
const (
_selectCountRecoveryInfo = "select count(rid) from account_recovery_info"
_selectRecoveryInfoLimit = "select rid,mid,user_type,status,login_addrs,unames,reg_time,reg_type,reg_addr,pwds,phones,emails,safe_question,safe_answer,card_type,card_id," +
"sys_login_addrs,sys_reg,sys_unames,sys_pwds,sys_phones,sys_emails,sys_safe,sys_card," +
"link_email,operator,opt_time,remark,ctime,business from account_recovery_info %s"
_getSuccessCount = "SELECT count FROM account_recovery_success WHERE mid=?"
_batchGetRecoverySuccess = "SELECT mid,count,ctime,mtime FROM account_recovery_success WHERE mid in (%s)"
_updateSuccessCount = "INSERT INTO account_recovery_success (mid, count) VALUES (?, 1) ON DUPLICATE KEY UPDATE count = count + 1"
_batchUpdateSuccessCount = "INSERT INTO account_recovery_success (mid, count) VALUES %s ON DUPLICATE KEY UPDATE count = count + 1"
_updateStatus = "UPDATE account_recovery_info SET status=?,operator=?,opt_time=?,remark=? WHERE rid = ? AND `status`=0"
_getNoDeal = "SELECT COUNT(1) FROM account_recovery_info WHERE mid=? AND `status`=0"
_updateUserType = "UPDATE account_recovery_info SET user_type=? WHERE rid = ?"
_insertRecoveryInfo = "INSERT INTO account_recovery_info(login_addrs,unames,reg_time,reg_type,reg_addr,pwds,phones,emails,safe_question,safe_answer,card_type,card_id,link_email,mid,business,last_suc_count,last_suc_ctime) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"
_updateSysInfo = "UPDATE account_recovery_info SET sys_login_addrs=?,sys_reg=?,sys_unames=?,sys_pwds=?,sys_phones=?,sys_emails=?,sys_safe=?,sys_card=?,user_type=? WHERE rid=?"
_getUinfoByRid = "SELECT mid,link_email,ctime FROM account_recovery_info WHERE rid=? LIMIT 1"
_getUinfoByRidMore = "SELECT rid,mid,link_email,ctime FROM account_recovery_info WHERE rid in (%s)"
_selectUnCheckInfo = "SELECT mid,login_addrs,unames,reg_time,reg_type,reg_addr,pwds,phones,emails,safe_question,safe_answer,card_type,card_id FROM account_recovery_info WHERE rid=? AND `status`=0 AND sys_card=''"
_getStatusByRid = "SELECT `status` FROM account_recovery_info WHERE rid=?"
_getMailStatus = "SELECT mail_status FROM account_recovery_info WHERE rid=?"
_updateMailStatus = "UPDATE account_recovery_info SET mail_status=1 WHERE rid=?"
_insertRecoveryAddit = "INSERT INTO account_recovery_addit(`rid`, `files`, `extra`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE files=VALUES(files),extra=VALUES(extra)"
_updateRecoveryAddit = "UPDATE account_recovery_addit SET `files` = ?,`extra` = ? WHERE rid = ?"
_getRecoveryAddit = "SELECT rid, `files`,`extra`, ctime, mtime FROM account_recovery_addit WHERE rid= ?"
_batchRecoveryAAdit = "SELECT rid, `files`, `extra`, ctime, mtime FROM account_recovery_addit WHERE rid in (%s)"
_batchGetLastSuccess = "SELECT mid,max(ctime) FROM account_recovery_info WHERE mid in (%s) AND `status`=1 GROUP BY mid"
_getLastSuccess = "SELECT mid,max(ctime) FROM account_recovery_info WHERE mid = ? AND `status`=1"
)
// GetStatusByRid get status by rid
func (dao *Dao) GetStatusByRid(c context.Context, rid int64) (status int64, err error) {
res := dao.db.Prepared(_getStatusByRid).QueryRow(c, rid)
if err = res.Scan(&status); err != nil {
if err == sql.ErrNoRows {
status = -1
err = nil
} else {
log.Error("GetStatusByRid row.Scan error(%v)", err)
}
}
return
}
// GetSuccessCount get success count
func (dao *Dao) GetSuccessCount(c context.Context, mid int64) (count int64, err error) {
res := dao.db.Prepared(_getSuccessCount).QueryRow(c, mid)
if err = res.Scan(&count); err != nil {
if err == sql.ErrNoRows {
count = 0
err = nil
} else {
log.Error("GetSuccessCount row.Scan error(%v)", err)
}
}
return
}
// BatchGetRecoverySuccess batch get recovery success info
func (dao *Dao) BatchGetRecoverySuccess(c context.Context, mids []int64) (successMap map[int64]*model.RecoverySuccess, err error) {
rows, err := dao.db.Query(c, fmt.Sprintf(_batchGetRecoverySuccess, xstr.JoinInts(mids)))
if err != nil {
if err == sql.ErrNoRows {
err = nil
return
}
log.Error("BatchGetRecoverySuccess d.db.Query error(%v)", err)
return
}
successMap = make(map[int64]*model.RecoverySuccess)
for rows.Next() {
r := new(model.RecoverySuccess)
if err = rows.Scan(&r.SuccessMID, &r.SuccessCount, &r.FirstSuccessTime, &r.LastSuccessTime); err != nil {
log.Error("BatchGetRecoverySuccess rows.Scan error(%v)", err)
continue
}
successMap[r.SuccessMID] = r
}
return
}
// UpdateSuccessCount insert or update success count
func (dao *Dao) UpdateSuccessCount(c context.Context, mid int64) (err error) {
_, err = dao.db.Exec(c, _updateSuccessCount, mid)
return
}
// BatchUpdateSuccessCount batch insert or update success count
func (dao *Dao) BatchUpdateSuccessCount(c context.Context, mids string) (err error) {
var s string
midArr := strings.Split(mids, ",")
for _, mid := range midArr {
s = s + fmt.Sprintf(",(%s, 1)", mid)
}
_, err = dao.db.Exec(c, fmt.Sprintf(_batchUpdateSuccessCount, s[1:]))
return
}
// GetNoDeal get no deal record
func (dao *Dao) GetNoDeal(c context.Context, mid int64) (count int64, err error) {
res := dao.db.Prepared(_getNoDeal).QueryRow(c, mid)
if err = res.Scan(&count); err != nil {
if err == sql.ErrNoRows {
err = nil
return
}
log.Error("GetNoDeal row.Scan error(%v)", err)
return
}
return
}
// UpdateStatus update field status.
func (dao *Dao) UpdateStatus(c context.Context, status int64, rid int64, operator string, optTime xtime.Time, remark string) (err error) {
_, err = dao.db.Exec(c, _updateStatus, status, operator, optTime, remark, rid)
return
}
// UpdateUserType update field user_type.
func (dao *Dao) UpdateUserType(c context.Context, status int64, rid int64) (err error) {
if _, err = dao.db.Exec(c, _updateUserType, status, rid); err != nil {
log.Error("dao.db.Exec(%s, %d, %d) error(%v)", _updateUserType, status, rid, err)
}
return
}
// InsertRecoveryInfo insert data
func (dao *Dao) InsertRecoveryInfo(c context.Context, uinfo *model.UserInfoReq) (lastID int64, err error) {
var res rsql.Result
if res, err = dao.db.Exec(c, _insertRecoveryInfo, uinfo.LoginAddrs, uinfo.Unames, uinfo.RegTime, uinfo.RegType, uinfo.RegAddr,
uinfo.Pwds, uinfo.Phones, uinfo.Emails, uinfo.SafeQuestion, uinfo.SafeAnswer, uinfo.CardType, uinfo.CardID, uinfo.LinkMail, uinfo.Mid, uinfo.Business, uinfo.LastSucCount, uinfo.LastSucCTime); err != nil {
log.Error("dao.db.Exec(%s, %v) error(%v)", _insertRecoveryInfo, uinfo, err)
return
}
return res.LastInsertId()
}
// UpdateSysInfo update sysinfo and user_type
func (dao *Dao) UpdateSysInfo(c context.Context, sys *model.SysInfo, userType int64, rid int64) (err error) {
if _, err = dao.db.Exec(c, _updateSysInfo, &sys.SysLoginAddrs, &sys.SysReg, &sys.SysUNames, &sys.SysPwds, &sys.SysPhones,
&sys.SysEmails, &sys.SysSafe, &sys.SysCard, userType, rid); err != nil {
log.Error("dao.db.Exec(%s, %v) error(%v)", _updateSysInfo, sys, err)
}
return
}
// GetAllByCon get a pageData by more condition
func (dao *Dao) GetAllByCon(c context.Context, aq *model.QueryRecoveryInfoReq) ([]*model.AccountRecoveryInfo, int64, error) {
query := sqlbuilder.NewSelectBuilder().Select("rid,mid,user_type,status,login_addrs,unames,reg_time,reg_type,reg_addr,pwds,phones,emails,safe_question,safe_answer,card_type,card_id,sys_login_addrs,sys_reg,sys_unames,sys_pwds,sys_phones,sys_emails,sys_safe,sys_card,link_email,operator,opt_time,remark,ctime,business,last_suc_count,last_suc_ctime").From("account_recovery_info")
if aq.Bussiness != "" {
query = query.Where(query.Equal("business", aq.Bussiness))
}
if aq.Status != nil {
query = query.Where(fmt.Sprintf("status=%d", *aq.Status))
}
if aq.Game != nil {
query = query.Where(fmt.Sprintf("user_type=%d", *aq.Game))
}
if aq.UID != 0 {
query = query.Where(fmt.Sprintf("mid=%d", aq.UID))
}
if aq.RID != 0 {
query = query.Where(fmt.Sprintf("rid=%d", aq.RID))
}
if aq.StartTime != 0 {
query = query.Where(query.GE("ctime", aq.StartTime.Time()))
}
if aq.EndTime != 0 {
query = query.Where(query.LE("ctime", aq.EndTime.Time()))
}
totalSQL, totalArg := query.Copy().Select("count(1)").Build()
log.Info("Build GetAllByCon total count SQL: %s", totalSQL)
page := aq.Page
if page == 0 {
page = 1
}
size := aq.Size
if size == 0 {
size = 50
}
query = query.Limit(int(size)).Offset(int(size * (page - 1))).OrderBy("rid DESC")
rawSQL, rawArg := query.Build()
log.Info("Build GetAllByCon SQL: %s", rawSQL)
total := int64(0)
row := dao.db.QueryRow(c, totalSQL, totalArg...)
if err := row.Scan(&total); err != nil {
return nil, 0, err
}
rows, err := dao.db.Query(c, rawSQL, rawArg...)
if err != nil {
return nil, 0, err
}
defer rows.Close()
resultData := make([]*model.AccountRecoveryInfo, 0)
for rows.Next() {
r := new(model.AccountRecoveryInfo)
if err = rows.Scan(&r.Rid, &r.Mid, &r.UserType, &r.Status, &r.LoginAddr, &r.UNames, &r.RegTime, &r.RegType, &r.RegAddr,
&r.Pwd, &r.Phones, &r.Emails, &r.SafeQuestion, &r.SafeAnswer, &r.CardType, &r.CardID,
&r.SysLoginAddr, &r.SysReg, &r.SysUNames, &r.SysPwds, &r.SysPhones, &r.SysEmails, &r.SysSafe, &r.SysCard,
&r.LinkEmail, &r.Operator, &r.OptTime, &r.Remark, &r.CTime, &r.Bussiness, &r.LastSucCount, &r.LastSucCTime); err != nil {
log.Error("GetAllByCon error (%+v)", err)
continue
}
resultData = append(resultData, r)
}
return resultData, total, err
}
// QueryByID query by rid
func (dao *Dao) QueryByID(c context.Context, rid int64, fromTime, endTime xtime.Time) (res *model.AccountRecoveryInfo, err error) {
sql1 := "select rid,mid,user_type,status,login_addrs,unames,reg_time,reg_type,reg_addr,pwds,phones,emails,safe_question,safe_answer,card_type,card_id," +
"sys_login_addrs,sys_reg,sys_unames,sys_pwds,sys_phones,sys_emails,sys_safe,sys_card," +
"link_email,operator,opt_time,remark,ctime,business from account_recovery_info where ctime between ? and ? and rid = ?"
res = new(model.AccountRecoveryInfo)
row := dao.db.QueryRow(c, sql1, fromTime, endTime, rid)
if err = row.Scan(&res.Rid, &res.Mid, &res.UserType, &res.Status, &res.LoginAddr, &res.UNames, &res.RegTime, &res.RegType, &res.RegAddr,
&res.Pwd, &res.Phones, &res.Emails, &res.SafeQuestion, &res.SafeAnswer, &res.CardType, &res.CardID,
&res.SysLoginAddr, &res.SysReg, &res.SysUNames, &res.SysPwds, &res.SysPhones, &res.SysEmails, &res.SysSafe, &res.SysCard,
&res.LinkEmail, &res.Operator, &res.OptTime, &res.Remark, &res.CTime, &res.Bussiness); err != nil {
if err == sql.ErrNoRows {
err = nil
res = nil
return
}
log.Error("QueryByID(%d) error(%v)", rid, err)
return
}
return
}
//QueryInfoByLimit page query through limit m,n
func (dao *Dao) QueryInfoByLimit(c context.Context, req *model.DBRecoveryInfoParams) (res []*model.AccountRecoveryInfo, total int64, err error) {
p := make([]interface{}, 0)
s := " where ctime between ? and ?"
p = append(p, req.StartTime)
p = append(p, req.EndTime)
if req.ExistGame {
s = s + " and user_type = ?"
p = append(p, req.Game)
}
if req.ExistStatus {
s = s + " and status = ?"
p = append(p, req.Status)
}
if req.ExistMid {
s = s + " and mid = ?"
p = append(p, req.Mid)
}
var s2 = s + " order by rid desc limit ?,?"
p2 := p
p2 = append(p2, (req.CurrPage-1)*req.Size, req.Size)
var rows *sql.Rows
rows, err = dao.db.Query(c, fmt.Sprintf(_selectRecoveryInfoLimit, s2), p2...)
if err != nil {
if err == sql.ErrNoRows {
err = nil
return
}
log.Error("QueryInfo err: d.db.Query error(%v)", err)
return
}
defer rows.Close()
res = make([]*model.AccountRecoveryInfo, 0, req.Size)
for rows.Next() {
r := new(model.AccountRecoveryInfo)
if err = rows.Scan(&r.Rid, &r.Mid, &r.UserType, &r.Status, &r.LoginAddr, &r.UNames, &r.RegTime, &r.RegType, &r.RegAddr,
&r.Pwd, &r.Phones, &r.Emails, &r.SafeQuestion, &r.SafeAnswer, &r.CardType, &r.CardID,
&r.SysLoginAddr, &r.SysReg, &r.SysUNames, &r.SysPwds, &r.SysPhones, &r.SysEmails, &r.SysSafe, &r.SysCard,
&r.LinkEmail, &r.Operator, &r.OptTime, &r.Remark, &r.CTime, &r.Bussiness); err != nil {
log.Error("QueryInfo (%v) error (%v)", req, err)
return
}
res = append(res, r)
}
row := dao.db.QueryRow(c, _selectCountRecoveryInfo+s, p...)
if err = row.Scan(&total); err != nil {
log.Error("QueryInfo total error (%v)", err)
return
}
return
}
// GetUinfoByRid get mid,linkMail by rid
func (dao *Dao) GetUinfoByRid(c context.Context, rid int64) (mid int64, linkMail string, ctime string, err error) {
res := dao.db.Prepared(_getUinfoByRid).QueryRow(c, rid)
req := new(struct {
Mid int64
LinKMail string
Ctime xtime.Time
})
if err = res.Scan(&req.Mid, &req.LinKMail, &req.Ctime); err != nil {
if err == sql.ErrNoRows {
req.Mid = 0
err = nil
} else {
log.Error("GetUinfoByRid row.Scan error(%v)", err)
}
}
mid = req.Mid
linkMail = req.LinKMail
ctime = req.Ctime.Time().Format("2006-01-02 15:04:05")
return
}
// GetUinfoByRidMore get list of BatchAppeal by rid
func (dao *Dao) GetUinfoByRidMore(c context.Context, ridsStr string) (bathRes []*model.BatchAppeal, err error) {
rows, err := dao.db.Prepared(fmt.Sprintf(_getUinfoByRidMore, ridsStr)).Query(c)
if err != nil {
return nil, err
}
defer rows.Close()
bathRes = make([]*model.BatchAppeal, 0, len(strings.Split(ridsStr, ",")))
for rows.Next() {
req := &model.BatchAppeal{}
if err = rows.Scan(&req.Rid, &req.Mid, &req.LinkMail, &req.Ctime); err != nil {
return
}
bathRes = append(bathRes, req)
}
return
}
// GetUnCheckInfo get uncheck info
func (dao *Dao) GetUnCheckInfo(c context.Context, rid int64) (r *model.UserInfoReq, err error) {
row := dao.db.QueryRow(c, _selectUnCheckInfo, rid)
r = new(model.UserInfoReq)
if err = row.Scan(&r.Mid, &r.LoginAddrs, &r.Unames, &r.RegTime, &r.RegType, &r.RegAddr,
&r.Pwds, &r.Phones, &r.Emails, &r.SafeQuestion, &r.SafeAnswer, &r.CardType, &r.CardID); err != nil {
if err == sql.ErrNoRows {
err = nil
return
}
log.Error("GetUnCheckInfo (%v) error (%v)", rid, err)
}
return
}
//BeginTran begin transaction
func (dao *Dao) BeginTran(ctx context.Context) (tx *sql.Tx, err error) {
if tx, err = dao.db.Begin(ctx); err != nil {
log.Error("db: begintran BeginTran d.db.Begin error(%v)", err)
}
return
}
// GetMailStatus get mail_status by rid
func (dao *Dao) GetMailStatus(c context.Context, rid int64) (mailStatus int64, err error) {
res := dao.db.Prepared(_getMailStatus).QueryRow(c, rid)
if err = res.Scan(&mailStatus); err != nil {
if err == sql.ErrNoRows {
mailStatus = -1
err = nil
} else {
log.Error("GetStatusByRid row.Scan error(%v)", err)
}
}
return
}
// UpdateMailStatus update mail_status.
func (dao *Dao) UpdateMailStatus(c context.Context, rid int64) (err error) {
_, err = dao.db.Exec(c, _updateMailStatus, rid)
return
}
// UpdateRecoveryAddit is
func (dao *Dao) UpdateRecoveryAddit(c context.Context, rid int64, files []string, extra string) (err error) {
_, err = dao.db.Exec(c, _updateRecoveryAddit, strings.Join(files, ","), extra, rid)
return
}
// GetRecoveryAddit is
func (dao *Dao) GetRecoveryAddit(c context.Context, rid int64) (addit *model.DBAccountRecoveryAddit, err error) {
row := dao.db.QueryRow(c, _getRecoveryAddit, rid)
addit = new(model.DBAccountRecoveryAddit)
if err = row.Scan(&addit.Rid, &addit.Files, &addit.Extra, &addit.Ctime, &addit.Mtime); err != nil {
if err == sql.ErrNoRows {
err = nil
return
}
log.Error("GetRecoveryAddit (%v) error (%v)", rid, err)
}
return
}
// InsertRecoveryAddit is
func (dao *Dao) InsertRecoveryAddit(c context.Context, rid int64, files, extra string) (err error) {
_, err = dao.db.Exec(c, _insertRecoveryAddit, rid, files, extra)
return
}
//BatchGetRecoveryAddit is
func (dao *Dao) BatchGetRecoveryAddit(c context.Context, rids []int64) (addits map[int64]*model.DBAccountRecoveryAddit, err error) {
rows, err := dao.db.Query(c, fmt.Sprintf(_batchRecoveryAAdit, xstr.JoinInts(rids)))
if err != nil {
if err == sql.ErrNoRows {
err = nil
return
}
log.Error("BatchGetRecoveryAddit d.db.Query error(%v)", err)
return
}
defer rows.Close()
addits = make(map[int64]*model.DBAccountRecoveryAddit)
for rows.Next() {
var addit = new(model.DBAccountRecoveryAddit)
if err = rows.Scan(&addit.Rid, &addit.Files, &addit.Extra, &addit.Ctime, &addit.Mtime); err != nil {
log.Error("BatchGetRecoveryAddit rows.Scan error(%v)", err)
continue
}
addits[addit.Rid] = addit
}
return
}
// BatchGetLastSuccess batch get last find success info
func (dao *Dao) BatchGetLastSuccess(c context.Context, mids []int64) (lastSuccessMap map[int64]*model.LastSuccessData, err error) {
rows, err := dao.db.Query(c, fmt.Sprintf(_batchGetLastSuccess, xstr.JoinInts(mids)))
if err != nil {
if err == sql.ErrNoRows {
err = nil
return
}
log.Error("BatchGetLastSuccess d.db.Query error(%v)", err)
return
}
defer rows.Close()
lastSuccessMap = make(map[int64]*model.LastSuccessData)
for rows.Next() {
r := new(model.LastSuccessData)
if err = rows.Scan(&r.LastApplyMID, &r.LastApplyTime); err != nil {
log.Error("BatchGetLastSuccess rows.Scan error(%v)", err)
continue
}
lastSuccessMap[r.LastApplyMID] = r
}
return
}
// GetLastSuccess get last find success info
func (dao *Dao) GetLastSuccess(c context.Context, mid int64) (lastSuc *model.LastSuccessData, err error) {
row := dao.db.QueryRow(c, _getLastSuccess, mid)
lastSuc = new(model.LastSuccessData)
if err = row.Scan(&lastSuc.LastApplyMID, &lastSuc.LastApplyTime); err != nil {
if err == sql.ErrNoRows {
err = nil
return
}
log.Error("GetRecoveryAddit (%v) error (%v)", mid, err)
}
return
}

View File

@ -0,0 +1,412 @@
package dao
import (
"context"
"testing"
"go-common/app/service/main/account-recovery/model"
xtime "go-common/library/time"
"github.com/smartystreets/goconvey/convey"
)
func TestDaoGetStatusByRid(t *testing.T) {
var (
c = context.Background()
rid = int64(1)
)
convey.Convey("GetStatusByRid", t, func(ctx convey.C) {
status, err := d.GetStatusByRid(c, rid)
ctx.Convey("Then err should be nil.status should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(status, convey.ShouldNotBeNil)
})
})
}
func TestDaoGetSuccessCount(t *testing.T) {
var (
c = context.Background()
mid = int64(1234)
)
convey.Convey("GetSuccessCount", t, func(ctx convey.C) {
count, err := d.GetSuccessCount(c, mid)
ctx.Convey("Then err should be nil.count should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(count, convey.ShouldNotBeNil)
})
})
}
func TestDaoBatchGetRecoverySuccess(t *testing.T) {
var (
c = context.Background()
mids = []int64{1234}
)
convey.Convey("BatchGetRecoverySuccess", t, func(ctx convey.C) {
countMap, err := d.BatchGetRecoverySuccess(c, mids)
ctx.Convey("Then err should be nil.countMap should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(countMap, convey.ShouldNotBeNil)
})
})
}
func TestDaoUpdateSuccessCount(t *testing.T) {
var (
c = context.Background()
mid = int64(1234)
)
convey.Convey("UpdateSuccessCount", t, func(ctx convey.C) {
err := d.UpdateSuccessCount(c, mid)
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
})
})
}
func TestDaoBatchUpdateSuccessCount(t *testing.T) {
var (
c = context.Background()
mids = "1234"
)
convey.Convey("BatchUpdateSuccessCount", t, func(ctx convey.C) {
err := d.BatchUpdateSuccessCount(c, mids)
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
})
})
}
func TestDaoGetNoDeal(t *testing.T) {
var (
c = context.Background()
mid = int64(1234)
)
convey.Convey("GetNoDeal", t, func(ctx convey.C) {
count, err := d.GetNoDeal(c, mid)
ctx.Convey("Then err should be nil.count should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(count, convey.ShouldNotBeNil)
})
})
}
func TestDaoUpdateStatus(t *testing.T) {
var (
c = context.Background()
status = int64(1)
rid = int64(1)
operator = "abcd"
optTime xtime.Time
remark = ""
)
convey.Convey("UpdateStatus", t, func(ctx convey.C) {
err := d.UpdateStatus(c, status, rid, operator, optTime, remark)
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
})
})
}
func TestDaoUpdateUserType(t *testing.T) {
var (
c = context.Background()
status = int64(1)
rid = int64(1)
)
convey.Convey("UpdateUserType", t, func(ctx convey.C) {
err := d.UpdateUserType(c, status, rid)
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
})
})
}
func TestDaoInsertRecoveryInfo(t *testing.T) {
var (
c = context.Background()
uinfo = &model.UserInfoReq{
LoginAddrs: "中国-福州,中国-上海,澳大利亚",
//RegTime:timeS, //变成2018
RegTime: 1533206284, //变成2018 //数据库设置为int(11),so数据库必须设置为tiimestamp
RegType: int8(1),
RegAddr: "中国上海",
Unames: "昵称AA,昵称BB,昵称CC",
Pwds: "密码1,密码2",
Phones: "12345678901,54321678923",
Emails: "2456@sina.com,789@qq.com",
SafeQuestion: int8(1),
SafeAnswer: "心态呀",
CardID: "ISN-1234567890-0987",
CardType: int8(1),
LinkMail: "345678@qq.com",
Mid: 1234,
}
)
convey.Convey("InsertRecoveryInfo", t, func(ctx convey.C) {
lastID, err := d.InsertRecoveryInfo(c, uinfo)
ctx.Convey("Then err should be nil.lastID should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(lastID, convey.ShouldNotBeNil)
})
})
}
func TestDaoUpdateSysInfo(t *testing.T) {
var (
c = context.Background()
sys = &model.SysInfo{
SysLoginAddrs: "中国-福州,中国-上海,澳大利亚",
SysReg: "对",
SysUNames: "对,错,错",
SysPwds: "对,错",
SysPhones: "对,错",
SysEmails: "对,错",
SysSafe: "对",
SysCard: "对",
}
userType = int64(1)
rid = int64(1)
)
convey.Convey("UpdateSysInfo", t, func(ctx convey.C) {
err := d.UpdateSysInfo(c, sys, userType, rid)
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
})
})
}
func TestDaoQueryByID(t *testing.T) {
var (
c = context.Background()
rid = int64(1)
fromTime xtime.Time = 1533120949
endTime xtime.Time = 1535636392
)
convey.Convey("QueryByID", t, func(ctx convey.C) {
res, err := d.QueryByID(c, rid, fromTime, endTime)
ctx.Convey("Then err should be nil.res should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(res, convey.ShouldNotBeNil)
})
})
}
func TestDaoQueryInfoByLimit(t *testing.T) {
var (
c = context.Background()
req = &model.DBRecoveryInfoParams{
ExistGame: false,
ExistStatus: false,
ExistMid: false,
Mid: 0,
Game: 0,
Status: 1,
FirstRid: 20,
LastRid: 0,
Size: 2,
StartTime: 1533120949,
EndTime: 1535636392,
SubNum: 1,
CurrPage: 1,
}
)
convey.Convey("QueryInfoByLimit", t, func(ctx convey.C) {
res, total, err := d.QueryInfoByLimit(c, req)
ctx.Convey("Then err should be nil.res,total should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(total, convey.ShouldNotBeNil)
ctx.So(res, convey.ShouldNotBeNil)
})
})
}
func TestDaoGetUinfoByRid(t *testing.T) {
var (
c = context.Background()
rid = int64(240)
)
convey.Convey("GetUinfoByRid", t, func(ctx convey.C) {
mid, linkMail, ctime, err := d.GetUinfoByRid(c, rid)
ctx.Convey("Then err should be nil.mid,linkMail,ctime should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(ctime, convey.ShouldNotBeNil)
ctx.So(linkMail, convey.ShouldNotBeNil)
ctx.So(mid, convey.ShouldNotBeNil)
})
})
}
func TestDaoGetUinfoByRidMore(t *testing.T) {
var (
c = context.Background()
ridsStr = "1,2"
)
convey.Convey("GetUinfoByRidMore", t, func(ctx convey.C) {
bathRes, err := d.GetUinfoByRidMore(c, ridsStr)
ctx.Convey("Then err should be nil.bathRes should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(bathRes, convey.ShouldNotBeNil)
})
})
}
func TestDaoGetUnCheckInfo(t *testing.T) {
var (
c = context.Background()
rid = int64(1)
)
convey.Convey("GetUnCheckInfo", t, func(ctx convey.C) {
r, err := d.GetUnCheckInfo(c, rid)
ctx.Convey("Then err should be nil.r should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(r, convey.ShouldNotBeNil)
})
})
}
func TestDaoGetMailStatus(t *testing.T) {
var (
c = context.Background()
rid = int64(1)
)
convey.Convey("GetMailStatus", t, func(ctx convey.C) {
mailStatus, err := d.GetMailStatus(c, rid)
ctx.Convey("Then err should be nil.mailStatus should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(mailStatus, convey.ShouldNotBeNil)
})
})
}
func TestDaoUpdateMailStatus(t *testing.T) {
var (
c = context.Background()
rid = int64(1)
)
convey.Convey("UpdateMailStatus", t, func(ctx convey.C) {
err := d.UpdateMailStatus(c, rid)
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
})
})
}
func TestDaoGetAllByCon(t *testing.T) {
var (
defint int64
c = context.Background()
aq = &model.QueryRecoveryInfoReq{
//RID: 1,
//UID:2,
Status: &defint,
Game: &defint,
Size: 10,
Page: 1,
StartTime: 1533052800,
EndTime: 1536924163,
//StartTime time.Time `json:"start_time" form:"start_time"`
//EndTime time.Time `json:"end_time" form:"end_time"`
//IsAdvanced bool `json:"-"`
//Page int64 `form:"page"`
}
)
convey.Convey("UpdateMailStatus", t, func(ctx convey.C) {
resultData, total, err := d.GetAllByCon(c, aq)
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
ctx.So(resultData, convey.ShouldNotBeNil)
ctx.So(total, convey.ShouldNotBeNil)
ctx.So(err, convey.ShouldBeNil)
ctx.Println(total, " len=", len(resultData))
})
})
}
func TestDaoInsertRecoveryAddit(t *testing.T) {
var (
c = context.Background()
rid int64 = 1
files = "http://uat-i0.hdslb.com/bfs/account/recovery/bca2.zip,http://uat-i0.hdslb.com/bfs/account/recovery/abcd.zip"
extra = `{"GameArea":"ios-A服","GameNames":"崩坏3","GamePlay":"1"}`
)
convey.Convey("InsertRecoveryAddit", t, func(ctx convey.C) {
err := d.InsertRecoveryAddit(c, rid, files, extra)
err = d.InsertRecoveryAddit(c, 2, files, extra)
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
})
})
}
func TestDaoUpdateRecoveryAddit(t *testing.T) {
var (
c = context.Background()
rid int64 = 1
files = []string{"http://uat-i0.hdslb.com/bfs/aaaa.zip", "http://uat-i0.hdslb.com/bfs/dddd.zip"}
extra = `{"GameArea":"ios-A服","GameNames":"崩坏3","GamePlay":"1"}`
)
convey.Convey("UpdateMailStatus", t, func(ctx convey.C) {
err := d.UpdateRecoveryAddit(c, rid, files, extra)
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
})
})
}
func TestDaoGetRecoveryAddit(t *testing.T) {
var (
c = context.Background()
rid int64 = 1
)
convey.Convey("UpdateMailStatus", t, func(ctx convey.C) {
addit, err := d.GetRecoveryAddit(c, rid)
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.Println(addit, convey.ShouldNotBeNil)
})
})
}
func TestDaoBatchGetRecoveryAddit(t *testing.T) {
var (
c = context.Background()
rids = []int64{1, 2}
)
convey.Convey("BatchGetRecoveryAddit", t, func(ctx convey.C) {
addits, err := d.BatchGetRecoveryAddit(c, rids)
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(addits, convey.ShouldNotBeNil)
})
})
}
func TestBatchGetLastSuccess(t *testing.T) {
var (
c = context.Background()
mids = []int64{1234}
)
convey.Convey("BatchGetLastSuccess", t, func(ctx convey.C) {
lastSuccessMap, err := d.BatchGetLastSuccess(c, mids)
ctx.Convey("Then err should be nil.count should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(lastSuccessMap, convey.ShouldNotBeNil)
})
})
}
func TestDaoGetLastSuccess(t *testing.T) {
var (
c = context.Background()
mid = int64(1234)
)
convey.Convey("GetLastSuccess", t, func(ctx convey.C) {
res, err := d.GetLastSuccess(c, mid)
ctx.Convey("Then err should be nil.count should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(res, convey.ShouldNotBeNil)
})
})
}

View File

@ -0,0 +1,92 @@
package dao
import (
"context"
"strconv"
"time"
"go-common/library/cache/redis"
"go-common/library/log"
)
const (
_expire = 30 * 60 // 30 minutes
_prefixCaptcha = "recovery:ca_"
)
// SetLinkMailCount set linkMail expire time.
func (d *Dao) SetLinkMailCount(c context.Context, linkMail string) (state int64, err error) {
conn := d.redis.Get(c)
defer conn.Close()
//用时间戳去减去时间 当天0点过期第二天又可以发送10封邮件
i, _ := redis.Int(conn.Do("incr", linkMail))
conn.Do("expire", linkMail, getSubtime())
if i >= 11 { //第10封邮件之后
state = 10 //当天邮件发送到达最大次数
return
}
return
}
func getSubtime() (subtime int64) {
timeStr := time.Now().Format("2006-01-02")
t, _ := time.Parse("2006-01-02", timeStr)
t2 := t.AddDate(0, 0, 1).Unix()
curr := time.Now().Unix() //当前时间
subtime = t2 - curr
return
}
// keyCaptcha
func keyCaptcha(mid int64, linkMail string) string {
return _prefixCaptcha + strconv.FormatInt(mid, 10) + "_" + linkMail
}
// SetCaptcha set linkMail expire time.
func (d *Dao) SetCaptcha(c context.Context, code string, mid int64, linkMail string) (err error) {
conn := d.redis.Get(c)
defer conn.Close()
key := keyCaptcha(mid, linkMail)
//验证码30分钟内有效
if _, err = conn.Do("SETEX", key, _expire, code); err != nil {
log.Error("conn.Do(SETEX, %d, %v, %s) error(%v)", mid, _expire, code, err)
}
return
}
// GetEMailCode get captcha from redis
func (d *Dao) GetEMailCode(c context.Context, mid int64, linkMail string) (code string, err error) {
key := keyCaptcha(mid, linkMail)
conn := d.redis.Get(c)
defer conn.Close()
code, err = redis.String(conn.Do("GET", key))
if err != nil {
if err == redis.ErrNil {
err = nil
return
}
log.Error("conn.Do(GET, %s, ), err (%v)", key, err)
return
}
return
}
// DelEMailCode del captcha from redis 提交:校验验证之后就删除验证码(保证只能提交一次)
func (d *Dao) DelEMailCode(c context.Context, mid int64, linkMail string) (err error) {
key := keyCaptcha(mid, linkMail)
conn := d.redis.Get(c)
defer conn.Close()
if _, err = conn.Do("DEL", key); err != nil {
log.Error("conn.Do(DEL, %s, ), err (%v)", key, err)
return
}
return
}
// PingRedis check connection success.
func (d *Dao) PingRedis(c context.Context) (err error) {
conn := d.redis.Get(c)
defer conn.Close()
_, err = conn.Do("GET", "PING")
return
}

View File

@ -0,0 +1,100 @@
package dao
import (
"context"
"testing"
"github.com/smartystreets/goconvey/convey"
)
func TestDaoSetLinkMailCount(t *testing.T) {
var (
c = context.Background()
linkMail = "2459593393@qq.com"
)
convey.Convey("SetLinkMailCount", t, func(ctx convey.C) {
state, err := d.SetLinkMailCount(c, linkMail)
ctx.Convey("Then err should be nil.state should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(state, convey.ShouldNotBeNil)
})
})
}
func TestDaogetSubtime(t *testing.T) {
convey.Convey("getSubtime", t, func(ctx convey.C) {
subtime := getSubtime()
ctx.Convey("Then subtime should not be nil.", func(ctx convey.C) {
ctx.So(subtime, convey.ShouldNotBeNil)
})
})
}
func TestDaokeyCaptcha(t *testing.T) {
var (
mid = int64(0)
linkMail = "2459593393@qq.com"
)
convey.Convey("keyCaptcha", t, func(ctx convey.C) {
p1 := keyCaptcha(mid, linkMail)
ctx.Convey("Then p1 should not be nil.", func(ctx convey.C) {
ctx.So(p1, convey.ShouldNotBeNil)
})
})
}
func TestDaoSetCaptcha(t *testing.T) {
var (
c = context.Background()
code = "1234"
mid = int64(1)
linkMail = "2459593393@qq.com"
)
convey.Convey("SetCaptcha", t, func(ctx convey.C) {
err := d.SetCaptcha(c, code, mid, linkMail)
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
})
})
}
func TestDaoGetEMailCode(t *testing.T) {
var (
c = context.Background()
mid = int64(1)
linkMail = "2459593393@qq.com"
)
convey.Convey("GetEMailCode", t, func(ctx convey.C) {
code, err := d.GetEMailCode(c, mid, linkMail)
ctx.Convey("Then err should be nil.code should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(code, convey.ShouldNotBeNil)
})
})
}
func TestDaoDelEMailCode(t *testing.T) {
var (
c = context.Background()
mid = int64(1)
linkMail = "2459593393@qq.com"
)
convey.Convey("GetEMailCode", t, func(ctx convey.C) {
err := d.DelEMailCode(c, mid, linkMail)
ctx.Convey("Then err should be nil.code should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
})
})
}
func TestDaoPingRedis(t *testing.T) {
var (
c = context.Background()
)
convey.Convey("PingRedis", t, func(ctx convey.C) {
err := d.PingRedis(c)
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
})
})
}

View File

@ -0,0 +1,49 @@
package dao
import (
"context"
account "go-common/app/service/main/account/api"
location "go-common/app/service/main/location/model"
member "go-common/app/service/main/member/api"
"go-common/library/log"
"go-common/library/net/metadata"
"github.com/pkg/errors"
)
// Info3 get info by mid
func (d *Dao) Info3(c context.Context, mid int64) (info *account.Info, err error) {
var (
arg = &account.MidReq{
Mid: mid,
RealIp: metadata.String(c, metadata.RemoteIP),
}
res *account.InfoReply
)
if res, err = d.accountClient.Info3(c, arg); err != nil {
err = errors.Wrapf(err, "%v", arg)
return nil, err
}
return res.Info, nil
}
// Infos get the ips info.
func (d *Dao) Infos(c context.Context, ipList []string) (res map[string]*location.Info, err error) {
if res, err = d.locRPC.Infos(c, ipList); err != nil {
log.Error("s.locaRPC err(%v)", err)
}
return
}
// CheckRealnameStatus realname status
func (d *Dao) CheckRealnameStatus(c context.Context, mid int64) (status int8, err error) {
var (
relnameStatus *member.RealnameStatusReply
)
if relnameStatus, err = d.memberClient.RealnameStatus(c, &member.MemberMidReq{Mid: mid, RemoteIP: metadata.String(c, metadata.RemoteIP)}); err != nil {
log.Error("s.memberSvr.RealnameStatus err(%v)", err)
return
}
return relnameStatus.RealnameStatus, nil
}

View File

@ -0,0 +1,50 @@
package dao
import (
"context"
"testing"
"github.com/smartystreets/goconvey/convey"
)
func TestDaoInfo3(t *testing.T) {
var (
c = context.Background()
mid = int64(2)
)
convey.Convey("Info3", t, func(ctx convey.C) {
res, err := d.Info3(c, mid)
ctx.Convey("Then err should be nil.res should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(res, convey.ShouldNotBeNil)
})
})
}
func TestDaoInfos(t *testing.T) {
var (
c = context.Background()
ipList = []string{"127.0.0.1"}
)
convey.Convey("Infos", t, func(ctx convey.C) {
res, err := d.Infos(c, ipList)
ctx.Convey("Then err should be nil.res should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(res, convey.ShouldNotBeNil)
})
})
}
func TestDaoCheckRealnameStatus(t *testing.T) {
var (
c = context.Background()
mid int64 = 1
)
convey.Convey("CheckRealnameStatus", t, func(ctx convey.C) {
res, err := d.CheckRealnameStatus(c, mid)
ctx.Convey("Then err should be nil.res should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(res, convey.ShouldNotBeNil)
})
})
}

View File

@ -0,0 +1,51 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_test",
"go_library",
)
go_test(
name = "go_default_test",
srcs = [
"args_test.go",
"builder_test.go",
"cond_test.go",
"flavor_test.go",
"modifiers_test.go",
"select_test.go",
],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
)
go_library(
name = "go_default_library",
srcs = [
"args.go",
"builder.go",
"cond.go",
"flavor.go",
"modifiers.go",
"select.go",
],
importpath = "go-common/app/service/main/account-recovery/dao/sqlbuilder",
tags = ["automanaged"],
visibility = ["//visibility:public"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,236 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"bytes"
"database/sql"
"fmt"
"sort"
"strconv"
"strings"
)
// Args stores arguments associated with a SQL.
type Args struct {
// The default flavor used by `Args#Compile`
Flavor Flavor
args []interface{}
namedArgs map[string]int
sqlNamedArgs map[string]int
onlyNamed bool
}
// Add adds an arg to Args and returns a placeholder.
func (args *Args) Add(arg interface{}) string {
return fmt.Sprintf("$%v", args.add(arg))
}
func (args *Args) add(arg interface{}) int {
idx := len(args.args)
switch a := arg.(type) {
case sql.NamedArg:
if args.sqlNamedArgs == nil {
args.sqlNamedArgs = map[string]int{}
}
if p, ok := args.sqlNamedArgs[a.Name]; ok {
arg = args.args[p]
break
}
args.sqlNamedArgs[a.Name] = idx
case namedArgs:
if args.namedArgs == nil {
args.namedArgs = map[string]int{}
}
if p, ok := args.namedArgs[a.name]; ok {
arg = args.args[p]
break
}
// Find out the real arg and add it to args.
idx = args.add(a.arg)
args.namedArgs[a.name] = idx
return idx
}
args.args = append(args.args, arg)
return idx
}
// Compile compiles builder's format to standard sql and returns associated args.
//
// The format string uses a special syntax to represent arguments.
//
// $? refers successive arguments passed in the call. It works similar as `%v` in `fmt.Sprintf`.
// $0 $1 ... $n refers nth-argument passed in the call. Next $? will use arguments n+1.
// ${name} refers a named argument created by `Named` with `name`.
// $$ is a "$" string.
func (args *Args) Compile(format string, intialValue ...interface{}) (query string, values []interface{}) {
return args.CompileWithFlavor(format, args.Flavor, intialValue...)
}
// CompileWithFlavor compiles builder's format to standard sql with flavor and returns associated args.
//
// See doc for `Compile` to learn details.
func (args *Args) CompileWithFlavor(format string, flavor Flavor, intialValue ...interface{}) (query string, values []interface{}) {
buf := &bytes.Buffer{}
idx := strings.IndexRune(format, '$')
offset := 0
values = intialValue
if flavor == invalidFlavor {
flavor = DefaultFlavor
}
for idx >= 0 && len(format) > 0 {
if idx > 0 {
buf.WriteString(format[:idx])
}
format = format[idx+1:]
// Should not happen.
if len(format) == 0 {
break
}
if format[0] == '$' {
buf.WriteRune('$')
format = format[1:]
} else if format[0] == '{' {
format, values = args.compileNamed(buf, flavor, format, values)
} else if !args.onlyNamed && '0' <= format[0] && format[0] <= '9' {
format, values, offset = args.compileDigits(buf, flavor, format, values, offset)
} else if !args.onlyNamed && format[0] == '?' {
format, values, offset = args.compileSuccessive(buf, flavor, format[1:], values, offset)
}
idx = strings.IndexRune(format, '$')
}
if len(format) > 0 {
buf.WriteString(format)
}
query = buf.String()
if len(args.sqlNamedArgs) > 0 {
// Stabilize the sequence to make it easier to write test cases.
ints := make([]int, 0, len(args.sqlNamedArgs))
for _, p := range args.sqlNamedArgs {
ints = append(ints, p)
}
sort.Ints(ints)
for _, i := range ints {
values = append(values, args.args[i])
}
}
return
}
func (args *Args) compileNamed(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}) (string, []interface{}) {
i := 1
for ; i < len(format) && format[i] != '}'; i++ {
// Nothing.
}
// Invalid $ format. Ignore it.
if i == len(format) {
return format, values
}
name := format[1:i]
format = format[i+1:]
if p, ok := args.namedArgs[name]; ok {
format, values, _ = args.compileSuccessive(buf, flavor, format, values, p)
}
return format, values
}
func (args *Args) compileDigits(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) {
i := 1
for ; i < len(format) && '0' <= format[i] && format[i] <= '9'; i++ {
// Nothing.
}
digits := format[:i]
format = format[i:]
if pointer, err := strconv.Atoi(digits); err == nil {
return args.compileSuccessive(buf, flavor, format, values, pointer)
}
return format, values, offset
}
func (args *Args) compileSuccessive(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) {
if offset >= len(args.args) {
return format, values, offset
}
arg := args.args[offset]
values = args.compileArg(buf, flavor, values, arg)
return format, values, offset + 1
}
func (args *Args) compileArg(buf *bytes.Buffer, flavor Flavor, values []interface{}, arg interface{}) []interface{} {
switch a := arg.(type) {
case Builder:
var s string
s, values = a.BuildWithFlavor(flavor, values...)
buf.WriteString(s)
case sql.NamedArg:
buf.WriteRune('@')
buf.WriteString(a.Name)
case rawArgs:
buf.WriteString(a.expr)
case listArgs:
if len(a.args) > 0 {
values = args.compileArg(buf, flavor, values, a.args[0])
}
for i := 1; i < len(a.args); i++ {
buf.WriteString(", ")
values = args.compileArg(buf, flavor, values, a.args[i])
}
default:
switch flavor {
case MySQL:
buf.WriteRune('?')
case PostgreSQL:
fmt.Fprintf(buf, "$%v", len(values)+1)
default:
panic(fmt.Errorf("Args.CompileWithFlavor: invalid flavor %v (%v)", flavor, int(flavor)))
}
values = append(values, arg)
}
return values
}
// Copy is
func (args *Args) Copy() *Args {
return &Args{
Flavor: args.Flavor,
args: args.args,
namedArgs: args.namedArgs,
sqlNamedArgs: args.sqlNamedArgs,
onlyNamed: args.onlyNamed,
}
}

View File

@ -0,0 +1,76 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"bytes"
"fmt"
"strings"
"testing"
)
func TestArgs(t *testing.T) {
cases := map[string][]interface{}{
"abc ? def\n[123]": {"abc $? def", 123},
"abc ? def\n[456]": {"abc $0 def", 456},
"abc def\n[]": {"abc $1 def", 123},
"abc ? def\n[789]": {"abc ${s} def", Named("s", 789)},
"abc def \n[]": {"abc ${unknown} def ", 123},
"abc $ def\n[]": {"abc $$ def", 123},
"abcdef\n[]": {"abcdef$", 123},
"abc ? ? ? ? def\n[123 456 123 456]": {"abc $? $? $0 $? def", 123, 456, 789},
"abc ? raw ? raw def\n[123 123]": {"abc $? $? $0 $? def", 123, Raw("raw"), 789},
}
for expected, c := range cases {
args := new(Args)
for i := 1; i < len(c); i++ {
args.Add(c[i])
}
sql, values := args.Compile(c[0].(string))
actual := fmt.Sprintf("%v\n%v", sql, values)
if actual != expected {
t.Fatalf("invalid compile result. [expected:%v] [actual:%v]", expected, actual)
}
}
old := DefaultFlavor
DefaultFlavor = PostgreSQL
defer func() {
DefaultFlavor = old
}()
// PostgreSQL flavor compiled sql.
for expected, c := range cases {
args := new(Args)
for i := 1; i < len(c); i++ {
args.Add(c[i])
}
sql, values := args.Compile(c[0].(string))
actual := fmt.Sprintf("%v\n%v", sql, values)
expected = toPostgreSQL(expected)
if actual != expected {
t.Fatalf("invalid compile result. [expected:%v] [actual:%v]", expected, actual)
}
}
}
func toPostgreSQL(sql string) string {
parts := strings.Split(sql, "?")
buf := &bytes.Buffer{}
buf.WriteString(parts[0])
for i, p := range parts[1:] {
fmt.Fprintf(buf, "$%v", i+1)
buf.WriteString(p)
}
return buf.String()
}

View File

@ -0,0 +1,105 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"fmt"
)
// Builder is a general SQL builder.
// It's used by Args to create nested SQL like the `IN` expression in
// `SELECT * FROM t1 WHERE id IN (SELECT id FROM t2)`.
type Builder interface {
Build() (sql string, args []interface{})
BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{})
}
type compiledBuilder struct {
args *Args
format string
}
func (cb *compiledBuilder) Build() (sql string, args []interface{}) {
return cb.args.Compile(cb.format)
}
func (cb *compiledBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {
return cb.args.CompileWithFlavor(cb.format, flavor, initialArg...)
}
type flavoredBuilder struct {
builder Builder
flavor Flavor
}
func (fb *flavoredBuilder) Build() (sql string, args []interface{}) {
return fb.builder.BuildWithFlavor(fb.flavor)
}
func (fb *flavoredBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {
return fb.builder.BuildWithFlavor(flavor, initialArg...)
}
// WithFlavor creates a new Builder based on builder with a default flavor.
func WithFlavor(builder Builder, flavor Flavor) Builder {
return &flavoredBuilder{
builder: builder,
flavor: flavor,
}
}
// Buildf creates a Builder from a format string using `fmt.Sprintf`-like syntax.
// As all arguments will be converted to a string internally, e.g. "$0",
// only `%v` and `%s` are valid.
func Buildf(format string, arg ...interface{}) Builder {
args := &Args{
Flavor: DefaultFlavor,
}
vars := make([]interface{}, 0, len(arg))
for _, a := range arg {
vars = append(vars, args.Add(a))
}
return &compiledBuilder{
args: args,
format: fmt.Sprintf(Escape(format), vars...),
}
}
// Build creates a Builder from a format string.
// The format string uses special syntax to represent arguments.
// See doc in `Args#Compile` for syntax details.
func Build(format string, arg ...interface{}) Builder {
args := &Args{
Flavor: DefaultFlavor,
}
for _, a := range arg {
args.Add(a)
}
return &compiledBuilder{
args: args,
format: format,
}
}
// BuildNamed creates a Builder from a format string.
// The format string uses `${key}` to refer the value of named by key.
func BuildNamed(format string, named map[string]interface{}) Builder {
args := &Args{
Flavor: DefaultFlavor,
onlyNamed: true,
}
for n, v := range named {
args.Add(Named(n, v))
}
return &compiledBuilder{
args: args,
format: format,
}
}

View File

@ -0,0 +1,113 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"database/sql"
"fmt"
"reflect"
"testing"
)
func ExampleBuildf() {
sb := NewSelectBuilder()
sb.Select("id").From("user")
explain := Buildf("EXPLAIN %v LEFT JOIN SELECT * FROM banned WHERE state IN (%v, %v)", sb, 1, 2)
sql, args := explain.Build()
fmt.Println(sql)
fmt.Println(args)
// Output:
// EXPLAIN SELECT id FROM user LEFT JOIN SELECT * FROM banned WHERE state IN (?, ?)
// [1 2]
}
func ExampleBuild() {
sb := NewSelectBuilder()
sb.Select("id").From("user").Where(sb.In("status", 1, 2))
b := Build("EXPLAIN $? LEFT JOIN SELECT * FROM $? WHERE created_at > $? AND state IN (${states}) AND modified_at BETWEEN $2 AND $?",
sb, Raw("banned"), 1514458225, 1514544625, Named("states", List([]int{3, 4, 5})))
sql, args := b.Build()
fmt.Println(sql)
fmt.Println(args)
// Output:
// EXPLAIN SELECT id FROM user WHERE status IN (?, ?) LEFT JOIN SELECT * FROM banned WHERE created_at > ? AND state IN (?, ?, ?) AND modified_at BETWEEN ? AND ?
// [1 2 1514458225 3 4 5 1514458225 1514544625]
}
func ExampleBuildNamed() {
b := BuildNamed("SELECT * FROM ${table} WHERE status IN (${status}) AND name LIKE ${name} AND created_at > ${time} AND modified_at < ${time} + 86400",
map[string]interface{}{
"time": sql.Named("start", 1234567890),
"status": List([]int{1, 2, 5}),
"name": "Huan%",
"table": Raw("user"),
})
sql, args := b.Build()
fmt.Println(sql)
fmt.Println(args)
// Output:
// SELECT * FROM user WHERE status IN (?, ?, ?) AND name LIKE ? AND created_at > @start AND modified_at < @start + 86400
// [1 2 5 Huan% {{} start 1234567890}]
}
func ExampleWithFlavor() {
sql, args := WithFlavor(Buildf("SELECT * FROM foo WHERE id = %v", 1234), PostgreSQL).Build()
fmt.Println(sql)
fmt.Println(args)
// Explicitly use MySQL as the flavor.
sql, args = WithFlavor(Buildf("SELECT * FROM foo WHERE id = %v", 1234), PostgreSQL).BuildWithFlavor(MySQL)
fmt.Println(sql)
fmt.Println(args)
// Output:
// SELECT * FROM foo WHERE id = $1
// [1234]
// SELECT * FROM foo WHERE id = ?
// [1234]
}
func TestBuildWithPostgreSQL(t *testing.T) {
sb1 := PostgreSQL.NewSelectBuilder()
sb1.Select("col1", "col2").From("t1").Where(sb1.E("id", 1234), sb1.G("level", 2))
sb2 := PostgreSQL.NewSelectBuilder()
sb2.Select("col3", "col4").From("t2").Where(sb2.E("id", 4567), sb2.LE("level", 5))
// Use DefaultFlavor (MySQL) instead of PostgreSQL.
sql, args := Build("SELECT $1 AS col5 LEFT JOIN $0 LEFT JOIN $2", sb1, 7890, sb2).Build()
if expected := "SELECT ? AS col5 LEFT JOIN SELECT col1, col2 FROM t1 WHERE id = ? AND level > ? LEFT JOIN SELECT col3, col4 FROM t2 WHERE id = ? AND level <= ?"; sql != expected {
t.Fatalf("invalid sql. [expected:%v] [actual:%v]", expected, sql)
}
if expected := []interface{}{7890, 1234, 2, 4567, 5}; !reflect.DeepEqual(args, expected) {
t.Fatalf("invalid args. [expected:%v] [actual:%v]", expected, args)
}
old := DefaultFlavor
DefaultFlavor = PostgreSQL
defer func() {
DefaultFlavor = old
}()
sql, args = Build("SELECT $1 AS col5 LEFT JOIN $0 LEFT JOIN $2", sb1, 7890, sb2).Build()
if expected := "SELECT $1 AS col5 LEFT JOIN SELECT col1, col2 FROM t1 WHERE id = $2 AND level > $3 LEFT JOIN SELECT col3, col4 FROM t2 WHERE id = $4 AND level <= $5"; sql != expected {
t.Fatalf("invalid sql. [expected:%v] [actual:%v]", expected, sql)
}
if expected := []interface{}{7890, 1234, 2, 4567, 5}; !reflect.DeepEqual(args, expected) {
t.Fatalf("invalid args. [expected:%v] [actual:%v]", expected, args)
}
}

View File

@ -0,0 +1,136 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"fmt"
"strings"
)
// Cond provides several helper methods to build conditions.
type Cond struct {
Args *Args
}
// Equal represents "field = value".
func (c *Cond) Equal(field string, value interface{}) string {
return fmt.Sprintf("%v = %v", Escape(field), c.Args.Add(value))
}
// E is an alias of Equal.
func (c *Cond) E(field string, value interface{}) string {
return c.Equal(field, value)
}
// NotEqual represents "field != value".
func (c *Cond) NotEqual(field string, value interface{}) string {
return fmt.Sprintf("%v <> %v", Escape(field), c.Args.Add(value))
}
// NE is an alias of NotEqual.
func (c *Cond) NE(field string, value interface{}) string {
return c.NotEqual(field, value)
}
// GreaterThan represents "field > value".
func (c *Cond) GreaterThan(field string, value interface{}) string {
return fmt.Sprintf("%v > %v", Escape(field), c.Args.Add(value))
}
// G is an alias of GreaterThan.
func (c *Cond) G(field string, value interface{}) string {
return c.GreaterThan(field, value)
}
// GreaterEqualThan represents "field >= value".
func (c *Cond) GreaterEqualThan(field string, value interface{}) string {
return fmt.Sprintf("%v >= %v", Escape(field), c.Args.Add(value))
}
// GE is an alias of GreaterEqualThan.
func (c *Cond) GE(field string, value interface{}) string {
return c.GreaterEqualThan(field, value)
}
// LessThan represents "field < value".
func (c *Cond) LessThan(field string, value interface{}) string {
return fmt.Sprintf("%v < %v", Escape(field), c.Args.Add(value))
}
// L is an alias of LessThan.
func (c *Cond) L(field string, value interface{}) string {
return c.LessThan(field, value)
}
// LessEqualThan represents "field <= value".
func (c *Cond) LessEqualThan(field string, value interface{}) string {
return fmt.Sprintf("%v <= %v", Escape(field), c.Args.Add(value))
}
// LE is an alias of LessEqualThan.
func (c *Cond) LE(field string, value interface{}) string {
return c.LessEqualThan(field, value)
}
// In represents "field IN (value...)".
func (c *Cond) In(field string, value ...interface{}) string {
vs := make([]string, 0, len(value))
for _, v := range value {
vs = append(vs, c.Args.Add(v))
}
return fmt.Sprintf("%v IN (%v)", Escape(field), strings.Join(vs, ", "))
}
// NotIn represents "field NOT IN (value...)".
func (c *Cond) NotIn(field string, value ...interface{}) string {
vs := make([]string, 0, len(value))
for _, v := range value {
vs = append(vs, c.Args.Add(v))
}
return fmt.Sprintf("%v NOT IN (%v)", Escape(field), strings.Join(vs, ", "))
}
// Like represents "field LIKE value".
func (c *Cond) Like(field string, value interface{}) string {
return fmt.Sprintf("%v LIKE %v", Escape(field), c.Args.Add(value))
}
// NotLike represents "field NOT LIKE value".
func (c *Cond) NotLike(field string, value interface{}) string {
return fmt.Sprintf("%v NOT LIKE %v", Escape(field), c.Args.Add(value))
}
// IsNull represents "field IS NULL".
func (c *Cond) IsNull(field string) string {
return fmt.Sprintf("%v IS NULL", Escape(field))
}
// IsNotNull represents "field IS NOT NULL".
func (c *Cond) IsNotNull(field string) string {
return fmt.Sprintf("%v IS NOT NULL", Escape(field))
}
// Between represents "field BETWEEN lower AND upper".
func (c *Cond) Between(field string, lower, upper interface{}) string {
return fmt.Sprintf("%v BETWEEN %v AND %v", Escape(field), c.Args.Add(lower), c.Args.Add(upper))
}
// NotBetween represents "field NOT BETWEEN lower AND upper".
func (c *Cond) NotBetween(field string, lower, upper interface{}) string {
return fmt.Sprintf("%v NOT BETWEEN %v AND %v", Escape(field), c.Args.Add(lower), c.Args.Add(upper))
}
// Or represents OR logic like "expr1 OR expr2 OR expr3".
func (c *Cond) Or(orExpr ...string) string {
return fmt.Sprintf("(%v)", strings.Join(orExpr, " OR "))
}
// Var returns a placeholder for value.
func (c *Cond) Var(value interface{}) string {
return c.Args.Add(value)
}

View File

@ -0,0 +1,47 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"testing"
)
func TestCond(t *testing.T) {
cases := map[string]func() string{
"$$a = $0": func() string { return newTestCond().Equal("$a", 123) },
"$$b = $0": func() string { return newTestCond().E("$b", 123) },
"$$a <> $0": func() string { return newTestCond().NotEqual("$a", 123) },
"$$b <> $0": func() string { return newTestCond().NE("$b", 123) },
"$$a > $0": func() string { return newTestCond().GreaterThan("$a", 123) },
"$$b > $0": func() string { return newTestCond().G("$b", 123) },
"$$a >= $0": func() string { return newTestCond().GreaterEqualThan("$a", 123) },
"$$b >= $0": func() string { return newTestCond().GE("$b", 123) },
"$$a < $0": func() string { return newTestCond().LessThan("$a", 123) },
"$$b < $0": func() string { return newTestCond().L("$b", 123) },
"$$a <= $0": func() string { return newTestCond().LessEqualThan("$a", 123) },
"$$b <= $0": func() string { return newTestCond().LE("$b", 123) },
"$$a IN ($0, $1, $2)": func() string { return newTestCond().In("$a", 1, 2, 3) },
"$$a NOT IN ($0, $1, $2)": func() string { return newTestCond().NotIn("$a", 1, 2, 3) },
"$$a LIKE $0": func() string { return newTestCond().Like("$a", "%Huan%") },
"$$a NOT LIKE $0": func() string { return newTestCond().NotLike("$a", "%Huan%") },
"$$a IS NULL": func() string { return newTestCond().IsNull("$a") },
"$$a IS NOT NULL": func() string { return newTestCond().IsNotNull("$a") },
"$$a BETWEEN $0 AND $1": func() string { return newTestCond().Between("$a", 123, 456) },
"$$a NOT BETWEEN $0 AND $1": func() string { return newTestCond().NotBetween("$a", 123, 456) },
"(1 = 1 OR 2 = 2 OR 3 = 3)": func() string { return newTestCond().Or("1 = 1", "2 = 2", "3 = 3") },
"$0": func() string { return newTestCond().Var(123) },
}
for expected, f := range cases {
if actual := f(); expected != actual {
t.Fatalf("invalid result. [expected:%v] [actual:%v]", expected, actual)
}
}
}
func newTestCond() *Cond {
return &Cond{
Args: &Args{},
}
}

View File

@ -0,0 +1,57 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import "fmt"
// Supported flavors.
const (
invalidFlavor Flavor = iota
MySQL
PostgreSQL
)
var (
// DefaultFlavor is the default flavor for all builders.
DefaultFlavor = MySQL
)
// Flavor is the flag to control the format of compiled sql.
type Flavor int
// String returns the name of f.
func (f Flavor) String() string {
switch f {
case MySQL:
return "MySQL"
case PostgreSQL:
return "PostgreSQL"
}
return "<invalid>"
}
// NewSelectBuilder creates a new SELECT builder with flavor.
func (f Flavor) NewSelectBuilder() *SelectBuilder {
b := newSelectBuilder()
b.SetFlavor(f)
return b
}
// Quote adds quote for name to make sure the name can be used safely
// as table name or field name.
//
// * For MySQL, use back quote (`) to quote name;
// * For PostgreSQL, use double quote (") to quote name.
func (f Flavor) Quote(name string) string {
switch f {
case MySQL:
return fmt.Sprintf("`%v`", name)
case PostgreSQL:
return fmt.Sprintf(`"%v"`, name)
}
return name
}

View File

@ -0,0 +1,40 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"fmt"
"testing"
)
func TestFlavor(t *testing.T) {
cases := map[Flavor]string{
0: "<invalid>",
MySQL: "MySQL",
PostgreSQL: "PostgreSQL",
}
for f, expected := range cases {
if actual := f.String(); actual != expected {
t.Fatalf("invalid flavor name. [expected:%v] [actual:%v]", expected, actual)
}
}
}
func ExampleFlavor() {
// Create a flavored builder.
sb := PostgreSQL.NewSelectBuilder()
sb.Select("name").From("user").Where(
sb.E("id", 1234),
sb.G("rank", 3),
)
sql, args := sb.Build()
fmt.Println(sql)
fmt.Println(args)
// Output:
// SELECT name FROM user WHERE id = $1 AND rank > $2
// [1234 3]
}

View File

@ -0,0 +1,98 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"reflect"
"strings"
)
// Escape replaces `$` with `$$` in ident.
func Escape(ident string) string {
return strings.Replace(ident, "$", "$$", -1)
}
// EscapeAll replaces `$` with `$$` in all strings of ident.
func EscapeAll(ident ...string) []string {
escaped := make([]string, 0, len(ident))
for _, i := range ident {
escaped = append(escaped, Escape(i))
}
return escaped
}
// Flatten recursively extracts values in slices and returns
// a flattened []interface{} with all values.
// If slices is not a slice, return `[]interface{}{slices}`.
func Flatten(slices interface{}) (flattened []interface{}) {
v := reflect.ValueOf(slices)
slices, flattened = flatten(v)
if slices != nil {
return []interface{}{slices}
}
return flattened
}
func flatten(v reflect.Value) (elem interface{}, flattened []interface{}) {
k := v.Kind()
for k == reflect.Interface {
v = v.Elem()
k = v.Kind()
}
if k != reflect.Slice && k != reflect.Array {
return v.Interface(), nil
}
for i, l := 0, v.Len(); i < l; i++ {
e, f := flatten(v.Index(i))
if e == nil {
flattened = append(flattened, f...)
} else {
flattened = append(flattened, e)
}
}
return
}
type rawArgs struct {
expr string
}
// Raw marks the expr as a raw value which will not be added to args.
func Raw(expr string) interface{} {
return rawArgs{expr}
}
type listArgs struct {
args []interface{}
}
// List marks arg as a list of data.
// If arg is `[]int{1, 2, 3}`, it will be compiled to `?, ?, ?` with args `[1 2 3]`.
func List(arg interface{}) interface{} {
return listArgs{Flatten(arg)}
}
type namedArgs struct {
name string
arg interface{}
}
// Named creates a named argument.
// Unlike `sql.Named`, this named argument works only with `Build` or `BuildNamed` for convenience
// and will be replaced to a `?` after `Compile`.
func Named(name string, arg interface{}) interface{} {
return namedArgs{
name: name,
arg: arg,
}
}

View File

@ -0,0 +1,59 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"reflect"
"testing"
)
func TestEscape(t *testing.T) {
cases := map[string]string{
"foo": "foo",
"$foo": "$$foo",
"$$$": "$$$$$$",
}
var inputs, expects []string
for s, expected := range cases {
inputs = append(inputs, s)
expects = append(expects, expected)
if actual := Escape(s); actual != expected {
t.Fatalf("invalid escape result. [expected:%v] [actual:%v]", expected, actual)
}
}
actuals := EscapeAll(inputs...)
if !reflect.DeepEqual(expects, actuals) {
t.Fatalf("invalid escape result. [expected:%v] [actual:%v]", expects, actuals)
}
}
func TestFlatten(t *testing.T) {
cases := [][2]interface{}{
{
"foo",
[]interface{}{"foo"},
},
{
[]int{1, 2, 3},
[]interface{}{1, 2, 3},
},
{
[]interface{}{"abc", []int{1, 2, 3}, [3]string{"def", "ghi"}},
[]interface{}{"abc", 1, 2, 3, "def", "ghi", ""},
},
}
for _, c := range cases {
input, expected := c[0], c[1]
actual := Flatten(input)
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("invalid flatten result. [expected:%v] [actual:%v]", expected, actual)
}
}
}

View File

@ -0,0 +1,212 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"bytes"
"fmt"
"strconv"
"strings"
)
// NewSelectBuilder creates a new SELECT builder.
func NewSelectBuilder() *SelectBuilder {
return DefaultFlavor.NewSelectBuilder()
}
func newSelectBuilder() *SelectBuilder {
args := &Args{}
return &SelectBuilder{
Cond: Cond{
Args: args,
},
limit: -1,
offset: -1,
args: args,
}
}
// SelectBuilder is a builder to build SELECT.
type SelectBuilder struct {
Cond
distinct bool
tables []string
selectCols []string
whereExprs []string
havingExprs []string
groupByCols []string
orderByCols []string
order string
limit int
offset int
args *Args
}
// Distinct marks this SELECT as DISTINCT.
func (sb *SelectBuilder) Distinct() *SelectBuilder {
sb.distinct = true
return sb
}
// Select sets columns in SELECT.
func (sb *SelectBuilder) Select(col ...string) *SelectBuilder {
sb.selectCols = EscapeAll(col...)
return sb
}
// From sets table names in SELECT.
func (sb *SelectBuilder) From(table ...string) *SelectBuilder {
sb.tables = table
return sb
}
// Where sets expressions of WHERE in SELECT.
func (sb *SelectBuilder) Where(andExpr ...string) *SelectBuilder {
sb.whereExprs = append(sb.whereExprs, andExpr...)
return sb
}
// Having sets expressions of HAVING in SELECT.
func (sb *SelectBuilder) Having(andExpr ...string) *SelectBuilder {
sb.havingExprs = append(sb.havingExprs, andExpr...)
return sb
}
// GroupBy sets columns of GROUP BY in SELECT.
func (sb *SelectBuilder) GroupBy(col ...string) *SelectBuilder {
sb.groupByCols = EscapeAll(col...)
return sb
}
// OrderBy sets columns of ORDER BY in SELECT.
func (sb *SelectBuilder) OrderBy(col ...string) *SelectBuilder {
sb.orderByCols = EscapeAll(col...)
return sb
}
// Asc sets order of ORDER BY to ASC.
func (sb *SelectBuilder) Asc() *SelectBuilder {
sb.order = "ASC"
return sb
}
// Desc sets order of ORDER BY to DESC.
func (sb *SelectBuilder) Desc() *SelectBuilder {
sb.order = "DESC"
return sb
}
// Limit sets the LIMIT in SELECT.
func (sb *SelectBuilder) Limit(limit int) *SelectBuilder {
sb.limit = limit
return sb
}
// Offset sets the LIMIT offset in SELECT.
func (sb *SelectBuilder) Offset(offset int) *SelectBuilder {
sb.offset = offset
return sb
}
// As returns an AS expression.
func (sb *SelectBuilder) As(col, alias string) string {
return fmt.Sprintf("%v AS %v", col, Escape(alias))
}
// BuilderAs returns an AS expression wrapping a complex SQL.
// According to SQL syntax, SQL built by builder is surrounded by parens.
func (sb *SelectBuilder) BuilderAs(builder Builder, alias string) string {
return fmt.Sprintf("(%v) AS %v", sb.Var(builder), Escape(alias))
}
// String returns the compiled SELECT string.
func (sb *SelectBuilder) String() string {
s, _ := sb.Build()
return s
}
// Build returns compiled SELECT string and args.
// They can be used in `DB#Query` of package `database/sql` directly.
func (sb *SelectBuilder) Build() (sql string, args []interface{}) {
return sb.BuildWithFlavor(sb.args.Flavor)
}
// BuildWithFlavor returns compiled SELECT string and args with flavor and initial args.
// They can be used in `DB#Query` of package `database/sql` directly.
func (sb *SelectBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {
buf := &bytes.Buffer{}
buf.WriteString("SELECT ")
if sb.distinct {
buf.WriteString("DISTINCT ")
}
buf.WriteString(strings.Join(sb.selectCols, ", "))
buf.WriteString(" FROM ")
buf.WriteString(strings.Join(sb.tables, ", "))
if len(sb.whereExprs) > 0 {
buf.WriteString(" WHERE ")
buf.WriteString(strings.Join(sb.whereExprs, " AND "))
}
if len(sb.groupByCols) > 0 {
buf.WriteString(" GROUP BY ")
buf.WriteString(strings.Join(sb.groupByCols, ", "))
if len(sb.havingExprs) > 0 {
buf.WriteString(" HAVING ")
buf.WriteString(strings.Join(sb.havingExprs, " AND "))
}
}
if len(sb.orderByCols) > 0 {
buf.WriteString(" ORDER BY ")
buf.WriteString(strings.Join(sb.orderByCols, ", "))
if sb.order != "" {
buf.WriteRune(' ')
buf.WriteString(sb.order)
}
}
if sb.limit >= 0 {
buf.WriteString(" LIMIT ")
buf.WriteString(strconv.Itoa(sb.limit))
if sb.offset >= 0 {
buf.WriteString(" OFFSET ")
buf.WriteString(strconv.Itoa(sb.offset))
}
}
return sb.Args.CompileWithFlavor(buf.String(), flavor, initialArg...)
}
// SetFlavor sets the flavor of compiled sql.
func (sb *SelectBuilder) SetFlavor(flavor Flavor) (old Flavor) {
old = sb.args.Flavor
sb.args.Flavor = flavor
return
}
// Copy the builder
func (sb *SelectBuilder) Copy() *SelectBuilder {
return &SelectBuilder{
Cond: sb.Cond,
distinct: sb.distinct,
tables: sb.tables,
selectCols: sb.selectCols,
whereExprs: sb.whereExprs,
havingExprs: sb.havingExprs,
groupByCols: sb.groupByCols,
orderByCols: sb.orderByCols,
order: sb.order,
limit: sb.limit,
offset: sb.offset,
args: sb.args.Copy(),
}
}

View File

@ -0,0 +1,67 @@
// Copyright 2018 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.
package sqlbuilder
import (
"database/sql"
"fmt"
)
func ExampleSelectBuilder() {
sb := NewSelectBuilder()
sb.Distinct().Select("id", "name", sb.As("COUNT(*)", "t"))
sb.From("demo.user")
sb.Where(
sb.GreaterThan("id", 1234),
sb.Like("name", "%Du"),
sb.Or(
sb.IsNull("id_card"),
sb.In("status", 1, 2, 5),
),
sb.NotIn(
"id",
NewSelectBuilder().Select("id").From("banned"),
), // Nested SELECT.
"modified_at > created_at + "+sb.Var(86400), // It's allowed to write arbitrary SQL.
)
sb.GroupBy("status").Having(sb.NotIn("status", 4, 5))
sb.OrderBy("modified_at").Asc()
sb.Limit(10).Offset(5)
sql, args := sb.Build()
fmt.Println(sql)
fmt.Println(args)
// Output:
// SELECT DISTINCT id, name, COUNT(*) AS t FROM demo.user WHERE id > ? AND name LIKE ? AND (id_card IS NULL OR status IN (?, ?, ?)) AND id NOT IN (SELECT id FROM banned) AND modified_at > created_at + ? GROUP BY status HAVING status NOT IN (?, ?) ORDER BY modified_at ASC LIMIT 10 OFFSET 5
// [1234 %Du 1 2 5 86400 4 5]
}
func ExampleSelectBuilder_advancedUsage() {
sb := NewSelectBuilder()
innerSb := NewSelectBuilder()
sb.Select("id", "name")
sb.From(
sb.BuilderAs(innerSb, "user"),
)
sb.Where(
sb.In("status", Flatten([]int{1, 2, 3})...),
sb.Between("created_at", sql.Named("start", 1234567890), sql.Named("end", 1234599999)),
)
innerSb.Select("*")
innerSb.From("banned")
innerSb.Where(
innerSb.NotIn("name", Flatten([]string{"Huan Du", "Charmy Liu"})...),
)
sql, args := sb.Build()
fmt.Println(sql)
fmt.Println(args)
// Output:
// SELECT id, name FROM (SELECT * FROM banned WHERE name NOT IN (?, ?)) AS user WHERE status IN (?, ?, ?) AND created_at BETWEEN @start AND @end
// [Huan Du Charmy Liu 1 2 3 {{} start 1234567890} {{} end 1234599999}]
}

View File

@ -0,0 +1,140 @@
package dao
import (
"context"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"strconv"
"time"
"go-common/app/service/main/account-recovery/model"
"go-common/library/database/elastic"
"go-common/library/log"
)
// NickNameLog NickNameLog
func (d *Dao) NickNameLog(c context.Context, nickNameReq *model.NickNameReq) (res *model.NickNameLogRes, err error) {
nowYear := time.Now().Year()
index1 := "log_user_action_14_" + strconv.Itoa(nowYear)
index2 := "log_user_action_14_" + strconv.Itoa(nowYear-1)
r := d.es.NewRequest("log_user_action").Fields("str_0", "str_1").Index(index1, index2)
r.Order("ctime", elastic.OrderDesc).Order("mid", elastic.OrderDesc).Pn(nickNameReq.Page).Ps(nickNameReq.Size)
if nickNameReq.Mid != 0 {
r.WhereEq("mid", nickNameReq.Mid)
}
if nickNameReq.From != 0 && nickNameReq.To != 0 {
ftm := time.Unix(nickNameReq.From, 0)
sf := ftm.Format("2006-01-02 15:04:05")
ttm := time.Unix(nickNameReq.To, 0)
tf := ttm.Format("2006-01-02 15:04:05")
r.WhereRange("ctime", sf, tf, elastic.RangeScopeLoRo)
}
esres := new(model.NickESRes)
if err = r.Scan(context.TODO(), &esres); err != nil {
log.Error("nickNameLog search error(%v)", err)
return
}
var nickNames = make([]*model.NickNameInfo, 0)
for _, value := range esres.Result {
ulr := model.NickNameInfo{OldName: value.OldName, NewName: value.NewName}
nickNames = append(nickNames, &ulr)
}
res = &model.NickNameLogRes{Page: esres.Page, Result: nickNames}
return
}
type userLogsExtra struct {
EncryptTel string `json:"tel"`
EncryptEmail string `json:"email"`
}
// UserBindLog User bind log
func (d *Dao) UserBindLog(c context.Context, userActLogReq *model.UserBindLogReq) (res *model.UserBindLogRes, err error) {
e := d.es
nowYear := time.Now().Year()
var count = 2 //默认查询两年
//2016年就有了手机历史记录此处需要循环建立索引 , 2018年才有邮箱这个功能
if userActLogReq.Action == "telBindLog" {
count = nowYear - 2015
}
if userActLogReq.Action == "emailBindLog" {
count = nowYear - 2017
}
indexs := make([]string, count)
for i := 0; i < count; i++ {
indexs[i] = "log_user_action_54_" + strconv.Itoa(nowYear-i)
}
r := e.NewRequest("log_user_action").Fields("mid", "str_0", "extra_data", "ctime").Index(indexs...)
r.Order("ctime", elastic.OrderDesc).Order("mid", elastic.OrderDesc).Pn(userActLogReq.Page).Ps(userActLogReq.Size)
if userActLogReq.Mid != 0 {
r.WhereEq("mid", userActLogReq.Mid)
}
if userActLogReq.Query != "" {
hash := sha1.New()
hash.Write([]byte(userActLogReq.Query))
telHash := base64.StdEncoding.EncodeToString(hash.Sum(d.hashSalt))
r.WhereEq("str_0", telHash)
}
if userActLogReq.Action != "" {
r.WhereEq("action", userActLogReq.Action)
}
if userActLogReq.From != 0 && userActLogReq.To != 0 {
ftm := time.Unix(userActLogReq.From, 0)
sf := ftm.Format("2006-01-02 15:04:05")
ttm := time.Unix(userActLogReq.To, 0)
tf := ttm.Format("2006-01-02 15:04:05")
r.WhereRange("ctime", sf, tf, elastic.RangeScopeLoRo)
}
esres := new(model.EsRes)
if err = r.Scan(context.Background(), &esres); err != nil {
log.Error("userActLogs search error(%v)", err)
return
}
var userBindLogs = make([]*model.UserBindLog, 0)
for _, value := range esres.Result {
var email, tel string
//var model.UserBindLog
userLogExtra := userLogsExtra{}
err = json.Unmarshal([]byte(value.ExtraData), &userLogExtra)
if err != nil {
log.Error("cannot convert json(%s) to struct,err(%+v) ", value.ExtraData, err)
continue
}
if userLogExtra.EncryptEmail != "" {
email, err = d.decrypt(userLogExtra.EncryptEmail)
if err != nil {
log.Error("EncryptEmail decode err(%v)", err)
continue
}
}
if userLogExtra.EncryptTel != "" {
tel, err = d.decrypt(userLogExtra.EncryptTel)
if err != nil {
log.Error("EncryptTel decode err(%v)", err)
continue
}
}
ulr := model.UserBindLog{Mid: value.Mid, Email: email, Phone: tel, Time: value.CTime}
userBindLogs = append(userBindLogs, &ulr)
}
res = &model.UserBindLogRes{Page: esres.Page, Result: userBindLogs}
return
}

View File

@ -0,0 +1,42 @@
package dao
import (
"context"
"testing"
"go-common/app/service/main/account-recovery/model"
"github.com/smartystreets/goconvey/convey"
)
func TestDaoNickNameLog(t *testing.T) {
var (
c = context.Background()
nickNameReq = &model.NickNameReq{Mid: 111001254}
)
convey.Convey("NickNameLog", t, func(ctx convey.C) {
res, err := d.NickNameLog(c, nickNameReq)
ctx.Convey("Then err should be nil.res should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(res, convey.ShouldNotBeNil)
})
})
}
func TestDaoUserBindLog(t *testing.T) {
var (
c = context.Background()
emailReq = &model.UserBindLogReq{
Action: "emailBindLog",
Mid: 23,
Size: 100,
}
)
convey.Convey("UserBindLog", t, func(ctx convey.C) {
res, err := d.UserBindLog(c, emailReq)
ctx.Convey("Then err should be nil.res should not be nil.", func(ctx convey.C) {
ctx.So(err, convey.ShouldBeNil)
ctx.So(res, convey.ShouldNotBeNil)
})
})
}