Create & Init Project...
This commit is contained in:
86
app/service/main/account-recovery/dao/BUILD
Normal file
86
app/service/main/account-recovery/dao/BUILD
Normal file
@ -0,0 +1,86 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
load(
|
||||
"@io_bazel_rules_go//go:def.bzl",
|
||||
"go_library",
|
||||
"go_test",
|
||||
)
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = [
|
||||
"aes.go",
|
||||
"captcha.go",
|
||||
"dao.go",
|
||||
"email.go",
|
||||
"mid_info.go",
|
||||
"mysql.go",
|
||||
"redis.go",
|
||||
"req_rpc.go",
|
||||
"user_act_log.go",
|
||||
],
|
||||
importpath = "go-common/app/service/main/account-recovery/dao",
|
||||
tags = ["automanaged"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//app/service/main/account-recovery/conf:go_default_library",
|
||||
"//app/service/main/account-recovery/dao/sqlbuilder:go_default_library",
|
||||
"//app/service/main/account-recovery/model:go_default_library",
|
||||
"//app/service/main/account/api:go_default_library",
|
||||
"//app/service/main/location/model:go_default_library",
|
||||
"//app/service/main/location/rpc/client:go_default_library",
|
||||
"//app/service/main/member/api:go_default_library",
|
||||
"//library/cache/redis:go_default_library",
|
||||
"//library/database/elastic:go_default_library",
|
||||
"//library/database/sql:go_default_library",
|
||||
"//library/ecode:go_default_library",
|
||||
"//library/log:go_default_library",
|
||||
"//library/net/http/blademaster:go_default_library",
|
||||
"//library/net/metadata:go_default_library",
|
||||
"//library/time:go_default_library",
|
||||
"//library/xstr:go_default_library",
|
||||
"//vendor/github.com/pkg/errors:go_default_library",
|
||||
"//vendor/gopkg.in/gomail.v2:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "package-srcs",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["automanaged"],
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all-srcs",
|
||||
srcs = [
|
||||
":package-srcs",
|
||||
"//app/service/main/account-recovery/dao/sqlbuilder:all-srcs",
|
||||
],
|
||||
tags = ["automanaged"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "go_default_test",
|
||||
srcs = [
|
||||
"captcha_test.go",
|
||||
"dao_test.go",
|
||||
"email_test.go",
|
||||
"mid_info_test.go",
|
||||
"mysql_test.go",
|
||||
"redis_test.go",
|
||||
"req_rpc_test.go",
|
||||
"user_act_log_test.go",
|
||||
],
|
||||
embed = [":go_default_library"],
|
||||
rundir = ".",
|
||||
tags = ["automanaged"],
|
||||
deps = [
|
||||
"//app/service/main/account-recovery/conf:go_default_library",
|
||||
"//app/service/main/account-recovery/model:go_default_library",
|
||||
"//library/time:go_default_library",
|
||||
"//vendor/github.com/smartystreets/goconvey/convey:go_default_library",
|
||||
"//vendor/gopkg.in/h2non/gock.v1:go_default_library",
|
||||
],
|
||||
)
|
63
app/service/main/account-recovery/dao/aes.go
Normal file
63
app/service/main/account-recovery/dao/aes.go
Normal file
@ -0,0 +1,63 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
)
|
||||
|
||||
//func pad(src []byte) []byte {
|
||||
// padding := aes.BlockSize - len(src)%aes.BlockSize
|
||||
// padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
// return append(src, padText...)
|
||||
//}
|
||||
|
||||
func unpad(src []byte) ([]byte, error) {
|
||||
length := len(src)
|
||||
unpadding := int(src[length-1])
|
||||
|
||||
if unpadding > length {
|
||||
return nil, errors.New("unpad error. This could happen when incorrect encryption key is used")
|
||||
}
|
||||
|
||||
return src[:(length - unpadding)], nil
|
||||
}
|
||||
|
||||
//func (s *Service) encrypt(text string) (string, error) {
|
||||
// msg := pad([]byte(text))
|
||||
// cipherText := make([]byte, aes.BlockSize+len(msg))
|
||||
// iv := cipherText[:aes.BlockSize]
|
||||
// if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
//
|
||||
// cfb := cipher.NewCFBEncrypter(s.AESBlock, iv)
|
||||
// cfb.XORKeyStream(cipherText[aes.BlockSize:], []byte(msg))
|
||||
// finalMsg := base64.URLEncoding.EncodeToString(cipherText)
|
||||
// return finalMsg, nil
|
||||
//}
|
||||
|
||||
func (d *Dao) decrypt(text string) (string, error) {
|
||||
decodedMsg, err := base64.URLEncoding.DecodeString(text)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if (len(decodedMsg) % aes.BlockSize) != 0 {
|
||||
return "", errors.New("blocksize must be multipe of decoded message length")
|
||||
}
|
||||
|
||||
iv := decodedMsg[:aes.BlockSize]
|
||||
msg := decodedMsg[aes.BlockSize:]
|
||||
|
||||
cfb := cipher.NewCFBDecrypter(d.AESBlock, iv)
|
||||
cfb.XORKeyStream(msg, msg)
|
||||
|
||||
unpadMsg, err := unpad(msg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(unpadMsg), nil
|
||||
}
|
47
app/service/main/account-recovery/dao/captcha.go
Normal file
47
app/service/main/account-recovery/dao/captcha.go
Normal file
@ -0,0 +1,47 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"go-common/app/service/main/account-recovery/model"
|
||||
"go-common/library/ecode"
|
||||
"go-common/library/log"
|
||||
"go-common/library/net/metadata"
|
||||
)
|
||||
|
||||
// GetToken get open token.
|
||||
func (d *Dao) GetToken(c context.Context, bid string) (res *model.TokenResq, err error) {
|
||||
params := url.Values{}
|
||||
params.Add("bid", bid)
|
||||
if err = d.httpClient.Get(c, d.c.CaptchaConf.TokenURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
|
||||
log.Error("GetToken HTTP request err(%v)", err)
|
||||
return
|
||||
}
|
||||
if res.Code != 0 {
|
||||
log.Error("GetToken service return err(%v)", res.Code)
|
||||
err = ecode.Int(int(res.Code))
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Verify verify code.
|
||||
func (d *Dao) Verify(c context.Context, code, token string) (ok bool, err error) {
|
||||
params := url.Values{}
|
||||
params.Add("token", token)
|
||||
params.Add("code", code)
|
||||
res := new(struct {
|
||||
Code int `json:"code"`
|
||||
})
|
||||
if err = d.httpClient.Post(c, d.c.CaptchaConf.VerifyURL, metadata.String(c, metadata.RemoteIP), params, res); err != nil {
|
||||
log.Error("Verify HTTP request err(%v)", err)
|
||||
return
|
||||
}
|
||||
if res.Code != 0 {
|
||||
log.Error("Verify service return err(%v)", res.Code)
|
||||
err = ecode.Int(res.Code)
|
||||
return
|
||||
}
|
||||
return true, nil
|
||||
}
|
22
app/service/main/account-recovery/dao/captcha_test.go
Normal file
22
app/service/main/account-recovery/dao/captcha_test.go
Normal file
@ -0,0 +1,22 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
func TestDaoGetToken(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
bid = "account"
|
||||
)
|
||||
convey.Convey("GetToken", t, func(ctx convey.C) {
|
||||
res, err := d.GetToken(c, bid)
|
||||
ctx.Convey("Then err should be nil.res should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(res, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
90
app/service/main/account-recovery/dao/dao.go
Normal file
90
app/service/main/account-recovery/dao/dao.go
Normal file
@ -0,0 +1,90 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
|
||||
"go-common/app/service/main/account-recovery/conf"
|
||||
account "go-common/app/service/main/account/api"
|
||||
location "go-common/app/service/main/location/rpc/client"
|
||||
member "go-common/app/service/main/member/api"
|
||||
"go-common/library/cache/redis"
|
||||
"go-common/library/database/elastic"
|
||||
xsql "go-common/library/database/sql"
|
||||
bm "go-common/library/net/http/blademaster"
|
||||
|
||||
"gopkg.in/gomail.v2"
|
||||
)
|
||||
|
||||
// Dao dao
|
||||
type Dao struct {
|
||||
c *conf.Config
|
||||
redis *redis.Pool
|
||||
db *xsql.DB
|
||||
// httpClient
|
||||
httpClient *bm.Client
|
||||
|
||||
// email
|
||||
email *gomail.Dialer
|
||||
es *elastic.Elastic
|
||||
|
||||
// rpc
|
||||
locRPC *location.Service
|
||||
|
||||
// grpc
|
||||
memberClient member.MemberClient
|
||||
accountClient account.AccountClient
|
||||
|
||||
hashSalt []byte
|
||||
AESBlock cipher.Block
|
||||
}
|
||||
|
||||
// New init mysql db
|
||||
func New(c *conf.Config) (dao *Dao) {
|
||||
dao = &Dao{
|
||||
c: c,
|
||||
redis: redis.NewPool(c.Redis),
|
||||
db: xsql.NewMySQL(c.MySQL),
|
||||
// httpClient
|
||||
httpClient: bm.NewClient(c.HTTPClientConfig),
|
||||
|
||||
email: gomail.NewDialer(c.MailConf.Host, c.MailConf.Port, c.MailConf.Username, c.MailConf.Password),
|
||||
es: elastic.NewElastic(c.Elastic),
|
||||
locRPC: location.New(c.LocationRPC),
|
||||
hashSalt: []byte(c.AESEncode.Salt),
|
||||
}
|
||||
dao.email.TLSConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
dao.AESBlock, _ = aes.NewCipher([]byte(c.AESEncode.AesKey))
|
||||
|
||||
var err error
|
||||
if dao.memberClient, err = member.NewClient(c.MemberGRPC); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if dao.accountClient, err = account.NewClient(c.AccountGRPC); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Close close the resource.
|
||||
func (d *Dao) Close() {
|
||||
d.redis.Close()
|
||||
d.db.Close()
|
||||
}
|
||||
|
||||
// Ping dao ping
|
||||
func (d *Dao) Ping(c context.Context) (err error) {
|
||||
if err = d.db.Ping(c); err != nil {
|
||||
return
|
||||
}
|
||||
if err = d.PingRedis(c); err != nil {
|
||||
return
|
||||
}
|
||||
// TODO: if you need use mc,redis, please add
|
||||
return
|
||||
}
|
46
app/service/main/account-recovery/dao/dao_test.go
Normal file
46
app/service/main/account-recovery/dao/dao_test.go
Normal file
@ -0,0 +1,46 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"go-common/app/service/main/account-recovery/conf"
|
||||
|
||||
"gopkg.in/h2non/gock.v1"
|
||||
)
|
||||
|
||||
var (
|
||||
d *Dao
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if os.Getenv("DEPLOY_ENV") != "" {
|
||||
flag.Set("app_id", "main.account.account-recovery")
|
||||
flag.Set("conf_token", "5fe12bcbf11eb0c368ee0c2d1f567184")
|
||||
flag.Set("tree_id", "55382")
|
||||
flag.Set("conf_version", "docker-1")
|
||||
flag.Set("deploy_env", "uat")
|
||||
flag.Set("conf_host", "config.bilibili.co")
|
||||
flag.Set("conf_path", "/tmp")
|
||||
flag.Set("region", "sh")
|
||||
flag.Set("zone", "sh001")
|
||||
} else {
|
||||
flag.Set("conf", "../cmd/account-recovery-test.toml")
|
||||
}
|
||||
flag.Parse()
|
||||
if err := conf.Init(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
d = New(conf.Conf)
|
||||
d.httpClient.SetTransport(gock.DefaultTransport)
|
||||
m.Run()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func httpMock(method, url string) *gock.Request {
|
||||
r := gock.New(url)
|
||||
r.Method = strings.ToUpper(method)
|
||||
return r
|
||||
}
|
23
app/service/main/account-recovery/dao/email.go
Normal file
23
app/service/main/account-recovery/dao/email.go
Normal file
@ -0,0 +1,23 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"go-common/app/service/main/account-recovery/conf"
|
||||
"go-common/library/log"
|
||||
|
||||
"gopkg.in/gomail.v2"
|
||||
)
|
||||
|
||||
// SendMail send the email.
|
||||
func (d *Dao) SendMail(body string, subject string, send ...string) (err error) {
|
||||
log.Info("send mail send:%v", send)
|
||||
msg := gomail.NewMessage()
|
||||
msg.SetHeader("From", conf.Conf.MailConf.Username)
|
||||
msg.SetHeader("To", send...)
|
||||
msg.SetHeader("Subject", subject)
|
||||
msg.SetBody("text/html", body, gomail.SetPartEncoding(gomail.Base64))
|
||||
if err = d.email.DialAndSend(msg); err != nil {
|
||||
log.Error("s.email.DialAndSend error(%v)", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
23
app/service/main/account-recovery/dao/email_test.go
Normal file
23
app/service/main/account-recovery/dao/email_test.go
Normal file
@ -0,0 +1,23 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
func TestDaoSendMail(t *testing.T) {
|
||||
var (
|
||||
body = strconv.Itoa(rand.Intn(100))
|
||||
subject = "邮件测试hyy"
|
||||
send = "2459593393@qq.com"
|
||||
)
|
||||
convey.Convey("SendMail", t, func(ctx convey.C) {
|
||||
err := d.SendMail(body, subject, send)
|
||||
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
})
|
||||
})
|
||||
}
|
272
app/service/main/account-recovery/dao/mid_info.go
Normal file
272
app/service/main/account-recovery/dao/mid_info.go
Normal file
@ -0,0 +1,272 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"go-common/app/service/main/account-recovery/model"
|
||||
"go-common/library/ecode"
|
||||
"go-common/library/log"
|
||||
"go-common/library/net/metadata"
|
||||
)
|
||||
|
||||
// GetMidInfo get mid info by more condition
|
||||
func (d *Dao) GetMidInfo(c context.Context, qType string, qKey string) (v *model.MIDInfo, err error) {
|
||||
params := url.Values{}
|
||||
params.Set("q_type", qType)
|
||||
params.Set("q_key", qKey)
|
||||
res := new(struct {
|
||||
Code int `json:"code"`
|
||||
Data model.MIDInfo `json:"data"`
|
||||
})
|
||||
if err = d.httpClient.Get(c, d.c.AccRecover.MidInfoURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
|
||||
log.Error("GetMidInfo HTTP request err %+v", err)
|
||||
return
|
||||
}
|
||||
if res.Code != 0 {
|
||||
log.Error("GetMidInfo server err_code %d,params: qType=%s,qKey=%s", res.Code, qType, qKey)
|
||||
err = ecode.ServerErr
|
||||
return
|
||||
}
|
||||
log.Info("GetMidInfo url=%v, params: qType=%s,qKey=%s, res: %+v", d.c.AccRecover.MidInfoURL, qType, qKey, res)
|
||||
return &res.Data, nil
|
||||
}
|
||||
|
||||
// GetUserInfo get user info by mid
|
||||
func (d *Dao) GetUserInfo(c context.Context, mid int64) (v *model.UserInfo, err error) {
|
||||
params := url.Values{}
|
||||
params.Add("mid", strconv.Itoa(int(mid)))
|
||||
res := new(struct {
|
||||
Code int `json:"code"`
|
||||
Data model.UserInfo `json:"data"`
|
||||
})
|
||||
if err = d.httpClient.Get(c, d.c.AccRecover.GetUserInfoURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
|
||||
log.Error("GetUserInfo HTTP request err %+v", err)
|
||||
return
|
||||
}
|
||||
if res.Code != 0 {
|
||||
log.Error("GetUserInfo server err_code %d,params: mid=%d", res.Code, mid)
|
||||
err = ecode.ServerErr
|
||||
return
|
||||
}
|
||||
log.Info("GetUserInfo url=%v, params: mid=%d, res: %+v", d.c.AccRecover.GetUserInfoURL, mid, res)
|
||||
return &res.Data, nil
|
||||
}
|
||||
|
||||
// UpdatePwd update password
|
||||
func (d *Dao) UpdatePwd(c context.Context, mid int64, operator string) (user *model.User, err error) {
|
||||
params := url.Values{}
|
||||
params.Set("mid", strconv.Itoa(int(mid)))
|
||||
params.Set("operator", operator)
|
||||
res := new(struct {
|
||||
Code int `json:"code"`
|
||||
Data model.User `json:"data"`
|
||||
})
|
||||
if err = d.httpClient.Post(c, d.c.AccRecover.UpPwdURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
|
||||
log.Error("UpdatePwd HTTP request err %+v", err)
|
||||
return
|
||||
}
|
||||
if res.Code != 0 {
|
||||
log.Error("UpdatePwd server err_code %d,params: mid=%d", res.Code, mid)
|
||||
err = ecode.Int(res.Code)
|
||||
return
|
||||
}
|
||||
log.Info("UpdatePwd url=%v, params: mid=%d,operator=%s, res: %+v", d.c.AccRecover.UpPwdURL, mid, operator, res)
|
||||
return &res.Data, nil
|
||||
}
|
||||
|
||||
// CheckSafe safe info
|
||||
func (d *Dao) CheckSafe(c context.Context, mid int64, question int8, answer string) (check *model.Check, err error) {
|
||||
params := url.Values{}
|
||||
params.Add("mid", strconv.Itoa(int(mid)))
|
||||
params.Add("question", strconv.Itoa(int(question)))
|
||||
params.Add("answer", answer)
|
||||
res := new(struct {
|
||||
Code int `json:"code"`
|
||||
Data model.Check `json:"data"`
|
||||
})
|
||||
if err = d.httpClient.Post(c, d.c.AccRecover.CheckSafeURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
|
||||
log.Error("CheckSafe HTTP request err %+v", err)
|
||||
return
|
||||
}
|
||||
if res.Code != 0 {
|
||||
log.Error("CheckSafe server err_code %d,params: mid=%d,question=%d,answer=%s", res.Code, mid, question, answer)
|
||||
err = ecode.Int(res.Code)
|
||||
return
|
||||
}
|
||||
log.Info("CheckSafe url=%v, params: mid=%d,question=%d,answer=%s, res: %+v", d.c.AccRecover.CheckSafeURL, mid, question, answer, res)
|
||||
return &res.Data, nil
|
||||
}
|
||||
|
||||
// GetUserType get user_type
|
||||
func (d *Dao) GetUserType(c context.Context, mid int64) (gams []*model.Game, err error) {
|
||||
params := url.Values{}
|
||||
params.Add("mid", strconv.Itoa(int(mid)))
|
||||
res := new(struct {
|
||||
Code int `json:"code"`
|
||||
Data []*model.Game `json:"items"`
|
||||
})
|
||||
if err = d.httpClient.Get(c, d.c.AccRecover.GameURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
|
||||
log.Error("GetUserType HTTP request err %+v", err)
|
||||
return
|
||||
}
|
||||
if res.Code != 0 {
|
||||
log.Error("GetUserType server err_code %d,params: mid=%d", res.Code, mid)
|
||||
err = ecode.Int(res.Code)
|
||||
return
|
||||
}
|
||||
log.Info("GetUserType url=%v, params: mid=%d, res: %+v", d.c.AccRecover.GameURL, mid, res)
|
||||
return res.Data, nil
|
||||
}
|
||||
|
||||
// CheckReg check reg info
|
||||
func (d *Dao) CheckReg(c context.Context, mid int64, regTime int64, regType int8, regAddr string) (v *model.Check, err error) {
|
||||
params := url.Values{}
|
||||
params.Add("mid", strconv.Itoa(int(mid)))
|
||||
params.Add("reg_time", strconv.FormatInt(regTime, 10))
|
||||
params.Add("reg_type", strconv.Itoa(int(regType)))
|
||||
params.Add("reg_addr", regAddr)
|
||||
res := new(struct {
|
||||
Code int `json:"code"`
|
||||
Data model.Check `json:"data"`
|
||||
})
|
||||
if err = d.httpClient.Post(c, d.c.AccRecover.CheckRegURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
|
||||
log.Error("CheckReg HTTP request err %+v", err)
|
||||
return
|
||||
}
|
||||
if res.Code != 0 {
|
||||
log.Error("CheckReg server err_code %d,params: mid=%d,regTime=%d,regType=%d,regAddr=%s", res.Code, mid, regTime, regType, regAddr)
|
||||
err = ecode.Int(res.Code)
|
||||
return
|
||||
}
|
||||
log.Info("CheckReg url=%v, params: mid=%d,regTime=%d,regType=%d,regAddr=%s, res: %+v", d.c.AccRecover.CheckRegURL, mid, regTime, regType, regAddr, res)
|
||||
return &res.Data, nil
|
||||
}
|
||||
|
||||
// UpdateBatchPwd batch update password
|
||||
func (d *Dao) UpdateBatchPwd(c context.Context, mids string, operator string) (userMap map[string]*model.User, err error) {
|
||||
params := url.Values{}
|
||||
params.Set("mids", mids)
|
||||
params.Set("operator", operator)
|
||||
res := new(struct {
|
||||
Code int `json:"code"`
|
||||
Data map[string]*model.User `json:"data"`
|
||||
})
|
||||
if err = d.httpClient.Post(c, d.c.AccRecover.UpBatchPwdURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
|
||||
log.Error("UpdateBatchPwd HTTP request err %+v", err)
|
||||
return
|
||||
}
|
||||
if res.Code != 0 {
|
||||
log.Error("UpdateBatchPwd server err_code %d,params: mids=%s", res.Code, mids)
|
||||
err = ecode.Int(res.Code)
|
||||
return
|
||||
}
|
||||
log.Info("UpdateBatchPwd url=%v, params: mid=%s,operator=%s, res: %+v", d.c.AccRecover.UpBatchPwdURL, mids, operator, res)
|
||||
return res.Data, nil
|
||||
}
|
||||
|
||||
// CheckCard check card
|
||||
func (d *Dao) CheckCard(c context.Context, mid int64, cardType int8, cardCode string) (ok bool, err error) {
|
||||
params := url.Values{}
|
||||
params.Set("mid", strconv.FormatInt(mid, 10))
|
||||
params.Set("card_type", strconv.Itoa(int(cardType)))
|
||||
params.Set("card_code", cardCode)
|
||||
res := new(struct {
|
||||
Code int `json:"code"`
|
||||
Data bool `json:"data"`
|
||||
})
|
||||
if err = d.httpClient.Get(c, d.c.AccRecover.CheckCardURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
|
||||
log.Error("CheckCard HTTP request err %+v", err)
|
||||
return
|
||||
}
|
||||
if res.Code != 0 {
|
||||
log.Error("CheckCard server err_code %d,params: mid=%d,cardType=%d,cardCode=%s", res.Code, mid, cardType, cardCode)
|
||||
err = ecode.Int(res.Code)
|
||||
return
|
||||
}
|
||||
log.Info("CheckCard url=%v, params: mid=%d,cardType=%d,cardCode=%s, res: %+v", d.c.AccRecover.CheckCardURL, mid, cardType, cardCode, res)
|
||||
return res.Data, nil
|
||||
}
|
||||
|
||||
// CheckPwds check pwd
|
||||
func (d *Dao) CheckPwds(c context.Context, mid int64, pwds string) (v string, err error) {
|
||||
params := url.Values{}
|
||||
params.Set("mid", strconv.FormatInt(mid, 10))
|
||||
params.Set("pwd", pwds)
|
||||
res := new(struct {
|
||||
Code int `json:"code"`
|
||||
Data string `json:"data"`
|
||||
})
|
||||
if err = d.httpClient.Post(c, d.c.AccRecover.CheckPwdURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
|
||||
log.Error("CheckPwds HTTP request err %+v", err)
|
||||
return
|
||||
}
|
||||
if res.Code != 0 {
|
||||
log.Error("CheckPwds server err_code %d,params: mid=%d,pwds=%s", res.Code, mid, pwds)
|
||||
err = ecode.Int(res.Code)
|
||||
return
|
||||
}
|
||||
log.Info("CheckPwds url=%v, params: mid=%d,pwds=%s, res: %+v", d.c.AccRecover.CheckPwdURL, mid, pwds, res)
|
||||
return res.Data, nil
|
||||
}
|
||||
|
||||
// GetLoginIPs get login ip
|
||||
func (d *Dao) GetLoginIPs(c context.Context, mid int64, limit int64) (ipInfo []*model.LoginIPInfo, err error) {
|
||||
params := url.Values{}
|
||||
params.Set("mid", strconv.FormatInt(mid, 10))
|
||||
params.Set("limit", strconv.FormatInt(limit, 10))
|
||||
res := new(struct {
|
||||
Code int `json:"code"`
|
||||
Data []*model.LoginIPInfo `json:"data"`
|
||||
})
|
||||
if err = d.httpClient.Get(c, d.c.AccRecover.GetLoginIPURL, metadata.String(c, metadata.RemoteIP), params, &res); err != nil {
|
||||
log.Error("GetLoginIPs HTTP request err %+v", err)
|
||||
return
|
||||
}
|
||||
if res.Code != 0 {
|
||||
log.Error("GetLoginIPs server err_code %d,params: mid=%d,limit=%d", res.Code, mid, limit)
|
||||
err = ecode.ServerErr
|
||||
return
|
||||
}
|
||||
log.Info("GetLoginIPs url=%v, params: mid=%d,limit=%d, res: %+v", d.c.AccRecover.GetLoginIPURL, mid, limit, res)
|
||||
return res.Data, nil
|
||||
}
|
||||
|
||||
// GetAddrByIP get addr by ip
|
||||
func (d *Dao) GetAddrByIP(c context.Context, mid int64, limit int64) (addrs string, err error) {
|
||||
ipInfo, err := d.GetLoginIPs(c, mid, limit)
|
||||
if err != nil || len(ipInfo) == 0 {
|
||||
return
|
||||
}
|
||||
var ipLen = len(ipInfo)
|
||||
ips := make([]string, 0, ipLen)
|
||||
//ip去重复和空串
|
||||
for i := 0; i < ipLen; i++ {
|
||||
if (i > 0 && ipInfo[i-1].LoginIP == ipInfo[i].LoginIP) || len(ipInfo[i].LoginIP) == 0 {
|
||||
continue
|
||||
}
|
||||
ips = append(ips, ipInfo[i].LoginIP)
|
||||
}
|
||||
ipMap, err := d.Infos(c, ips)
|
||||
i := 0
|
||||
for _, loc := range ipMap {
|
||||
if loc.Country != "" {
|
||||
addrs += loc.Country + "-"
|
||||
}
|
||||
if loc.Province != "" {
|
||||
addrs += loc.Province + "-"
|
||||
}
|
||||
if loc.City != "" {
|
||||
addrs += loc.City + "-"
|
||||
}
|
||||
addrs = strings.TrimRight(addrs, "-") + ","
|
||||
i++
|
||||
if i >= 3 {
|
||||
break
|
||||
}
|
||||
}
|
||||
addrs = strings.TrimRight(addrs, ",")
|
||||
return
|
||||
}
|
182
app/service/main/account-recovery/dao/mid_info_test.go
Normal file
182
app/service/main/account-recovery/dao/mid_info_test.go
Normal file
@ -0,0 +1,182 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/smartystreets/goconvey/convey"
|
||||
"gopkg.in/h2non/gock.v1"
|
||||
)
|
||||
|
||||
func TestDaoGetMidInfo(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
qType = "1"
|
||||
qKey = "silg@yahoo.cn"
|
||||
)
|
||||
convey.Convey("GetMidInfo", t, func(ctx convey.C) {
|
||||
v, err := d.GetMidInfo(c, qType, qKey)
|
||||
ctx.Convey("Then err should be nil.v should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(v, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoGetUserInfo(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mid = int64(1)
|
||||
)
|
||||
convey.Convey("GetUserInfo", t, func(ctx convey.C) {
|
||||
defer gock.OffAll()
|
||||
httpMock("GET", d.c.AccRecover.GetUserInfoURL).Reply(200).JSON(`{"code":0,"data":{"mid":21,"email":"raiden131@yahoo.cn","telphone":"","join_time":1245902140}}`)
|
||||
v, err := d.GetUserInfo(c, mid)
|
||||
ctx.Convey("Then err should be nil.v should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(v, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoUpdatePwd(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mid = int64(1)
|
||||
)
|
||||
convey.Convey("UpdatePwd", t, func(ctx convey.C) {
|
||||
defer gock.OffAll()
|
||||
httpMock("POST", d.c.AccRecover.UpPwdURL).Reply(200).JSON(`{"code": 0, "data":{"pwd":"d4txsunbb1","userid":"minorin"}}`)
|
||||
user, err := d.UpdatePwd(c, mid, "账号找回服务")
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(user, convey.ShouldNotBeNil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoCheckSafe(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mid = int64(1)
|
||||
question = int8(0)
|
||||
answer = "1"
|
||||
)
|
||||
convey.Convey("CheckSafe", t, func(ctx convey.C) {
|
||||
check, err := d.CheckSafe(c, mid, question, answer)
|
||||
ctx.Convey("Then err should be nil.check should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(check, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
//func httpMock(method, url string) *gock.Request {
|
||||
// r := gock.New(url)
|
||||
// r.Method = strings.ToUpper(method)
|
||||
// return r
|
||||
//}
|
||||
|
||||
//func TestDaoGetUserType(t *testing.T) {
|
||||
// var (
|
||||
// c = context.Background()
|
||||
// mid = int64(2)
|
||||
// )
|
||||
// convey.Convey("When http request gets code != 0", t, func(ctx convey.C) {
|
||||
// defer gock.OffAll()
|
||||
// httpMock("GET", d.c.AccRecover.GameURL).Reply(0).JSON(`{"requestId":"0def8d70b7ef11e8a395fa163e01a2e9","ts":"1535440592","code":0,"items":[{"id":14,"name":"SDK测试2","lastLogin":"1500969010"}]}`)
|
||||
// games, err := d.GetUserType(c, mid)
|
||||
// ctx.So(err, convey.ShouldBeNil)
|
||||
// ctx.So(games, convey.ShouldNotBeNil)
|
||||
// })
|
||||
//}
|
||||
|
||||
func TestDaoCheckReg(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mid = int64(1)
|
||||
regTime = int64(1532441644)
|
||||
regType = int8(0)
|
||||
regAddr = "中国_上海"
|
||||
)
|
||||
convey.Convey("CheckReg", t, func(ctx convey.C) {
|
||||
v, err := d.CheckReg(c, mid, regTime, regType, regAddr)
|
||||
ctx.Convey("Then err should be nil.v should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(v, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoUpdateBatchPwd(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mids = "1,2"
|
||||
)
|
||||
convey.Convey("UpdateBatchPwd", t, func(ctx convey.C) {
|
||||
defer gock.OffAll()
|
||||
httpMock("POST", d.c.AccRecover.UpBatchPwdURL).Reply(200).JSON(`{"code":0,"data":{"6":{"pwd":"tgs52r1st9","userid":"腹黑君"},"7":{"pwd":"g20ahzrf7j","userid":"Tzwcard"}}}`)
|
||||
userMap, err := d.UpdateBatchPwd(c, mids, "账号找回服务")
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(userMap, convey.ShouldNotBeNil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoCheckCard(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mid = int64(1)
|
||||
cardType = int8(1)
|
||||
cardCode = "123"
|
||||
)
|
||||
convey.Convey("CheckCard", t, func(ctx convey.C) {
|
||||
ok, err := d.CheckCard(c, mid, cardType, cardCode)
|
||||
ctx.Convey("Then err should be nil.ok should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(ok, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoCheckPwds(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mid = int64(1)
|
||||
pwds = "123"
|
||||
)
|
||||
convey.Convey("CheckPwds", t, func(ctx convey.C) {
|
||||
v, err := d.CheckPwds(c, mid, pwds)
|
||||
ctx.Convey("Then err should be nil.v should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(v, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoGetLoginIPs(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mid = int64(2)
|
||||
limit = int64(10)
|
||||
)
|
||||
convey.Convey("GetLoginIPs", t, func(ctx convey.C) {
|
||||
ipInfo, err := d.GetLoginIPs(c, mid, limit)
|
||||
ctx.Convey("Then err should be nil.ipInfo should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(ipInfo, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoGetAddrByIP(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mid = int64(111001254)
|
||||
limit = int64(10)
|
||||
)
|
||||
convey.Convey("GetAddrByIP", t, func(ctx convey.C) {
|
||||
addrs, err := d.GetAddrByIP(c, mid, limit)
|
||||
ctx.Convey("Then err should be nil.addrs should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(addrs, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
475
app/service/main/account-recovery/dao/mysql.go
Normal file
475
app/service/main/account-recovery/dao/mysql.go
Normal file
@ -0,0 +1,475 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
rsql "database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"go-common/app/service/main/account-recovery/dao/sqlbuilder"
|
||||
"go-common/app/service/main/account-recovery/model"
|
||||
"go-common/library/database/sql"
|
||||
"go-common/library/log"
|
||||
xtime "go-common/library/time"
|
||||
"go-common/library/xstr"
|
||||
)
|
||||
|
||||
const (
|
||||
_selectCountRecoveryInfo = "select count(rid) from account_recovery_info"
|
||||
_selectRecoveryInfoLimit = "select rid,mid,user_type,status,login_addrs,unames,reg_time,reg_type,reg_addr,pwds,phones,emails,safe_question,safe_answer,card_type,card_id," +
|
||||
"sys_login_addrs,sys_reg,sys_unames,sys_pwds,sys_phones,sys_emails,sys_safe,sys_card," +
|
||||
"link_email,operator,opt_time,remark,ctime,business from account_recovery_info %s"
|
||||
_getSuccessCount = "SELECT count FROM account_recovery_success WHERE mid=?"
|
||||
_batchGetRecoverySuccess = "SELECT mid,count,ctime,mtime FROM account_recovery_success WHERE mid in (%s)"
|
||||
_updateSuccessCount = "INSERT INTO account_recovery_success (mid, count) VALUES (?, 1) ON DUPLICATE KEY UPDATE count = count + 1"
|
||||
_batchUpdateSuccessCount = "INSERT INTO account_recovery_success (mid, count) VALUES %s ON DUPLICATE KEY UPDATE count = count + 1"
|
||||
_updateStatus = "UPDATE account_recovery_info SET status=?,operator=?,opt_time=?,remark=? WHERE rid = ? AND `status`=0"
|
||||
_getNoDeal = "SELECT COUNT(1) FROM account_recovery_info WHERE mid=? AND `status`=0"
|
||||
_updateUserType = "UPDATE account_recovery_info SET user_type=? WHERE rid = ?"
|
||||
_insertRecoveryInfo = "INSERT INTO account_recovery_info(login_addrs,unames,reg_time,reg_type,reg_addr,pwds,phones,emails,safe_question,safe_answer,card_type,card_id,link_email,mid,business,last_suc_count,last_suc_ctime) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"
|
||||
_updateSysInfo = "UPDATE account_recovery_info SET sys_login_addrs=?,sys_reg=?,sys_unames=?,sys_pwds=?,sys_phones=?,sys_emails=?,sys_safe=?,sys_card=?,user_type=? WHERE rid=?"
|
||||
_getUinfoByRid = "SELECT mid,link_email,ctime FROM account_recovery_info WHERE rid=? LIMIT 1"
|
||||
_getUinfoByRidMore = "SELECT rid,mid,link_email,ctime FROM account_recovery_info WHERE rid in (%s)"
|
||||
_selectUnCheckInfo = "SELECT mid,login_addrs,unames,reg_time,reg_type,reg_addr,pwds,phones,emails,safe_question,safe_answer,card_type,card_id FROM account_recovery_info WHERE rid=? AND `status`=0 AND sys_card=''"
|
||||
_getStatusByRid = "SELECT `status` FROM account_recovery_info WHERE rid=?"
|
||||
_getMailStatus = "SELECT mail_status FROM account_recovery_info WHERE rid=?"
|
||||
_updateMailStatus = "UPDATE account_recovery_info SET mail_status=1 WHERE rid=?"
|
||||
_insertRecoveryAddit = "INSERT INTO account_recovery_addit(`rid`, `files`, `extra`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE files=VALUES(files),extra=VALUES(extra)"
|
||||
_updateRecoveryAddit = "UPDATE account_recovery_addit SET `files` = ?,`extra` = ? WHERE rid = ?"
|
||||
_getRecoveryAddit = "SELECT rid, `files`,`extra`, ctime, mtime FROM account_recovery_addit WHERE rid= ?"
|
||||
_batchRecoveryAAdit = "SELECT rid, `files`, `extra`, ctime, mtime FROM account_recovery_addit WHERE rid in (%s)"
|
||||
_batchGetLastSuccess = "SELECT mid,max(ctime) FROM account_recovery_info WHERE mid in (%s) AND `status`=1 GROUP BY mid"
|
||||
_getLastSuccess = "SELECT mid,max(ctime) FROM account_recovery_info WHERE mid = ? AND `status`=1"
|
||||
)
|
||||
|
||||
// GetStatusByRid get status by rid
|
||||
func (dao *Dao) GetStatusByRid(c context.Context, rid int64) (status int64, err error) {
|
||||
res := dao.db.Prepared(_getStatusByRid).QueryRow(c, rid)
|
||||
if err = res.Scan(&status); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
status = -1
|
||||
err = nil
|
||||
} else {
|
||||
log.Error("GetStatusByRid row.Scan error(%v)", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetSuccessCount get success count
|
||||
func (dao *Dao) GetSuccessCount(c context.Context, mid int64) (count int64, err error) {
|
||||
res := dao.db.Prepared(_getSuccessCount).QueryRow(c, mid)
|
||||
if err = res.Scan(&count); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
count = 0
|
||||
err = nil
|
||||
} else {
|
||||
log.Error("GetSuccessCount row.Scan error(%v)", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// BatchGetRecoverySuccess batch get recovery success info
|
||||
func (dao *Dao) BatchGetRecoverySuccess(c context.Context, mids []int64) (successMap map[int64]*model.RecoverySuccess, err error) {
|
||||
rows, err := dao.db.Query(c, fmt.Sprintf(_batchGetRecoverySuccess, xstr.JoinInts(mids)))
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
log.Error("BatchGetRecoverySuccess d.db.Query error(%v)", err)
|
||||
return
|
||||
}
|
||||
successMap = make(map[int64]*model.RecoverySuccess)
|
||||
for rows.Next() {
|
||||
r := new(model.RecoverySuccess)
|
||||
if err = rows.Scan(&r.SuccessMID, &r.SuccessCount, &r.FirstSuccessTime, &r.LastSuccessTime); err != nil {
|
||||
log.Error("BatchGetRecoverySuccess rows.Scan error(%v)", err)
|
||||
continue
|
||||
}
|
||||
successMap[r.SuccessMID] = r
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UpdateSuccessCount insert or update success count
|
||||
func (dao *Dao) UpdateSuccessCount(c context.Context, mid int64) (err error) {
|
||||
_, err = dao.db.Exec(c, _updateSuccessCount, mid)
|
||||
return
|
||||
}
|
||||
|
||||
// BatchUpdateSuccessCount batch insert or update success count
|
||||
func (dao *Dao) BatchUpdateSuccessCount(c context.Context, mids string) (err error) {
|
||||
var s string
|
||||
midArr := strings.Split(mids, ",")
|
||||
for _, mid := range midArr {
|
||||
s = s + fmt.Sprintf(",(%s, 1)", mid)
|
||||
}
|
||||
_, err = dao.db.Exec(c, fmt.Sprintf(_batchUpdateSuccessCount, s[1:]))
|
||||
return
|
||||
}
|
||||
|
||||
// GetNoDeal get no deal record
|
||||
func (dao *Dao) GetNoDeal(c context.Context, mid int64) (count int64, err error) {
|
||||
res := dao.db.Prepared(_getNoDeal).QueryRow(c, mid)
|
||||
if err = res.Scan(&count); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
log.Error("GetNoDeal row.Scan error(%v)", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UpdateStatus update field status.
|
||||
func (dao *Dao) UpdateStatus(c context.Context, status int64, rid int64, operator string, optTime xtime.Time, remark string) (err error) {
|
||||
_, err = dao.db.Exec(c, _updateStatus, status, operator, optTime, remark, rid)
|
||||
return
|
||||
}
|
||||
|
||||
// UpdateUserType update field user_type.
|
||||
func (dao *Dao) UpdateUserType(c context.Context, status int64, rid int64) (err error) {
|
||||
if _, err = dao.db.Exec(c, _updateUserType, status, rid); err != nil {
|
||||
log.Error("dao.db.Exec(%s, %d, %d) error(%v)", _updateUserType, status, rid, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// InsertRecoveryInfo insert data
|
||||
func (dao *Dao) InsertRecoveryInfo(c context.Context, uinfo *model.UserInfoReq) (lastID int64, err error) {
|
||||
var res rsql.Result
|
||||
if res, err = dao.db.Exec(c, _insertRecoveryInfo, uinfo.LoginAddrs, uinfo.Unames, uinfo.RegTime, uinfo.RegType, uinfo.RegAddr,
|
||||
uinfo.Pwds, uinfo.Phones, uinfo.Emails, uinfo.SafeQuestion, uinfo.SafeAnswer, uinfo.CardType, uinfo.CardID, uinfo.LinkMail, uinfo.Mid, uinfo.Business, uinfo.LastSucCount, uinfo.LastSucCTime); err != nil {
|
||||
log.Error("dao.db.Exec(%s, %v) error(%v)", _insertRecoveryInfo, uinfo, err)
|
||||
return
|
||||
}
|
||||
return res.LastInsertId()
|
||||
}
|
||||
|
||||
// UpdateSysInfo update sysinfo and user_type
|
||||
func (dao *Dao) UpdateSysInfo(c context.Context, sys *model.SysInfo, userType int64, rid int64) (err error) {
|
||||
if _, err = dao.db.Exec(c, _updateSysInfo, &sys.SysLoginAddrs, &sys.SysReg, &sys.SysUNames, &sys.SysPwds, &sys.SysPhones,
|
||||
&sys.SysEmails, &sys.SysSafe, &sys.SysCard, userType, rid); err != nil {
|
||||
log.Error("dao.db.Exec(%s, %v) error(%v)", _updateSysInfo, sys, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetAllByCon get a pageData by more condition
|
||||
func (dao *Dao) GetAllByCon(c context.Context, aq *model.QueryRecoveryInfoReq) ([]*model.AccountRecoveryInfo, int64, error) {
|
||||
query := sqlbuilder.NewSelectBuilder().Select("rid,mid,user_type,status,login_addrs,unames,reg_time,reg_type,reg_addr,pwds,phones,emails,safe_question,safe_answer,card_type,card_id,sys_login_addrs,sys_reg,sys_unames,sys_pwds,sys_phones,sys_emails,sys_safe,sys_card,link_email,operator,opt_time,remark,ctime,business,last_suc_count,last_suc_ctime").From("account_recovery_info")
|
||||
|
||||
if aq.Bussiness != "" {
|
||||
query = query.Where(query.Equal("business", aq.Bussiness))
|
||||
}
|
||||
|
||||
if aq.Status != nil {
|
||||
query = query.Where(fmt.Sprintf("status=%d", *aq.Status))
|
||||
}
|
||||
if aq.Game != nil {
|
||||
query = query.Where(fmt.Sprintf("user_type=%d", *aq.Game))
|
||||
}
|
||||
if aq.UID != 0 {
|
||||
query = query.Where(fmt.Sprintf("mid=%d", aq.UID))
|
||||
}
|
||||
if aq.RID != 0 {
|
||||
query = query.Where(fmt.Sprintf("rid=%d", aq.RID))
|
||||
}
|
||||
if aq.StartTime != 0 {
|
||||
query = query.Where(query.GE("ctime", aq.StartTime.Time()))
|
||||
}
|
||||
if aq.EndTime != 0 {
|
||||
query = query.Where(query.LE("ctime", aq.EndTime.Time()))
|
||||
}
|
||||
totalSQL, totalArg := query.Copy().Select("count(1)").Build()
|
||||
log.Info("Build GetAllByCon total count SQL: %s", totalSQL)
|
||||
page := aq.Page
|
||||
if page == 0 {
|
||||
page = 1
|
||||
}
|
||||
size := aq.Size
|
||||
if size == 0 {
|
||||
size = 50
|
||||
}
|
||||
query = query.Limit(int(size)).Offset(int(size * (page - 1))).OrderBy("rid DESC")
|
||||
rawSQL, rawArg := query.Build()
|
||||
log.Info("Build GetAllByCon SQL: %s", rawSQL)
|
||||
|
||||
total := int64(0)
|
||||
row := dao.db.QueryRow(c, totalSQL, totalArg...)
|
||||
if err := row.Scan(&total); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
rows, err := dao.db.Query(c, rawSQL, rawArg...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
resultData := make([]*model.AccountRecoveryInfo, 0)
|
||||
for rows.Next() {
|
||||
r := new(model.AccountRecoveryInfo)
|
||||
if err = rows.Scan(&r.Rid, &r.Mid, &r.UserType, &r.Status, &r.LoginAddr, &r.UNames, &r.RegTime, &r.RegType, &r.RegAddr,
|
||||
&r.Pwd, &r.Phones, &r.Emails, &r.SafeQuestion, &r.SafeAnswer, &r.CardType, &r.CardID,
|
||||
&r.SysLoginAddr, &r.SysReg, &r.SysUNames, &r.SysPwds, &r.SysPhones, &r.SysEmails, &r.SysSafe, &r.SysCard,
|
||||
&r.LinkEmail, &r.Operator, &r.OptTime, &r.Remark, &r.CTime, &r.Bussiness, &r.LastSucCount, &r.LastSucCTime); err != nil {
|
||||
log.Error("GetAllByCon error (%+v)", err)
|
||||
continue
|
||||
}
|
||||
resultData = append(resultData, r)
|
||||
}
|
||||
return resultData, total, err
|
||||
}
|
||||
|
||||
// QueryByID query by rid
|
||||
func (dao *Dao) QueryByID(c context.Context, rid int64, fromTime, endTime xtime.Time) (res *model.AccountRecoveryInfo, err error) {
|
||||
sql1 := "select rid,mid,user_type,status,login_addrs,unames,reg_time,reg_type,reg_addr,pwds,phones,emails,safe_question,safe_answer,card_type,card_id," +
|
||||
"sys_login_addrs,sys_reg,sys_unames,sys_pwds,sys_phones,sys_emails,sys_safe,sys_card," +
|
||||
"link_email,operator,opt_time,remark,ctime,business from account_recovery_info where ctime between ? and ? and rid = ?"
|
||||
res = new(model.AccountRecoveryInfo)
|
||||
row := dao.db.QueryRow(c, sql1, fromTime, endTime, rid)
|
||||
if err = row.Scan(&res.Rid, &res.Mid, &res.UserType, &res.Status, &res.LoginAddr, &res.UNames, &res.RegTime, &res.RegType, &res.RegAddr,
|
||||
&res.Pwd, &res.Phones, &res.Emails, &res.SafeQuestion, &res.SafeAnswer, &res.CardType, &res.CardID,
|
||||
&res.SysLoginAddr, &res.SysReg, &res.SysUNames, &res.SysPwds, &res.SysPhones, &res.SysEmails, &res.SysSafe, &res.SysCard,
|
||||
&res.LinkEmail, &res.Operator, &res.OptTime, &res.Remark, &res.CTime, &res.Bussiness); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
res = nil
|
||||
return
|
||||
}
|
||||
log.Error("QueryByID(%d) error(%v)", rid, err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//QueryInfoByLimit page query through limit m,n
|
||||
func (dao *Dao) QueryInfoByLimit(c context.Context, req *model.DBRecoveryInfoParams) (res []*model.AccountRecoveryInfo, total int64, err error) {
|
||||
p := make([]interface{}, 0)
|
||||
s := " where ctime between ? and ?"
|
||||
p = append(p, req.StartTime)
|
||||
p = append(p, req.EndTime)
|
||||
if req.ExistGame {
|
||||
s = s + " and user_type = ?"
|
||||
p = append(p, req.Game)
|
||||
}
|
||||
if req.ExistStatus {
|
||||
s = s + " and status = ?"
|
||||
p = append(p, req.Status)
|
||||
}
|
||||
if req.ExistMid {
|
||||
s = s + " and mid = ?"
|
||||
p = append(p, req.Mid)
|
||||
}
|
||||
|
||||
var s2 = s + " order by rid desc limit ?,?"
|
||||
p2 := p
|
||||
p2 = append(p2, (req.CurrPage-1)*req.Size, req.Size)
|
||||
var rows *sql.Rows
|
||||
rows, err = dao.db.Query(c, fmt.Sprintf(_selectRecoveryInfoLimit, s2), p2...)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
log.Error("QueryInfo err: d.db.Query error(%v)", err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
res = make([]*model.AccountRecoveryInfo, 0, req.Size)
|
||||
for rows.Next() {
|
||||
r := new(model.AccountRecoveryInfo)
|
||||
if err = rows.Scan(&r.Rid, &r.Mid, &r.UserType, &r.Status, &r.LoginAddr, &r.UNames, &r.RegTime, &r.RegType, &r.RegAddr,
|
||||
&r.Pwd, &r.Phones, &r.Emails, &r.SafeQuestion, &r.SafeAnswer, &r.CardType, &r.CardID,
|
||||
&r.SysLoginAddr, &r.SysReg, &r.SysUNames, &r.SysPwds, &r.SysPhones, &r.SysEmails, &r.SysSafe, &r.SysCard,
|
||||
&r.LinkEmail, &r.Operator, &r.OptTime, &r.Remark, &r.CTime, &r.Bussiness); err != nil {
|
||||
log.Error("QueryInfo (%v) error (%v)", req, err)
|
||||
return
|
||||
}
|
||||
res = append(res, r)
|
||||
}
|
||||
row := dao.db.QueryRow(c, _selectCountRecoveryInfo+s, p...)
|
||||
if err = row.Scan(&total); err != nil {
|
||||
log.Error("QueryInfo total error (%v)", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetUinfoByRid get mid,linkMail by rid
|
||||
func (dao *Dao) GetUinfoByRid(c context.Context, rid int64) (mid int64, linkMail string, ctime string, err error) {
|
||||
res := dao.db.Prepared(_getUinfoByRid).QueryRow(c, rid)
|
||||
req := new(struct {
|
||||
Mid int64
|
||||
LinKMail string
|
||||
Ctime xtime.Time
|
||||
})
|
||||
|
||||
if err = res.Scan(&req.Mid, &req.LinKMail, &req.Ctime); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
req.Mid = 0
|
||||
err = nil
|
||||
} else {
|
||||
log.Error("GetUinfoByRid row.Scan error(%v)", err)
|
||||
}
|
||||
}
|
||||
mid = req.Mid
|
||||
linkMail = req.LinKMail
|
||||
ctime = req.Ctime.Time().Format("2006-01-02 15:04:05")
|
||||
return
|
||||
}
|
||||
|
||||
// GetUinfoByRidMore get list of BatchAppeal by rid
|
||||
func (dao *Dao) GetUinfoByRidMore(c context.Context, ridsStr string) (bathRes []*model.BatchAppeal, err error) {
|
||||
|
||||
rows, err := dao.db.Prepared(fmt.Sprintf(_getUinfoByRidMore, ridsStr)).Query(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
bathRes = make([]*model.BatchAppeal, 0, len(strings.Split(ridsStr, ",")))
|
||||
for rows.Next() {
|
||||
req := &model.BatchAppeal{}
|
||||
if err = rows.Scan(&req.Rid, &req.Mid, &req.LinkMail, &req.Ctime); err != nil {
|
||||
return
|
||||
}
|
||||
bathRes = append(bathRes, req)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetUnCheckInfo get uncheck info
|
||||
func (dao *Dao) GetUnCheckInfo(c context.Context, rid int64) (r *model.UserInfoReq, err error) {
|
||||
row := dao.db.QueryRow(c, _selectUnCheckInfo, rid)
|
||||
r = new(model.UserInfoReq)
|
||||
if err = row.Scan(&r.Mid, &r.LoginAddrs, &r.Unames, &r.RegTime, &r.RegType, &r.RegAddr,
|
||||
&r.Pwds, &r.Phones, &r.Emails, &r.SafeQuestion, &r.SafeAnswer, &r.CardType, &r.CardID); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
log.Error("GetUnCheckInfo (%v) error (%v)", rid, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//BeginTran begin transaction
|
||||
func (dao *Dao) BeginTran(ctx context.Context) (tx *sql.Tx, err error) {
|
||||
if tx, err = dao.db.Begin(ctx); err != nil {
|
||||
log.Error("db: begintran BeginTran d.db.Begin error(%v)", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetMailStatus get mail_status by rid
|
||||
func (dao *Dao) GetMailStatus(c context.Context, rid int64) (mailStatus int64, err error) {
|
||||
res := dao.db.Prepared(_getMailStatus).QueryRow(c, rid)
|
||||
if err = res.Scan(&mailStatus); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
mailStatus = -1
|
||||
err = nil
|
||||
} else {
|
||||
log.Error("GetStatusByRid row.Scan error(%v)", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// UpdateMailStatus update mail_status.
|
||||
func (dao *Dao) UpdateMailStatus(c context.Context, rid int64) (err error) {
|
||||
_, err = dao.db.Exec(c, _updateMailStatus, rid)
|
||||
return
|
||||
}
|
||||
|
||||
// UpdateRecoveryAddit is
|
||||
func (dao *Dao) UpdateRecoveryAddit(c context.Context, rid int64, files []string, extra string) (err error) {
|
||||
_, err = dao.db.Exec(c, _updateRecoveryAddit, strings.Join(files, ","), extra, rid)
|
||||
return
|
||||
}
|
||||
|
||||
// GetRecoveryAddit is
|
||||
func (dao *Dao) GetRecoveryAddit(c context.Context, rid int64) (addit *model.DBAccountRecoveryAddit, err error) {
|
||||
row := dao.db.QueryRow(c, _getRecoveryAddit, rid)
|
||||
addit = new(model.DBAccountRecoveryAddit)
|
||||
if err = row.Scan(&addit.Rid, &addit.Files, &addit.Extra, &addit.Ctime, &addit.Mtime); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
log.Error("GetRecoveryAddit (%v) error (%v)", rid, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// InsertRecoveryAddit is
|
||||
func (dao *Dao) InsertRecoveryAddit(c context.Context, rid int64, files, extra string) (err error) {
|
||||
_, err = dao.db.Exec(c, _insertRecoveryAddit, rid, files, extra)
|
||||
return
|
||||
}
|
||||
|
||||
//BatchGetRecoveryAddit is
|
||||
func (dao *Dao) BatchGetRecoveryAddit(c context.Context, rids []int64) (addits map[int64]*model.DBAccountRecoveryAddit, err error) {
|
||||
|
||||
rows, err := dao.db.Query(c, fmt.Sprintf(_batchRecoveryAAdit, xstr.JoinInts(rids)))
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
log.Error("BatchGetRecoveryAddit d.db.Query error(%v)", err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
addits = make(map[int64]*model.DBAccountRecoveryAddit)
|
||||
for rows.Next() {
|
||||
var addit = new(model.DBAccountRecoveryAddit)
|
||||
if err = rows.Scan(&addit.Rid, &addit.Files, &addit.Extra, &addit.Ctime, &addit.Mtime); err != nil {
|
||||
log.Error("BatchGetRecoveryAddit rows.Scan error(%v)", err)
|
||||
continue
|
||||
}
|
||||
addits[addit.Rid] = addit
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// BatchGetLastSuccess batch get last find success info
|
||||
func (dao *Dao) BatchGetLastSuccess(c context.Context, mids []int64) (lastSuccessMap map[int64]*model.LastSuccessData, err error) {
|
||||
rows, err := dao.db.Query(c, fmt.Sprintf(_batchGetLastSuccess, xstr.JoinInts(mids)))
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
log.Error("BatchGetLastSuccess d.db.Query error(%v)", err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
lastSuccessMap = make(map[int64]*model.LastSuccessData)
|
||||
for rows.Next() {
|
||||
r := new(model.LastSuccessData)
|
||||
if err = rows.Scan(&r.LastApplyMID, &r.LastApplyTime); err != nil {
|
||||
log.Error("BatchGetLastSuccess rows.Scan error(%v)", err)
|
||||
continue
|
||||
}
|
||||
lastSuccessMap[r.LastApplyMID] = r
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetLastSuccess get last find success info
|
||||
func (dao *Dao) GetLastSuccess(c context.Context, mid int64) (lastSuc *model.LastSuccessData, err error) {
|
||||
row := dao.db.QueryRow(c, _getLastSuccess, mid)
|
||||
lastSuc = new(model.LastSuccessData)
|
||||
if err = row.Scan(&lastSuc.LastApplyMID, &lastSuc.LastApplyTime); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
log.Error("GetRecoveryAddit (%v) error (%v)", mid, err)
|
||||
}
|
||||
return
|
||||
}
|
412
app/service/main/account-recovery/dao/mysql_test.go
Normal file
412
app/service/main/account-recovery/dao/mysql_test.go
Normal file
@ -0,0 +1,412 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"go-common/app/service/main/account-recovery/model"
|
||||
xtime "go-common/library/time"
|
||||
|
||||
"github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
func TestDaoGetStatusByRid(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
rid = int64(1)
|
||||
)
|
||||
convey.Convey("GetStatusByRid", t, func(ctx convey.C) {
|
||||
status, err := d.GetStatusByRid(c, rid)
|
||||
ctx.Convey("Then err should be nil.status should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(status, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoGetSuccessCount(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mid = int64(1234)
|
||||
)
|
||||
convey.Convey("GetSuccessCount", t, func(ctx convey.C) {
|
||||
count, err := d.GetSuccessCount(c, mid)
|
||||
ctx.Convey("Then err should be nil.count should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(count, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoBatchGetRecoverySuccess(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mids = []int64{1234}
|
||||
)
|
||||
convey.Convey("BatchGetRecoverySuccess", t, func(ctx convey.C) {
|
||||
countMap, err := d.BatchGetRecoverySuccess(c, mids)
|
||||
ctx.Convey("Then err should be nil.countMap should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(countMap, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoUpdateSuccessCount(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mid = int64(1234)
|
||||
)
|
||||
convey.Convey("UpdateSuccessCount", t, func(ctx convey.C) {
|
||||
err := d.UpdateSuccessCount(c, mid)
|
||||
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoBatchUpdateSuccessCount(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mids = "1234"
|
||||
)
|
||||
convey.Convey("BatchUpdateSuccessCount", t, func(ctx convey.C) {
|
||||
err := d.BatchUpdateSuccessCount(c, mids)
|
||||
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoGetNoDeal(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mid = int64(1234)
|
||||
)
|
||||
convey.Convey("GetNoDeal", t, func(ctx convey.C) {
|
||||
count, err := d.GetNoDeal(c, mid)
|
||||
ctx.Convey("Then err should be nil.count should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(count, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoUpdateStatus(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
status = int64(1)
|
||||
rid = int64(1)
|
||||
operator = "abcd"
|
||||
optTime xtime.Time
|
||||
remark = ""
|
||||
)
|
||||
convey.Convey("UpdateStatus", t, func(ctx convey.C) {
|
||||
err := d.UpdateStatus(c, status, rid, operator, optTime, remark)
|
||||
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoUpdateUserType(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
status = int64(1)
|
||||
rid = int64(1)
|
||||
)
|
||||
convey.Convey("UpdateUserType", t, func(ctx convey.C) {
|
||||
err := d.UpdateUserType(c, status, rid)
|
||||
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoInsertRecoveryInfo(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
uinfo = &model.UserInfoReq{
|
||||
LoginAddrs: "中国-福州,中国-上海,澳大利亚",
|
||||
//RegTime:timeS, //变成2018
|
||||
RegTime: 1533206284, //变成2018 //数据库设置为int(11),so数据库必须设置为tiimestamp
|
||||
RegType: int8(1),
|
||||
RegAddr: "中国上海",
|
||||
Unames: "昵称AA,昵称BB,昵称CC",
|
||||
Pwds: "密码1,密码2",
|
||||
Phones: "12345678901,54321678923",
|
||||
Emails: "2456@sina.com,789@qq.com",
|
||||
SafeQuestion: int8(1),
|
||||
SafeAnswer: "心态呀",
|
||||
CardID: "ISN-1234567890-0987",
|
||||
CardType: int8(1),
|
||||
LinkMail: "345678@qq.com",
|
||||
Mid: 1234,
|
||||
}
|
||||
)
|
||||
convey.Convey("InsertRecoveryInfo", t, func(ctx convey.C) {
|
||||
lastID, err := d.InsertRecoveryInfo(c, uinfo)
|
||||
ctx.Convey("Then err should be nil.lastID should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(lastID, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoUpdateSysInfo(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
sys = &model.SysInfo{
|
||||
SysLoginAddrs: "中国-福州,中国-上海,澳大利亚",
|
||||
SysReg: "对",
|
||||
SysUNames: "对,错,错",
|
||||
SysPwds: "对,错",
|
||||
SysPhones: "对,错",
|
||||
SysEmails: "对,错",
|
||||
SysSafe: "对",
|
||||
SysCard: "对",
|
||||
}
|
||||
userType = int64(1)
|
||||
rid = int64(1)
|
||||
)
|
||||
convey.Convey("UpdateSysInfo", t, func(ctx convey.C) {
|
||||
err := d.UpdateSysInfo(c, sys, userType, rid)
|
||||
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoQueryByID(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
rid = int64(1)
|
||||
fromTime xtime.Time = 1533120949
|
||||
endTime xtime.Time = 1535636392
|
||||
)
|
||||
convey.Convey("QueryByID", t, func(ctx convey.C) {
|
||||
res, err := d.QueryByID(c, rid, fromTime, endTime)
|
||||
ctx.Convey("Then err should be nil.res should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(res, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoQueryInfoByLimit(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
req = &model.DBRecoveryInfoParams{
|
||||
ExistGame: false,
|
||||
ExistStatus: false,
|
||||
ExistMid: false,
|
||||
Mid: 0,
|
||||
Game: 0,
|
||||
Status: 1,
|
||||
FirstRid: 20,
|
||||
LastRid: 0,
|
||||
Size: 2,
|
||||
StartTime: 1533120949,
|
||||
EndTime: 1535636392,
|
||||
SubNum: 1,
|
||||
CurrPage: 1,
|
||||
}
|
||||
)
|
||||
convey.Convey("QueryInfoByLimit", t, func(ctx convey.C) {
|
||||
res, total, err := d.QueryInfoByLimit(c, req)
|
||||
ctx.Convey("Then err should be nil.res,total should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(total, convey.ShouldNotBeNil)
|
||||
ctx.So(res, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoGetUinfoByRid(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
rid = int64(240)
|
||||
)
|
||||
convey.Convey("GetUinfoByRid", t, func(ctx convey.C) {
|
||||
mid, linkMail, ctime, err := d.GetUinfoByRid(c, rid)
|
||||
ctx.Convey("Then err should be nil.mid,linkMail,ctime should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(ctime, convey.ShouldNotBeNil)
|
||||
ctx.So(linkMail, convey.ShouldNotBeNil)
|
||||
ctx.So(mid, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoGetUinfoByRidMore(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
ridsStr = "1,2"
|
||||
)
|
||||
convey.Convey("GetUinfoByRidMore", t, func(ctx convey.C) {
|
||||
bathRes, err := d.GetUinfoByRidMore(c, ridsStr)
|
||||
ctx.Convey("Then err should be nil.bathRes should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(bathRes, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoGetUnCheckInfo(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
rid = int64(1)
|
||||
)
|
||||
convey.Convey("GetUnCheckInfo", t, func(ctx convey.C) {
|
||||
r, err := d.GetUnCheckInfo(c, rid)
|
||||
ctx.Convey("Then err should be nil.r should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(r, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoGetMailStatus(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
rid = int64(1)
|
||||
)
|
||||
convey.Convey("GetMailStatus", t, func(ctx convey.C) {
|
||||
mailStatus, err := d.GetMailStatus(c, rid)
|
||||
ctx.Convey("Then err should be nil.mailStatus should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(mailStatus, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoUpdateMailStatus(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
rid = int64(1)
|
||||
)
|
||||
convey.Convey("UpdateMailStatus", t, func(ctx convey.C) {
|
||||
err := d.UpdateMailStatus(c, rid)
|
||||
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoGetAllByCon(t *testing.T) {
|
||||
var (
|
||||
defint int64
|
||||
c = context.Background()
|
||||
aq = &model.QueryRecoveryInfoReq{
|
||||
//RID: 1,
|
||||
//UID:2,
|
||||
Status: &defint,
|
||||
Game: &defint,
|
||||
Size: 10,
|
||||
Page: 1,
|
||||
StartTime: 1533052800,
|
||||
EndTime: 1536924163,
|
||||
|
||||
//StartTime time.Time `json:"start_time" form:"start_time"`
|
||||
//EndTime time.Time `json:"end_time" form:"end_time"`
|
||||
//IsAdvanced bool `json:"-"`
|
||||
//Page int64 `form:"page"`
|
||||
}
|
||||
)
|
||||
convey.Convey("UpdateMailStatus", t, func(ctx convey.C) {
|
||||
resultData, total, err := d.GetAllByCon(c, aq)
|
||||
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
|
||||
ctx.So(resultData, convey.ShouldNotBeNil)
|
||||
ctx.So(total, convey.ShouldNotBeNil)
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.Println(total, " len=", len(resultData))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoInsertRecoveryAddit(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
rid int64 = 1
|
||||
files = "http://uat-i0.hdslb.com/bfs/account/recovery/bca2.zip,http://uat-i0.hdslb.com/bfs/account/recovery/abcd.zip"
|
||||
extra = `{"GameArea":"ios-A服","GameNames":"崩坏3","GamePlay":"1"}`
|
||||
)
|
||||
convey.Convey("InsertRecoveryAddit", t, func(ctx convey.C) {
|
||||
err := d.InsertRecoveryAddit(c, rid, files, extra)
|
||||
err = d.InsertRecoveryAddit(c, 2, files, extra)
|
||||
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoUpdateRecoveryAddit(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
rid int64 = 1
|
||||
files = []string{"http://uat-i0.hdslb.com/bfs/aaaa.zip", "http://uat-i0.hdslb.com/bfs/dddd.zip"}
|
||||
extra = `{"GameArea":"ios-A服","GameNames":"崩坏3","GamePlay":"1"}`
|
||||
)
|
||||
convey.Convey("UpdateMailStatus", t, func(ctx convey.C) {
|
||||
err := d.UpdateRecoveryAddit(c, rid, files, extra)
|
||||
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoGetRecoveryAddit(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
rid int64 = 1
|
||||
)
|
||||
convey.Convey("UpdateMailStatus", t, func(ctx convey.C) {
|
||||
addit, err := d.GetRecoveryAddit(c, rid)
|
||||
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.Println(addit, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoBatchGetRecoveryAddit(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
rids = []int64{1, 2}
|
||||
)
|
||||
convey.Convey("BatchGetRecoveryAddit", t, func(ctx convey.C) {
|
||||
addits, err := d.BatchGetRecoveryAddit(c, rids)
|
||||
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(addits, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBatchGetLastSuccess(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mids = []int64{1234}
|
||||
)
|
||||
convey.Convey("BatchGetLastSuccess", t, func(ctx convey.C) {
|
||||
lastSuccessMap, err := d.BatchGetLastSuccess(c, mids)
|
||||
ctx.Convey("Then err should be nil.count should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(lastSuccessMap, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoGetLastSuccess(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mid = int64(1234)
|
||||
)
|
||||
convey.Convey("GetLastSuccess", t, func(ctx convey.C) {
|
||||
res, err := d.GetLastSuccess(c, mid)
|
||||
ctx.Convey("Then err should be nil.count should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(res, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
92
app/service/main/account-recovery/dao/redis.go
Normal file
92
app/service/main/account-recovery/dao/redis.go
Normal file
@ -0,0 +1,92 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"go-common/library/cache/redis"
|
||||
"go-common/library/log"
|
||||
)
|
||||
|
||||
const (
|
||||
_expire = 30 * 60 // 30 minutes
|
||||
_prefixCaptcha = "recovery:ca_"
|
||||
)
|
||||
|
||||
// SetLinkMailCount set linkMail expire time.
|
||||
func (d *Dao) SetLinkMailCount(c context.Context, linkMail string) (state int64, err error) {
|
||||
conn := d.redis.Get(c)
|
||||
defer conn.Close()
|
||||
//用时间戳去减去时间 当天0点过期,第二天又可以发送10封邮件
|
||||
i, _ := redis.Int(conn.Do("incr", linkMail))
|
||||
conn.Do("expire", linkMail, getSubtime())
|
||||
if i >= 11 { //第10封邮件之后
|
||||
state = 10 //当天邮件发送到达最大次数
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getSubtime() (subtime int64) {
|
||||
timeStr := time.Now().Format("2006-01-02")
|
||||
t, _ := time.Parse("2006-01-02", timeStr)
|
||||
t2 := t.AddDate(0, 0, 1).Unix()
|
||||
curr := time.Now().Unix() //当前时间
|
||||
subtime = t2 - curr
|
||||
return
|
||||
}
|
||||
|
||||
// keyCaptcha
|
||||
func keyCaptcha(mid int64, linkMail string) string {
|
||||
return _prefixCaptcha + strconv.FormatInt(mid, 10) + "_" + linkMail
|
||||
}
|
||||
|
||||
// SetCaptcha set linkMail expire time.
|
||||
func (d *Dao) SetCaptcha(c context.Context, code string, mid int64, linkMail string) (err error) {
|
||||
conn := d.redis.Get(c)
|
||||
defer conn.Close()
|
||||
key := keyCaptcha(mid, linkMail)
|
||||
//验证码30分钟内有效
|
||||
if _, err = conn.Do("SETEX", key, _expire, code); err != nil {
|
||||
log.Error("conn.Do(SETEX, %d, %v, %s) error(%v)", mid, _expire, code, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetEMailCode get captcha from redis
|
||||
func (d *Dao) GetEMailCode(c context.Context, mid int64, linkMail string) (code string, err error) {
|
||||
key := keyCaptcha(mid, linkMail)
|
||||
conn := d.redis.Get(c)
|
||||
defer conn.Close()
|
||||
code, err = redis.String(conn.Do("GET", key))
|
||||
if err != nil {
|
||||
if err == redis.ErrNil {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
log.Error("conn.Do(GET, %s, ), err (%v)", key, err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DelEMailCode del captcha from redis 提交:校验验证之后就删除验证码(保证只能提交一次)
|
||||
func (d *Dao) DelEMailCode(c context.Context, mid int64, linkMail string) (err error) {
|
||||
key := keyCaptcha(mid, linkMail)
|
||||
conn := d.redis.Get(c)
|
||||
defer conn.Close()
|
||||
if _, err = conn.Do("DEL", key); err != nil {
|
||||
log.Error("conn.Do(DEL, %s, ), err (%v)", key, err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// PingRedis check connection success.
|
||||
func (d *Dao) PingRedis(c context.Context) (err error) {
|
||||
conn := d.redis.Get(c)
|
||||
defer conn.Close()
|
||||
_, err = conn.Do("GET", "PING")
|
||||
return
|
||||
}
|
100
app/service/main/account-recovery/dao/redis_test.go
Normal file
100
app/service/main/account-recovery/dao/redis_test.go
Normal file
@ -0,0 +1,100 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
func TestDaoSetLinkMailCount(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
linkMail = "2459593393@qq.com"
|
||||
)
|
||||
convey.Convey("SetLinkMailCount", t, func(ctx convey.C) {
|
||||
state, err := d.SetLinkMailCount(c, linkMail)
|
||||
ctx.Convey("Then err should be nil.state should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(state, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaogetSubtime(t *testing.T) {
|
||||
convey.Convey("getSubtime", t, func(ctx convey.C) {
|
||||
subtime := getSubtime()
|
||||
ctx.Convey("Then subtime should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(subtime, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaokeyCaptcha(t *testing.T) {
|
||||
var (
|
||||
mid = int64(0)
|
||||
linkMail = "2459593393@qq.com"
|
||||
)
|
||||
convey.Convey("keyCaptcha", t, func(ctx convey.C) {
|
||||
p1 := keyCaptcha(mid, linkMail)
|
||||
ctx.Convey("Then p1 should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(p1, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoSetCaptcha(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
code = "1234"
|
||||
mid = int64(1)
|
||||
linkMail = "2459593393@qq.com"
|
||||
)
|
||||
convey.Convey("SetCaptcha", t, func(ctx convey.C) {
|
||||
err := d.SetCaptcha(c, code, mid, linkMail)
|
||||
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoGetEMailCode(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mid = int64(1)
|
||||
linkMail = "2459593393@qq.com"
|
||||
)
|
||||
convey.Convey("GetEMailCode", t, func(ctx convey.C) {
|
||||
code, err := d.GetEMailCode(c, mid, linkMail)
|
||||
ctx.Convey("Then err should be nil.code should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(code, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoDelEMailCode(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mid = int64(1)
|
||||
linkMail = "2459593393@qq.com"
|
||||
)
|
||||
convey.Convey("GetEMailCode", t, func(ctx convey.C) {
|
||||
err := d.DelEMailCode(c, mid, linkMail)
|
||||
ctx.Convey("Then err should be nil.code should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoPingRedis(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
)
|
||||
convey.Convey("PingRedis", t, func(ctx convey.C) {
|
||||
err := d.PingRedis(c)
|
||||
ctx.Convey("Then err should be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
})
|
||||
})
|
||||
}
|
49
app/service/main/account-recovery/dao/req_rpc.go
Normal file
49
app/service/main/account-recovery/dao/req_rpc.go
Normal file
@ -0,0 +1,49 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
account "go-common/app/service/main/account/api"
|
||||
location "go-common/app/service/main/location/model"
|
||||
member "go-common/app/service/main/member/api"
|
||||
"go-common/library/log"
|
||||
"go-common/library/net/metadata"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Info3 get info by mid
|
||||
func (d *Dao) Info3(c context.Context, mid int64) (info *account.Info, err error) {
|
||||
var (
|
||||
arg = &account.MidReq{
|
||||
Mid: mid,
|
||||
RealIp: metadata.String(c, metadata.RemoteIP),
|
||||
}
|
||||
res *account.InfoReply
|
||||
)
|
||||
if res, err = d.accountClient.Info3(c, arg); err != nil {
|
||||
err = errors.Wrapf(err, "%v", arg)
|
||||
return nil, err
|
||||
}
|
||||
return res.Info, nil
|
||||
}
|
||||
|
||||
// Infos get the ips info.
|
||||
func (d *Dao) Infos(c context.Context, ipList []string) (res map[string]*location.Info, err error) {
|
||||
if res, err = d.locRPC.Infos(c, ipList); err != nil {
|
||||
log.Error("s.locaRPC err(%v)", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CheckRealnameStatus realname status
|
||||
func (d *Dao) CheckRealnameStatus(c context.Context, mid int64) (status int8, err error) {
|
||||
var (
|
||||
relnameStatus *member.RealnameStatusReply
|
||||
)
|
||||
if relnameStatus, err = d.memberClient.RealnameStatus(c, &member.MemberMidReq{Mid: mid, RemoteIP: metadata.String(c, metadata.RemoteIP)}); err != nil {
|
||||
log.Error("s.memberSvr.RealnameStatus err(%v)", err)
|
||||
return
|
||||
}
|
||||
return relnameStatus.RealnameStatus, nil
|
||||
}
|
50
app/service/main/account-recovery/dao/req_rpc_test.go
Normal file
50
app/service/main/account-recovery/dao/req_rpc_test.go
Normal file
@ -0,0 +1,50 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
func TestDaoInfo3(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mid = int64(2)
|
||||
)
|
||||
convey.Convey("Info3", t, func(ctx convey.C) {
|
||||
res, err := d.Info3(c, mid)
|
||||
ctx.Convey("Then err should be nil.res should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(res, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoInfos(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
ipList = []string{"127.0.0.1"}
|
||||
)
|
||||
convey.Convey("Infos", t, func(ctx convey.C) {
|
||||
res, err := d.Infos(c, ipList)
|
||||
ctx.Convey("Then err should be nil.res should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(res, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoCheckRealnameStatus(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
mid int64 = 1
|
||||
)
|
||||
convey.Convey("CheckRealnameStatus", t, func(ctx convey.C) {
|
||||
res, err := d.CheckRealnameStatus(c, mid)
|
||||
ctx.Convey("Then err should be nil.res should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(res, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
51
app/service/main/account-recovery/dao/sqlbuilder/BUILD
Normal file
51
app/service/main/account-recovery/dao/sqlbuilder/BUILD
Normal file
@ -0,0 +1,51 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
load(
|
||||
"@io_bazel_rules_go//go:def.bzl",
|
||||
"go_test",
|
||||
"go_library",
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "go_default_test",
|
||||
srcs = [
|
||||
"args_test.go",
|
||||
"builder_test.go",
|
||||
"cond_test.go",
|
||||
"flavor_test.go",
|
||||
"modifiers_test.go",
|
||||
"select_test.go",
|
||||
],
|
||||
embed = [":go_default_library"],
|
||||
rundir = ".",
|
||||
tags = ["automanaged"],
|
||||
)
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = [
|
||||
"args.go",
|
||||
"builder.go",
|
||||
"cond.go",
|
||||
"flavor.go",
|
||||
"modifiers.go",
|
||||
"select.go",
|
||||
],
|
||||
importpath = "go-common/app/service/main/account-recovery/dao/sqlbuilder",
|
||||
tags = ["automanaged"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "package-srcs",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["automanaged"],
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all-srcs",
|
||||
srcs = [":package-srcs"],
|
||||
tags = ["automanaged"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
236
app/service/main/account-recovery/dao/sqlbuilder/args.go
Normal file
236
app/service/main/account-recovery/dao/sqlbuilder/args.go
Normal file
@ -0,0 +1,236 @@
|
||||
// Copyright 2018 Huan Du. All rights reserved.
|
||||
// Licensed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Args stores arguments associated with a SQL.
|
||||
type Args struct {
|
||||
// The default flavor used by `Args#Compile`
|
||||
Flavor Flavor
|
||||
|
||||
args []interface{}
|
||||
namedArgs map[string]int
|
||||
sqlNamedArgs map[string]int
|
||||
onlyNamed bool
|
||||
}
|
||||
|
||||
// Add adds an arg to Args and returns a placeholder.
|
||||
func (args *Args) Add(arg interface{}) string {
|
||||
return fmt.Sprintf("$%v", args.add(arg))
|
||||
}
|
||||
|
||||
func (args *Args) add(arg interface{}) int {
|
||||
idx := len(args.args)
|
||||
|
||||
switch a := arg.(type) {
|
||||
case sql.NamedArg:
|
||||
if args.sqlNamedArgs == nil {
|
||||
args.sqlNamedArgs = map[string]int{}
|
||||
}
|
||||
|
||||
if p, ok := args.sqlNamedArgs[a.Name]; ok {
|
||||
arg = args.args[p]
|
||||
break
|
||||
}
|
||||
|
||||
args.sqlNamedArgs[a.Name] = idx
|
||||
case namedArgs:
|
||||
if args.namedArgs == nil {
|
||||
args.namedArgs = map[string]int{}
|
||||
}
|
||||
|
||||
if p, ok := args.namedArgs[a.name]; ok {
|
||||
arg = args.args[p]
|
||||
break
|
||||
}
|
||||
|
||||
// Find out the real arg and add it to args.
|
||||
idx = args.add(a.arg)
|
||||
args.namedArgs[a.name] = idx
|
||||
return idx
|
||||
}
|
||||
|
||||
args.args = append(args.args, arg)
|
||||
return idx
|
||||
}
|
||||
|
||||
// Compile compiles builder's format to standard sql and returns associated args.
|
||||
//
|
||||
// The format string uses a special syntax to represent arguments.
|
||||
//
|
||||
// $? refers successive arguments passed in the call. It works similar as `%v` in `fmt.Sprintf`.
|
||||
// $0 $1 ... $n refers nth-argument passed in the call. Next $? will use arguments n+1.
|
||||
// ${name} refers a named argument created by `Named` with `name`.
|
||||
// $$ is a "$" string.
|
||||
func (args *Args) Compile(format string, intialValue ...interface{}) (query string, values []interface{}) {
|
||||
return args.CompileWithFlavor(format, args.Flavor, intialValue...)
|
||||
}
|
||||
|
||||
// CompileWithFlavor compiles builder's format to standard sql with flavor and returns associated args.
|
||||
//
|
||||
// See doc for `Compile` to learn details.
|
||||
func (args *Args) CompileWithFlavor(format string, flavor Flavor, intialValue ...interface{}) (query string, values []interface{}) {
|
||||
buf := &bytes.Buffer{}
|
||||
idx := strings.IndexRune(format, '$')
|
||||
offset := 0
|
||||
values = intialValue
|
||||
|
||||
if flavor == invalidFlavor {
|
||||
flavor = DefaultFlavor
|
||||
}
|
||||
|
||||
for idx >= 0 && len(format) > 0 {
|
||||
if idx > 0 {
|
||||
buf.WriteString(format[:idx])
|
||||
}
|
||||
|
||||
format = format[idx+1:]
|
||||
|
||||
// Should not happen.
|
||||
if len(format) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if format[0] == '$' {
|
||||
buf.WriteRune('$')
|
||||
format = format[1:]
|
||||
} else if format[0] == '{' {
|
||||
format, values = args.compileNamed(buf, flavor, format, values)
|
||||
} else if !args.onlyNamed && '0' <= format[0] && format[0] <= '9' {
|
||||
format, values, offset = args.compileDigits(buf, flavor, format, values, offset)
|
||||
} else if !args.onlyNamed && format[0] == '?' {
|
||||
format, values, offset = args.compileSuccessive(buf, flavor, format[1:], values, offset)
|
||||
}
|
||||
|
||||
idx = strings.IndexRune(format, '$')
|
||||
}
|
||||
|
||||
if len(format) > 0 {
|
||||
buf.WriteString(format)
|
||||
}
|
||||
|
||||
query = buf.String()
|
||||
|
||||
if len(args.sqlNamedArgs) > 0 {
|
||||
// Stabilize the sequence to make it easier to write test cases.
|
||||
ints := make([]int, 0, len(args.sqlNamedArgs))
|
||||
|
||||
for _, p := range args.sqlNamedArgs {
|
||||
ints = append(ints, p)
|
||||
}
|
||||
|
||||
sort.Ints(ints)
|
||||
|
||||
for _, i := range ints {
|
||||
values = append(values, args.args[i])
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (args *Args) compileNamed(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}) (string, []interface{}) {
|
||||
i := 1
|
||||
|
||||
for ; i < len(format) && format[i] != '}'; i++ {
|
||||
// Nothing.
|
||||
}
|
||||
|
||||
// Invalid $ format. Ignore it.
|
||||
if i == len(format) {
|
||||
return format, values
|
||||
}
|
||||
|
||||
name := format[1:i]
|
||||
format = format[i+1:]
|
||||
|
||||
if p, ok := args.namedArgs[name]; ok {
|
||||
format, values, _ = args.compileSuccessive(buf, flavor, format, values, p)
|
||||
}
|
||||
|
||||
return format, values
|
||||
}
|
||||
|
||||
func (args *Args) compileDigits(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) {
|
||||
i := 1
|
||||
|
||||
for ; i < len(format) && '0' <= format[i] && format[i] <= '9'; i++ {
|
||||
// Nothing.
|
||||
}
|
||||
|
||||
digits := format[:i]
|
||||
format = format[i:]
|
||||
|
||||
if pointer, err := strconv.Atoi(digits); err == nil {
|
||||
return args.compileSuccessive(buf, flavor, format, values, pointer)
|
||||
}
|
||||
|
||||
return format, values, offset
|
||||
}
|
||||
|
||||
func (args *Args) compileSuccessive(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) {
|
||||
if offset >= len(args.args) {
|
||||
return format, values, offset
|
||||
}
|
||||
|
||||
arg := args.args[offset]
|
||||
values = args.compileArg(buf, flavor, values, arg)
|
||||
|
||||
return format, values, offset + 1
|
||||
}
|
||||
|
||||
func (args *Args) compileArg(buf *bytes.Buffer, flavor Flavor, values []interface{}, arg interface{}) []interface{} {
|
||||
switch a := arg.(type) {
|
||||
case Builder:
|
||||
var s string
|
||||
s, values = a.BuildWithFlavor(flavor, values...)
|
||||
buf.WriteString(s)
|
||||
case sql.NamedArg:
|
||||
buf.WriteRune('@')
|
||||
buf.WriteString(a.Name)
|
||||
case rawArgs:
|
||||
buf.WriteString(a.expr)
|
||||
case listArgs:
|
||||
if len(a.args) > 0 {
|
||||
values = args.compileArg(buf, flavor, values, a.args[0])
|
||||
}
|
||||
|
||||
for i := 1; i < len(a.args); i++ {
|
||||
buf.WriteString(", ")
|
||||
values = args.compileArg(buf, flavor, values, a.args[i])
|
||||
}
|
||||
default:
|
||||
switch flavor {
|
||||
case MySQL:
|
||||
buf.WriteRune('?')
|
||||
case PostgreSQL:
|
||||
fmt.Fprintf(buf, "$%v", len(values)+1)
|
||||
default:
|
||||
panic(fmt.Errorf("Args.CompileWithFlavor: invalid flavor %v (%v)", flavor, int(flavor)))
|
||||
}
|
||||
|
||||
values = append(values, arg)
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
// Copy is
|
||||
func (args *Args) Copy() *Args {
|
||||
return &Args{
|
||||
Flavor: args.Flavor,
|
||||
args: args.args,
|
||||
namedArgs: args.namedArgs,
|
||||
sqlNamedArgs: args.sqlNamedArgs,
|
||||
onlyNamed: args.onlyNamed,
|
||||
}
|
||||
}
|
@ -0,0 +1,76 @@
|
||||
// Copyright 2018 Huan Du. All rights reserved.
|
||||
// Licensed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestArgs(t *testing.T) {
|
||||
cases := map[string][]interface{}{
|
||||
"abc ? def\n[123]": {"abc $? def", 123},
|
||||
"abc ? def\n[456]": {"abc $0 def", 456},
|
||||
"abc def\n[]": {"abc $1 def", 123},
|
||||
"abc ? def\n[789]": {"abc ${s} def", Named("s", 789)},
|
||||
"abc def \n[]": {"abc ${unknown} def ", 123},
|
||||
"abc $ def\n[]": {"abc $$ def", 123},
|
||||
"abcdef\n[]": {"abcdef$", 123},
|
||||
"abc ? ? ? ? def\n[123 456 123 456]": {"abc $? $? $0 $? def", 123, 456, 789},
|
||||
"abc ? raw ? raw def\n[123 123]": {"abc $? $? $0 $? def", 123, Raw("raw"), 789},
|
||||
}
|
||||
|
||||
for expected, c := range cases {
|
||||
args := new(Args)
|
||||
|
||||
for i := 1; i < len(c); i++ {
|
||||
args.Add(c[i])
|
||||
}
|
||||
|
||||
sql, values := args.Compile(c[0].(string))
|
||||
actual := fmt.Sprintf("%v\n%v", sql, values)
|
||||
|
||||
if actual != expected {
|
||||
t.Fatalf("invalid compile result. [expected:%v] [actual:%v]", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
old := DefaultFlavor
|
||||
DefaultFlavor = PostgreSQL
|
||||
defer func() {
|
||||
DefaultFlavor = old
|
||||
}()
|
||||
|
||||
// PostgreSQL flavor compiled sql.
|
||||
for expected, c := range cases {
|
||||
args := new(Args)
|
||||
|
||||
for i := 1; i < len(c); i++ {
|
||||
args.Add(c[i])
|
||||
}
|
||||
|
||||
sql, values := args.Compile(c[0].(string))
|
||||
actual := fmt.Sprintf("%v\n%v", sql, values)
|
||||
expected = toPostgreSQL(expected)
|
||||
|
||||
if actual != expected {
|
||||
t.Fatalf("invalid compile result. [expected:%v] [actual:%v]", expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toPostgreSQL(sql string) string {
|
||||
parts := strings.Split(sql, "?")
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString(parts[0])
|
||||
|
||||
for i, p := range parts[1:] {
|
||||
fmt.Fprintf(buf, "$%v", i+1)
|
||||
buf.WriteString(p)
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
105
app/service/main/account-recovery/dao/sqlbuilder/builder.go
Normal file
105
app/service/main/account-recovery/dao/sqlbuilder/builder.go
Normal file
@ -0,0 +1,105 @@
|
||||
// Copyright 2018 Huan Du. All rights reserved.
|
||||
// Licensed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Builder is a general SQL builder.
|
||||
// It's used by Args to create nested SQL like the `IN` expression in
|
||||
// `SELECT * FROM t1 WHERE id IN (SELECT id FROM t2)`.
|
||||
type Builder interface {
|
||||
Build() (sql string, args []interface{})
|
||||
BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{})
|
||||
}
|
||||
|
||||
type compiledBuilder struct {
|
||||
args *Args
|
||||
format string
|
||||
}
|
||||
|
||||
func (cb *compiledBuilder) Build() (sql string, args []interface{}) {
|
||||
return cb.args.Compile(cb.format)
|
||||
}
|
||||
|
||||
func (cb *compiledBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {
|
||||
return cb.args.CompileWithFlavor(cb.format, flavor, initialArg...)
|
||||
}
|
||||
|
||||
type flavoredBuilder struct {
|
||||
builder Builder
|
||||
flavor Flavor
|
||||
}
|
||||
|
||||
func (fb *flavoredBuilder) Build() (sql string, args []interface{}) {
|
||||
return fb.builder.BuildWithFlavor(fb.flavor)
|
||||
}
|
||||
|
||||
func (fb *flavoredBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {
|
||||
return fb.builder.BuildWithFlavor(flavor, initialArg...)
|
||||
}
|
||||
|
||||
// WithFlavor creates a new Builder based on builder with a default flavor.
|
||||
func WithFlavor(builder Builder, flavor Flavor) Builder {
|
||||
return &flavoredBuilder{
|
||||
builder: builder,
|
||||
flavor: flavor,
|
||||
}
|
||||
}
|
||||
|
||||
// Buildf creates a Builder from a format string using `fmt.Sprintf`-like syntax.
|
||||
// As all arguments will be converted to a string internally, e.g. "$0",
|
||||
// only `%v` and `%s` are valid.
|
||||
func Buildf(format string, arg ...interface{}) Builder {
|
||||
args := &Args{
|
||||
Flavor: DefaultFlavor,
|
||||
}
|
||||
vars := make([]interface{}, 0, len(arg))
|
||||
|
||||
for _, a := range arg {
|
||||
vars = append(vars, args.Add(a))
|
||||
}
|
||||
|
||||
return &compiledBuilder{
|
||||
args: args,
|
||||
format: fmt.Sprintf(Escape(format), vars...),
|
||||
}
|
||||
}
|
||||
|
||||
// Build creates a Builder from a format string.
|
||||
// The format string uses special syntax to represent arguments.
|
||||
// See doc in `Args#Compile` for syntax details.
|
||||
func Build(format string, arg ...interface{}) Builder {
|
||||
args := &Args{
|
||||
Flavor: DefaultFlavor,
|
||||
}
|
||||
|
||||
for _, a := range arg {
|
||||
args.Add(a)
|
||||
}
|
||||
|
||||
return &compiledBuilder{
|
||||
args: args,
|
||||
format: format,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildNamed creates a Builder from a format string.
|
||||
// The format string uses `${key}` to refer the value of named by key.
|
||||
func BuildNamed(format string, named map[string]interface{}) Builder {
|
||||
args := &Args{
|
||||
Flavor: DefaultFlavor,
|
||||
onlyNamed: true,
|
||||
}
|
||||
|
||||
for n, v := range named {
|
||||
args.Add(Named(n, v))
|
||||
}
|
||||
|
||||
return &compiledBuilder{
|
||||
args: args,
|
||||
format: format,
|
||||
}
|
||||
}
|
113
app/service/main/account-recovery/dao/sqlbuilder/builder_test.go
Normal file
113
app/service/main/account-recovery/dao/sqlbuilder/builder_test.go
Normal file
@ -0,0 +1,113 @@
|
||||
// Copyright 2018 Huan Du. All rights reserved.
|
||||
// Licensed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func ExampleBuildf() {
|
||||
sb := NewSelectBuilder()
|
||||
sb.Select("id").From("user")
|
||||
|
||||
explain := Buildf("EXPLAIN %v LEFT JOIN SELECT * FROM banned WHERE state IN (%v, %v)", sb, 1, 2)
|
||||
sql, args := explain.Build()
|
||||
fmt.Println(sql)
|
||||
fmt.Println(args)
|
||||
|
||||
// Output:
|
||||
// EXPLAIN SELECT id FROM user LEFT JOIN SELECT * FROM banned WHERE state IN (?, ?)
|
||||
// [1 2]
|
||||
}
|
||||
|
||||
func ExampleBuild() {
|
||||
sb := NewSelectBuilder()
|
||||
sb.Select("id").From("user").Where(sb.In("status", 1, 2))
|
||||
|
||||
b := Build("EXPLAIN $? LEFT JOIN SELECT * FROM $? WHERE created_at > $? AND state IN (${states}) AND modified_at BETWEEN $2 AND $?",
|
||||
sb, Raw("banned"), 1514458225, 1514544625, Named("states", List([]int{3, 4, 5})))
|
||||
sql, args := b.Build()
|
||||
|
||||
fmt.Println(sql)
|
||||
fmt.Println(args)
|
||||
|
||||
// Output:
|
||||
// EXPLAIN SELECT id FROM user WHERE status IN (?, ?) LEFT JOIN SELECT * FROM banned WHERE created_at > ? AND state IN (?, ?, ?) AND modified_at BETWEEN ? AND ?
|
||||
// [1 2 1514458225 3 4 5 1514458225 1514544625]
|
||||
}
|
||||
|
||||
func ExampleBuildNamed() {
|
||||
b := BuildNamed("SELECT * FROM ${table} WHERE status IN (${status}) AND name LIKE ${name} AND created_at > ${time} AND modified_at < ${time} + 86400",
|
||||
map[string]interface{}{
|
||||
"time": sql.Named("start", 1234567890),
|
||||
"status": List([]int{1, 2, 5}),
|
||||
"name": "Huan%",
|
||||
"table": Raw("user"),
|
||||
})
|
||||
sql, args := b.Build()
|
||||
|
||||
fmt.Println(sql)
|
||||
fmt.Println(args)
|
||||
|
||||
// Output:
|
||||
// SELECT * FROM user WHERE status IN (?, ?, ?) AND name LIKE ? AND created_at > @start AND modified_at < @start + 86400
|
||||
// [1 2 5 Huan% {{} start 1234567890}]
|
||||
}
|
||||
|
||||
func ExampleWithFlavor() {
|
||||
sql, args := WithFlavor(Buildf("SELECT * FROM foo WHERE id = %v", 1234), PostgreSQL).Build()
|
||||
|
||||
fmt.Println(sql)
|
||||
fmt.Println(args)
|
||||
|
||||
// Explicitly use MySQL as the flavor.
|
||||
sql, args = WithFlavor(Buildf("SELECT * FROM foo WHERE id = %v", 1234), PostgreSQL).BuildWithFlavor(MySQL)
|
||||
|
||||
fmt.Println(sql)
|
||||
fmt.Println(args)
|
||||
|
||||
// Output:
|
||||
// SELECT * FROM foo WHERE id = $1
|
||||
// [1234]
|
||||
// SELECT * FROM foo WHERE id = ?
|
||||
// [1234]
|
||||
}
|
||||
|
||||
func TestBuildWithPostgreSQL(t *testing.T) {
|
||||
sb1 := PostgreSQL.NewSelectBuilder()
|
||||
sb1.Select("col1", "col2").From("t1").Where(sb1.E("id", 1234), sb1.G("level", 2))
|
||||
|
||||
sb2 := PostgreSQL.NewSelectBuilder()
|
||||
sb2.Select("col3", "col4").From("t2").Where(sb2.E("id", 4567), sb2.LE("level", 5))
|
||||
|
||||
// Use DefaultFlavor (MySQL) instead of PostgreSQL.
|
||||
sql, args := Build("SELECT $1 AS col5 LEFT JOIN $0 LEFT JOIN $2", sb1, 7890, sb2).Build()
|
||||
|
||||
if expected := "SELECT ? AS col5 LEFT JOIN SELECT col1, col2 FROM t1 WHERE id = ? AND level > ? LEFT JOIN SELECT col3, col4 FROM t2 WHERE id = ? AND level <= ?"; sql != expected {
|
||||
t.Fatalf("invalid sql. [expected:%v] [actual:%v]", expected, sql)
|
||||
}
|
||||
|
||||
if expected := []interface{}{7890, 1234, 2, 4567, 5}; !reflect.DeepEqual(args, expected) {
|
||||
t.Fatalf("invalid args. [expected:%v] [actual:%v]", expected, args)
|
||||
}
|
||||
|
||||
old := DefaultFlavor
|
||||
DefaultFlavor = PostgreSQL
|
||||
defer func() {
|
||||
DefaultFlavor = old
|
||||
}()
|
||||
|
||||
sql, args = Build("SELECT $1 AS col5 LEFT JOIN $0 LEFT JOIN $2", sb1, 7890, sb2).Build()
|
||||
|
||||
if expected := "SELECT $1 AS col5 LEFT JOIN SELECT col1, col2 FROM t1 WHERE id = $2 AND level > $3 LEFT JOIN SELECT col3, col4 FROM t2 WHERE id = $4 AND level <= $5"; sql != expected {
|
||||
t.Fatalf("invalid sql. [expected:%v] [actual:%v]", expected, sql)
|
||||
}
|
||||
|
||||
if expected := []interface{}{7890, 1234, 2, 4567, 5}; !reflect.DeepEqual(args, expected) {
|
||||
t.Fatalf("invalid args. [expected:%v] [actual:%v]", expected, args)
|
||||
}
|
||||
}
|
136
app/service/main/account-recovery/dao/sqlbuilder/cond.go
Normal file
136
app/service/main/account-recovery/dao/sqlbuilder/cond.go
Normal file
@ -0,0 +1,136 @@
|
||||
// Copyright 2018 Huan Du. All rights reserved.
|
||||
// Licensed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Cond provides several helper methods to build conditions.
|
||||
type Cond struct {
|
||||
Args *Args
|
||||
}
|
||||
|
||||
// Equal represents "field = value".
|
||||
func (c *Cond) Equal(field string, value interface{}) string {
|
||||
return fmt.Sprintf("%v = %v", Escape(field), c.Args.Add(value))
|
||||
}
|
||||
|
||||
// E is an alias of Equal.
|
||||
func (c *Cond) E(field string, value interface{}) string {
|
||||
return c.Equal(field, value)
|
||||
}
|
||||
|
||||
// NotEqual represents "field != value".
|
||||
func (c *Cond) NotEqual(field string, value interface{}) string {
|
||||
return fmt.Sprintf("%v <> %v", Escape(field), c.Args.Add(value))
|
||||
}
|
||||
|
||||
// NE is an alias of NotEqual.
|
||||
func (c *Cond) NE(field string, value interface{}) string {
|
||||
return c.NotEqual(field, value)
|
||||
}
|
||||
|
||||
// GreaterThan represents "field > value".
|
||||
func (c *Cond) GreaterThan(field string, value interface{}) string {
|
||||
return fmt.Sprintf("%v > %v", Escape(field), c.Args.Add(value))
|
||||
}
|
||||
|
||||
// G is an alias of GreaterThan.
|
||||
func (c *Cond) G(field string, value interface{}) string {
|
||||
return c.GreaterThan(field, value)
|
||||
}
|
||||
|
||||
// GreaterEqualThan represents "field >= value".
|
||||
func (c *Cond) GreaterEqualThan(field string, value interface{}) string {
|
||||
return fmt.Sprintf("%v >= %v", Escape(field), c.Args.Add(value))
|
||||
}
|
||||
|
||||
// GE is an alias of GreaterEqualThan.
|
||||
func (c *Cond) GE(field string, value interface{}) string {
|
||||
return c.GreaterEqualThan(field, value)
|
||||
}
|
||||
|
||||
// LessThan represents "field < value".
|
||||
func (c *Cond) LessThan(field string, value interface{}) string {
|
||||
return fmt.Sprintf("%v < %v", Escape(field), c.Args.Add(value))
|
||||
}
|
||||
|
||||
// L is an alias of LessThan.
|
||||
func (c *Cond) L(field string, value interface{}) string {
|
||||
return c.LessThan(field, value)
|
||||
}
|
||||
|
||||
// LessEqualThan represents "field <= value".
|
||||
func (c *Cond) LessEqualThan(field string, value interface{}) string {
|
||||
return fmt.Sprintf("%v <= %v", Escape(field), c.Args.Add(value))
|
||||
}
|
||||
|
||||
// LE is an alias of LessEqualThan.
|
||||
func (c *Cond) LE(field string, value interface{}) string {
|
||||
return c.LessEqualThan(field, value)
|
||||
}
|
||||
|
||||
// In represents "field IN (value...)".
|
||||
func (c *Cond) In(field string, value ...interface{}) string {
|
||||
vs := make([]string, 0, len(value))
|
||||
|
||||
for _, v := range value {
|
||||
vs = append(vs, c.Args.Add(v))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v IN (%v)", Escape(field), strings.Join(vs, ", "))
|
||||
}
|
||||
|
||||
// NotIn represents "field NOT IN (value...)".
|
||||
func (c *Cond) NotIn(field string, value ...interface{}) string {
|
||||
vs := make([]string, 0, len(value))
|
||||
|
||||
for _, v := range value {
|
||||
vs = append(vs, c.Args.Add(v))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v NOT IN (%v)", Escape(field), strings.Join(vs, ", "))
|
||||
}
|
||||
|
||||
// Like represents "field LIKE value".
|
||||
func (c *Cond) Like(field string, value interface{}) string {
|
||||
return fmt.Sprintf("%v LIKE %v", Escape(field), c.Args.Add(value))
|
||||
}
|
||||
|
||||
// NotLike represents "field NOT LIKE value".
|
||||
func (c *Cond) NotLike(field string, value interface{}) string {
|
||||
return fmt.Sprintf("%v NOT LIKE %v", Escape(field), c.Args.Add(value))
|
||||
}
|
||||
|
||||
// IsNull represents "field IS NULL".
|
||||
func (c *Cond) IsNull(field string) string {
|
||||
return fmt.Sprintf("%v IS NULL", Escape(field))
|
||||
}
|
||||
|
||||
// IsNotNull represents "field IS NOT NULL".
|
||||
func (c *Cond) IsNotNull(field string) string {
|
||||
return fmt.Sprintf("%v IS NOT NULL", Escape(field))
|
||||
}
|
||||
|
||||
// Between represents "field BETWEEN lower AND upper".
|
||||
func (c *Cond) Between(field string, lower, upper interface{}) string {
|
||||
return fmt.Sprintf("%v BETWEEN %v AND %v", Escape(field), c.Args.Add(lower), c.Args.Add(upper))
|
||||
}
|
||||
|
||||
// NotBetween represents "field NOT BETWEEN lower AND upper".
|
||||
func (c *Cond) NotBetween(field string, lower, upper interface{}) string {
|
||||
return fmt.Sprintf("%v NOT BETWEEN %v AND %v", Escape(field), c.Args.Add(lower), c.Args.Add(upper))
|
||||
}
|
||||
|
||||
// Or represents OR logic like "expr1 OR expr2 OR expr3".
|
||||
func (c *Cond) Or(orExpr ...string) string {
|
||||
return fmt.Sprintf("(%v)", strings.Join(orExpr, " OR "))
|
||||
}
|
||||
|
||||
// Var returns a placeholder for value.
|
||||
func (c *Cond) Var(value interface{}) string {
|
||||
return c.Args.Add(value)
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
// Copyright 2018 Huan Du. All rights reserved.
|
||||
// Licensed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCond(t *testing.T) {
|
||||
cases := map[string]func() string{
|
||||
"$$a = $0": func() string { return newTestCond().Equal("$a", 123) },
|
||||
"$$b = $0": func() string { return newTestCond().E("$b", 123) },
|
||||
"$$a <> $0": func() string { return newTestCond().NotEqual("$a", 123) },
|
||||
"$$b <> $0": func() string { return newTestCond().NE("$b", 123) },
|
||||
"$$a > $0": func() string { return newTestCond().GreaterThan("$a", 123) },
|
||||
"$$b > $0": func() string { return newTestCond().G("$b", 123) },
|
||||
"$$a >= $0": func() string { return newTestCond().GreaterEqualThan("$a", 123) },
|
||||
"$$b >= $0": func() string { return newTestCond().GE("$b", 123) },
|
||||
"$$a < $0": func() string { return newTestCond().LessThan("$a", 123) },
|
||||
"$$b < $0": func() string { return newTestCond().L("$b", 123) },
|
||||
"$$a <= $0": func() string { return newTestCond().LessEqualThan("$a", 123) },
|
||||
"$$b <= $0": func() string { return newTestCond().LE("$b", 123) },
|
||||
"$$a IN ($0, $1, $2)": func() string { return newTestCond().In("$a", 1, 2, 3) },
|
||||
"$$a NOT IN ($0, $1, $2)": func() string { return newTestCond().NotIn("$a", 1, 2, 3) },
|
||||
"$$a LIKE $0": func() string { return newTestCond().Like("$a", "%Huan%") },
|
||||
"$$a NOT LIKE $0": func() string { return newTestCond().NotLike("$a", "%Huan%") },
|
||||
"$$a IS NULL": func() string { return newTestCond().IsNull("$a") },
|
||||
"$$a IS NOT NULL": func() string { return newTestCond().IsNotNull("$a") },
|
||||
"$$a BETWEEN $0 AND $1": func() string { return newTestCond().Between("$a", 123, 456) },
|
||||
"$$a NOT BETWEEN $0 AND $1": func() string { return newTestCond().NotBetween("$a", 123, 456) },
|
||||
"(1 = 1 OR 2 = 2 OR 3 = 3)": func() string { return newTestCond().Or("1 = 1", "2 = 2", "3 = 3") },
|
||||
"$0": func() string { return newTestCond().Var(123) },
|
||||
}
|
||||
|
||||
for expected, f := range cases {
|
||||
if actual := f(); expected != actual {
|
||||
t.Fatalf("invalid result. [expected:%v] [actual:%v]", expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newTestCond() *Cond {
|
||||
return &Cond{
|
||||
Args: &Args{},
|
||||
}
|
||||
}
|
57
app/service/main/account-recovery/dao/sqlbuilder/flavor.go
Normal file
57
app/service/main/account-recovery/dao/sqlbuilder/flavor.go
Normal file
@ -0,0 +1,57 @@
|
||||
// Copyright 2018 Huan Du. All rights reserved.
|
||||
// Licensed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
package sqlbuilder
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Supported flavors.
|
||||
const (
|
||||
invalidFlavor Flavor = iota
|
||||
|
||||
MySQL
|
||||
PostgreSQL
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultFlavor is the default flavor for all builders.
|
||||
DefaultFlavor = MySQL
|
||||
)
|
||||
|
||||
// Flavor is the flag to control the format of compiled sql.
|
||||
type Flavor int
|
||||
|
||||
// String returns the name of f.
|
||||
func (f Flavor) String() string {
|
||||
switch f {
|
||||
case MySQL:
|
||||
return "MySQL"
|
||||
case PostgreSQL:
|
||||
return "PostgreSQL"
|
||||
}
|
||||
|
||||
return "<invalid>"
|
||||
}
|
||||
|
||||
// NewSelectBuilder creates a new SELECT builder with flavor.
|
||||
func (f Flavor) NewSelectBuilder() *SelectBuilder {
|
||||
b := newSelectBuilder()
|
||||
b.SetFlavor(f)
|
||||
return b
|
||||
}
|
||||
|
||||
// Quote adds quote for name to make sure the name can be used safely
|
||||
// as table name or field name.
|
||||
//
|
||||
// * For MySQL, use back quote (`) to quote name;
|
||||
// * For PostgreSQL, use double quote (") to quote name.
|
||||
func (f Flavor) Quote(name string) string {
|
||||
switch f {
|
||||
case MySQL:
|
||||
return fmt.Sprintf("`%v`", name)
|
||||
case PostgreSQL:
|
||||
return fmt.Sprintf(`"%v"`, name)
|
||||
}
|
||||
|
||||
return name
|
||||
}
|
@ -0,0 +1,40 @@
|
||||
// Copyright 2018 Huan Du. All rights reserved.
|
||||
// Licensed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFlavor(t *testing.T) {
|
||||
cases := map[Flavor]string{
|
||||
0: "<invalid>",
|
||||
MySQL: "MySQL",
|
||||
PostgreSQL: "PostgreSQL",
|
||||
}
|
||||
|
||||
for f, expected := range cases {
|
||||
if actual := f.String(); actual != expected {
|
||||
t.Fatalf("invalid flavor name. [expected:%v] [actual:%v]", expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleFlavor() {
|
||||
// Create a flavored builder.
|
||||
sb := PostgreSQL.NewSelectBuilder()
|
||||
sb.Select("name").From("user").Where(
|
||||
sb.E("id", 1234),
|
||||
sb.G("rank", 3),
|
||||
)
|
||||
sql, args := sb.Build()
|
||||
|
||||
fmt.Println(sql)
|
||||
fmt.Println(args)
|
||||
|
||||
// Output:
|
||||
// SELECT name FROM user WHERE id = $1 AND rank > $2
|
||||
// [1234 3]
|
||||
}
|
@ -0,0 +1,98 @@
|
||||
// Copyright 2018 Huan Du. All rights reserved.
|
||||
// Licensed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Escape replaces `$` with `$$` in ident.
|
||||
func Escape(ident string) string {
|
||||
return strings.Replace(ident, "$", "$$", -1)
|
||||
}
|
||||
|
||||
// EscapeAll replaces `$` with `$$` in all strings of ident.
|
||||
func EscapeAll(ident ...string) []string {
|
||||
escaped := make([]string, 0, len(ident))
|
||||
|
||||
for _, i := range ident {
|
||||
escaped = append(escaped, Escape(i))
|
||||
}
|
||||
|
||||
return escaped
|
||||
}
|
||||
|
||||
// Flatten recursively extracts values in slices and returns
|
||||
// a flattened []interface{} with all values.
|
||||
// If slices is not a slice, return `[]interface{}{slices}`.
|
||||
func Flatten(slices interface{}) (flattened []interface{}) {
|
||||
v := reflect.ValueOf(slices)
|
||||
slices, flattened = flatten(v)
|
||||
|
||||
if slices != nil {
|
||||
return []interface{}{slices}
|
||||
}
|
||||
|
||||
return flattened
|
||||
}
|
||||
|
||||
func flatten(v reflect.Value) (elem interface{}, flattened []interface{}) {
|
||||
k := v.Kind()
|
||||
|
||||
for k == reflect.Interface {
|
||||
v = v.Elem()
|
||||
k = v.Kind()
|
||||
}
|
||||
|
||||
if k != reflect.Slice && k != reflect.Array {
|
||||
return v.Interface(), nil
|
||||
}
|
||||
|
||||
for i, l := 0, v.Len(); i < l; i++ {
|
||||
e, f := flatten(v.Index(i))
|
||||
|
||||
if e == nil {
|
||||
flattened = append(flattened, f...)
|
||||
} else {
|
||||
flattened = append(flattened, e)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type rawArgs struct {
|
||||
expr string
|
||||
}
|
||||
|
||||
// Raw marks the expr as a raw value which will not be added to args.
|
||||
func Raw(expr string) interface{} {
|
||||
return rawArgs{expr}
|
||||
}
|
||||
|
||||
type listArgs struct {
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
// List marks arg as a list of data.
|
||||
// If arg is `[]int{1, 2, 3}`, it will be compiled to `?, ?, ?` with args `[1 2 3]`.
|
||||
func List(arg interface{}) interface{} {
|
||||
return listArgs{Flatten(arg)}
|
||||
}
|
||||
|
||||
type namedArgs struct {
|
||||
name string
|
||||
arg interface{}
|
||||
}
|
||||
|
||||
// Named creates a named argument.
|
||||
// Unlike `sql.Named`, this named argument works only with `Build` or `BuildNamed` for convenience
|
||||
// and will be replaced to a `?` after `Compile`.
|
||||
func Named(name string, arg interface{}) interface{} {
|
||||
return namedArgs{
|
||||
name: name,
|
||||
arg: arg,
|
||||
}
|
||||
}
|
@ -0,0 +1,59 @@
|
||||
// Copyright 2018 Huan Du. All rights reserved.
|
||||
// Licensed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEscape(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"foo": "foo",
|
||||
"$foo": "$$foo",
|
||||
"$$$": "$$$$$$",
|
||||
}
|
||||
var inputs, expects []string
|
||||
|
||||
for s, expected := range cases {
|
||||
inputs = append(inputs, s)
|
||||
expects = append(expects, expected)
|
||||
|
||||
if actual := Escape(s); actual != expected {
|
||||
t.Fatalf("invalid escape result. [expected:%v] [actual:%v]", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
actuals := EscapeAll(inputs...)
|
||||
|
||||
if !reflect.DeepEqual(expects, actuals) {
|
||||
t.Fatalf("invalid escape result. [expected:%v] [actual:%v]", expects, actuals)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlatten(t *testing.T) {
|
||||
cases := [][2]interface{}{
|
||||
{
|
||||
"foo",
|
||||
[]interface{}{"foo"},
|
||||
},
|
||||
{
|
||||
[]int{1, 2, 3},
|
||||
[]interface{}{1, 2, 3},
|
||||
},
|
||||
{
|
||||
[]interface{}{"abc", []int{1, 2, 3}, [3]string{"def", "ghi"}},
|
||||
[]interface{}{"abc", 1, 2, 3, "def", "ghi", ""},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
input, expected := c[0], c[1]
|
||||
actual := Flatten(input)
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Fatalf("invalid flatten result. [expected:%v] [actual:%v]", expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
212
app/service/main/account-recovery/dao/sqlbuilder/select.go
Normal file
212
app/service/main/account-recovery/dao/sqlbuilder/select.go
Normal file
@ -0,0 +1,212 @@
|
||||
// Copyright 2018 Huan Du. All rights reserved.
|
||||
// Licensed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// NewSelectBuilder creates a new SELECT builder.
|
||||
func NewSelectBuilder() *SelectBuilder {
|
||||
return DefaultFlavor.NewSelectBuilder()
|
||||
}
|
||||
|
||||
func newSelectBuilder() *SelectBuilder {
|
||||
args := &Args{}
|
||||
return &SelectBuilder{
|
||||
Cond: Cond{
|
||||
Args: args,
|
||||
},
|
||||
limit: -1,
|
||||
offset: -1,
|
||||
args: args,
|
||||
}
|
||||
}
|
||||
|
||||
// SelectBuilder is a builder to build SELECT.
|
||||
type SelectBuilder struct {
|
||||
Cond
|
||||
|
||||
distinct bool
|
||||
tables []string
|
||||
selectCols []string
|
||||
whereExprs []string
|
||||
havingExprs []string
|
||||
groupByCols []string
|
||||
orderByCols []string
|
||||
order string
|
||||
limit int
|
||||
offset int
|
||||
|
||||
args *Args
|
||||
}
|
||||
|
||||
// Distinct marks this SELECT as DISTINCT.
|
||||
func (sb *SelectBuilder) Distinct() *SelectBuilder {
|
||||
sb.distinct = true
|
||||
return sb
|
||||
}
|
||||
|
||||
// Select sets columns in SELECT.
|
||||
func (sb *SelectBuilder) Select(col ...string) *SelectBuilder {
|
||||
sb.selectCols = EscapeAll(col...)
|
||||
return sb
|
||||
}
|
||||
|
||||
// From sets table names in SELECT.
|
||||
func (sb *SelectBuilder) From(table ...string) *SelectBuilder {
|
||||
sb.tables = table
|
||||
return sb
|
||||
}
|
||||
|
||||
// Where sets expressions of WHERE in SELECT.
|
||||
func (sb *SelectBuilder) Where(andExpr ...string) *SelectBuilder {
|
||||
sb.whereExprs = append(sb.whereExprs, andExpr...)
|
||||
return sb
|
||||
}
|
||||
|
||||
// Having sets expressions of HAVING in SELECT.
|
||||
func (sb *SelectBuilder) Having(andExpr ...string) *SelectBuilder {
|
||||
sb.havingExprs = append(sb.havingExprs, andExpr...)
|
||||
return sb
|
||||
}
|
||||
|
||||
// GroupBy sets columns of GROUP BY in SELECT.
|
||||
func (sb *SelectBuilder) GroupBy(col ...string) *SelectBuilder {
|
||||
sb.groupByCols = EscapeAll(col...)
|
||||
return sb
|
||||
}
|
||||
|
||||
// OrderBy sets columns of ORDER BY in SELECT.
|
||||
func (sb *SelectBuilder) OrderBy(col ...string) *SelectBuilder {
|
||||
sb.orderByCols = EscapeAll(col...)
|
||||
return sb
|
||||
}
|
||||
|
||||
// Asc sets order of ORDER BY to ASC.
|
||||
func (sb *SelectBuilder) Asc() *SelectBuilder {
|
||||
sb.order = "ASC"
|
||||
return sb
|
||||
}
|
||||
|
||||
// Desc sets order of ORDER BY to DESC.
|
||||
func (sb *SelectBuilder) Desc() *SelectBuilder {
|
||||
sb.order = "DESC"
|
||||
return sb
|
||||
}
|
||||
|
||||
// Limit sets the LIMIT in SELECT.
|
||||
func (sb *SelectBuilder) Limit(limit int) *SelectBuilder {
|
||||
sb.limit = limit
|
||||
return sb
|
||||
}
|
||||
|
||||
// Offset sets the LIMIT offset in SELECT.
|
||||
func (sb *SelectBuilder) Offset(offset int) *SelectBuilder {
|
||||
sb.offset = offset
|
||||
return sb
|
||||
}
|
||||
|
||||
// As returns an AS expression.
|
||||
func (sb *SelectBuilder) As(col, alias string) string {
|
||||
return fmt.Sprintf("%v AS %v", col, Escape(alias))
|
||||
}
|
||||
|
||||
// BuilderAs returns an AS expression wrapping a complex SQL.
|
||||
// According to SQL syntax, SQL built by builder is surrounded by parens.
|
||||
func (sb *SelectBuilder) BuilderAs(builder Builder, alias string) string {
|
||||
return fmt.Sprintf("(%v) AS %v", sb.Var(builder), Escape(alias))
|
||||
}
|
||||
|
||||
// String returns the compiled SELECT string.
|
||||
func (sb *SelectBuilder) String() string {
|
||||
s, _ := sb.Build()
|
||||
return s
|
||||
}
|
||||
|
||||
// Build returns compiled SELECT string and args.
|
||||
// They can be used in `DB#Query` of package `database/sql` directly.
|
||||
func (sb *SelectBuilder) Build() (sql string, args []interface{}) {
|
||||
return sb.BuildWithFlavor(sb.args.Flavor)
|
||||
}
|
||||
|
||||
// BuildWithFlavor returns compiled SELECT string and args with flavor and initial args.
|
||||
// They can be used in `DB#Query` of package `database/sql` directly.
|
||||
func (sb *SelectBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString("SELECT ")
|
||||
|
||||
if sb.distinct {
|
||||
buf.WriteString("DISTINCT ")
|
||||
}
|
||||
|
||||
buf.WriteString(strings.Join(sb.selectCols, ", "))
|
||||
buf.WriteString(" FROM ")
|
||||
buf.WriteString(strings.Join(sb.tables, ", "))
|
||||
|
||||
if len(sb.whereExprs) > 0 {
|
||||
buf.WriteString(" WHERE ")
|
||||
buf.WriteString(strings.Join(sb.whereExprs, " AND "))
|
||||
}
|
||||
|
||||
if len(sb.groupByCols) > 0 {
|
||||
buf.WriteString(" GROUP BY ")
|
||||
buf.WriteString(strings.Join(sb.groupByCols, ", "))
|
||||
|
||||
if len(sb.havingExprs) > 0 {
|
||||
buf.WriteString(" HAVING ")
|
||||
buf.WriteString(strings.Join(sb.havingExprs, " AND "))
|
||||
}
|
||||
}
|
||||
|
||||
if len(sb.orderByCols) > 0 {
|
||||
buf.WriteString(" ORDER BY ")
|
||||
buf.WriteString(strings.Join(sb.orderByCols, ", "))
|
||||
|
||||
if sb.order != "" {
|
||||
buf.WriteRune(' ')
|
||||
buf.WriteString(sb.order)
|
||||
}
|
||||
}
|
||||
|
||||
if sb.limit >= 0 {
|
||||
buf.WriteString(" LIMIT ")
|
||||
buf.WriteString(strconv.Itoa(sb.limit))
|
||||
|
||||
if sb.offset >= 0 {
|
||||
buf.WriteString(" OFFSET ")
|
||||
buf.WriteString(strconv.Itoa(sb.offset))
|
||||
}
|
||||
}
|
||||
|
||||
return sb.Args.CompileWithFlavor(buf.String(), flavor, initialArg...)
|
||||
}
|
||||
|
||||
// SetFlavor sets the flavor of compiled sql.
|
||||
func (sb *SelectBuilder) SetFlavor(flavor Flavor) (old Flavor) {
|
||||
old = sb.args.Flavor
|
||||
sb.args.Flavor = flavor
|
||||
return
|
||||
}
|
||||
|
||||
// Copy the builder
|
||||
func (sb *SelectBuilder) Copy() *SelectBuilder {
|
||||
return &SelectBuilder{
|
||||
Cond: sb.Cond,
|
||||
distinct: sb.distinct,
|
||||
tables: sb.tables,
|
||||
selectCols: sb.selectCols,
|
||||
whereExprs: sb.whereExprs,
|
||||
havingExprs: sb.havingExprs,
|
||||
groupByCols: sb.groupByCols,
|
||||
orderByCols: sb.orderByCols,
|
||||
order: sb.order,
|
||||
limit: sb.limit,
|
||||
offset: sb.offset,
|
||||
args: sb.args.Copy(),
|
||||
}
|
||||
}
|
@ -0,0 +1,67 @@
|
||||
// Copyright 2018 Huan Du. All rights reserved.
|
||||
// Licensed under the MIT license that can be found in the LICENSE file.
|
||||
|
||||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func ExampleSelectBuilder() {
|
||||
sb := NewSelectBuilder()
|
||||
sb.Distinct().Select("id", "name", sb.As("COUNT(*)", "t"))
|
||||
sb.From("demo.user")
|
||||
sb.Where(
|
||||
sb.GreaterThan("id", 1234),
|
||||
sb.Like("name", "%Du"),
|
||||
sb.Or(
|
||||
sb.IsNull("id_card"),
|
||||
sb.In("status", 1, 2, 5),
|
||||
),
|
||||
sb.NotIn(
|
||||
"id",
|
||||
NewSelectBuilder().Select("id").From("banned"),
|
||||
), // Nested SELECT.
|
||||
"modified_at > created_at + "+sb.Var(86400), // It's allowed to write arbitrary SQL.
|
||||
)
|
||||
sb.GroupBy("status").Having(sb.NotIn("status", 4, 5))
|
||||
sb.OrderBy("modified_at").Asc()
|
||||
sb.Limit(10).Offset(5)
|
||||
|
||||
sql, args := sb.Build()
|
||||
fmt.Println(sql)
|
||||
fmt.Println(args)
|
||||
|
||||
// Output:
|
||||
// SELECT DISTINCT id, name, COUNT(*) AS t FROM demo.user WHERE id > ? AND name LIKE ? AND (id_card IS NULL OR status IN (?, ?, ?)) AND id NOT IN (SELECT id FROM banned) AND modified_at > created_at + ? GROUP BY status HAVING status NOT IN (?, ?) ORDER BY modified_at ASC LIMIT 10 OFFSET 5
|
||||
// [1234 %Du 1 2 5 86400 4 5]
|
||||
}
|
||||
|
||||
func ExampleSelectBuilder_advancedUsage() {
|
||||
sb := NewSelectBuilder()
|
||||
innerSb := NewSelectBuilder()
|
||||
|
||||
sb.Select("id", "name")
|
||||
sb.From(
|
||||
sb.BuilderAs(innerSb, "user"),
|
||||
)
|
||||
sb.Where(
|
||||
sb.In("status", Flatten([]int{1, 2, 3})...),
|
||||
sb.Between("created_at", sql.Named("start", 1234567890), sql.Named("end", 1234599999)),
|
||||
)
|
||||
|
||||
innerSb.Select("*")
|
||||
innerSb.From("banned")
|
||||
innerSb.Where(
|
||||
innerSb.NotIn("name", Flatten([]string{"Huan Du", "Charmy Liu"})...),
|
||||
)
|
||||
|
||||
sql, args := sb.Build()
|
||||
fmt.Println(sql)
|
||||
fmt.Println(args)
|
||||
|
||||
// Output:
|
||||
// SELECT id, name FROM (SELECT * FROM banned WHERE name NOT IN (?, ?)) AS user WHERE status IN (?, ?, ?) AND created_at BETWEEN @start AND @end
|
||||
// [Huan Du Charmy Liu 1 2 3 {{} start 1234567890} {{} end 1234599999}]
|
||||
}
|
140
app/service/main/account-recovery/dao/user_act_log.go
Normal file
140
app/service/main/account-recovery/dao/user_act_log.go
Normal file
@ -0,0 +1,140 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"go-common/app/service/main/account-recovery/model"
|
||||
"go-common/library/database/elastic"
|
||||
"go-common/library/log"
|
||||
)
|
||||
|
||||
// NickNameLog NickNameLog
|
||||
func (d *Dao) NickNameLog(c context.Context, nickNameReq *model.NickNameReq) (res *model.NickNameLogRes, err error) {
|
||||
nowYear := time.Now().Year()
|
||||
index1 := "log_user_action_14_" + strconv.Itoa(nowYear)
|
||||
index2 := "log_user_action_14_" + strconv.Itoa(nowYear-1)
|
||||
|
||||
r := d.es.NewRequest("log_user_action").Fields("str_0", "str_1").Index(index1, index2)
|
||||
r.Order("ctime", elastic.OrderDesc).Order("mid", elastic.OrderDesc).Pn(nickNameReq.Page).Ps(nickNameReq.Size)
|
||||
|
||||
if nickNameReq.Mid != 0 {
|
||||
r.WhereEq("mid", nickNameReq.Mid)
|
||||
}
|
||||
|
||||
if nickNameReq.From != 0 && nickNameReq.To != 0 {
|
||||
ftm := time.Unix(nickNameReq.From, 0)
|
||||
sf := ftm.Format("2006-01-02 15:04:05")
|
||||
|
||||
ttm := time.Unix(nickNameReq.To, 0)
|
||||
tf := ttm.Format("2006-01-02 15:04:05")
|
||||
|
||||
r.WhereRange("ctime", sf, tf, elastic.RangeScopeLoRo)
|
||||
}
|
||||
|
||||
esres := new(model.NickESRes)
|
||||
if err = r.Scan(context.TODO(), &esres); err != nil {
|
||||
log.Error("nickNameLog search error(%v)", err)
|
||||
return
|
||||
}
|
||||
|
||||
var nickNames = make([]*model.NickNameInfo, 0)
|
||||
for _, value := range esres.Result {
|
||||
ulr := model.NickNameInfo{OldName: value.OldName, NewName: value.NewName}
|
||||
nickNames = append(nickNames, &ulr)
|
||||
}
|
||||
res = &model.NickNameLogRes{Page: esres.Page, Result: nickNames}
|
||||
return
|
||||
}
|
||||
|
||||
type userLogsExtra struct {
|
||||
EncryptTel string `json:"tel"`
|
||||
EncryptEmail string `json:"email"`
|
||||
}
|
||||
|
||||
// UserBindLog User bind log
|
||||
func (d *Dao) UserBindLog(c context.Context, userActLogReq *model.UserBindLogReq) (res *model.UserBindLogRes, err error) {
|
||||
e := d.es
|
||||
nowYear := time.Now().Year()
|
||||
|
||||
var count = 2 //默认查询两年
|
||||
//2016年就有了手机历史记录,此处需要循环建立索引 , 2018年才有邮箱这个功能
|
||||
if userActLogReq.Action == "telBindLog" {
|
||||
count = nowYear - 2015
|
||||
}
|
||||
if userActLogReq.Action == "emailBindLog" {
|
||||
count = nowYear - 2017
|
||||
}
|
||||
indexs := make([]string, count)
|
||||
for i := 0; i < count; i++ {
|
||||
indexs[i] = "log_user_action_54_" + strconv.Itoa(nowYear-i)
|
||||
}
|
||||
|
||||
r := e.NewRequest("log_user_action").Fields("mid", "str_0", "extra_data", "ctime").Index(indexs...)
|
||||
r.Order("ctime", elastic.OrderDesc).Order("mid", elastic.OrderDesc).Pn(userActLogReq.Page).Ps(userActLogReq.Size)
|
||||
|
||||
if userActLogReq.Mid != 0 {
|
||||
r.WhereEq("mid", userActLogReq.Mid)
|
||||
}
|
||||
|
||||
if userActLogReq.Query != "" {
|
||||
hash := sha1.New()
|
||||
hash.Write([]byte(userActLogReq.Query))
|
||||
telHash := base64.StdEncoding.EncodeToString(hash.Sum(d.hashSalt))
|
||||
r.WhereEq("str_0", telHash)
|
||||
}
|
||||
|
||||
if userActLogReq.Action != "" {
|
||||
r.WhereEq("action", userActLogReq.Action)
|
||||
}
|
||||
|
||||
if userActLogReq.From != 0 && userActLogReq.To != 0 {
|
||||
ftm := time.Unix(userActLogReq.From, 0)
|
||||
sf := ftm.Format("2006-01-02 15:04:05")
|
||||
|
||||
ttm := time.Unix(userActLogReq.To, 0)
|
||||
tf := ttm.Format("2006-01-02 15:04:05")
|
||||
|
||||
r.WhereRange("ctime", sf, tf, elastic.RangeScopeLoRo)
|
||||
}
|
||||
|
||||
esres := new(model.EsRes)
|
||||
if err = r.Scan(context.Background(), &esres); err != nil {
|
||||
log.Error("userActLogs search error(%v)", err)
|
||||
return
|
||||
}
|
||||
|
||||
var userBindLogs = make([]*model.UserBindLog, 0)
|
||||
for _, value := range esres.Result {
|
||||
var email, tel string
|
||||
//var model.UserBindLog
|
||||
userLogExtra := userLogsExtra{}
|
||||
err = json.Unmarshal([]byte(value.ExtraData), &userLogExtra)
|
||||
if err != nil {
|
||||
log.Error("cannot convert json(%s) to struct,err(%+v) ", value.ExtraData, err)
|
||||
continue
|
||||
}
|
||||
if userLogExtra.EncryptEmail != "" {
|
||||
email, err = d.decrypt(userLogExtra.EncryptEmail)
|
||||
if err != nil {
|
||||
log.Error("EncryptEmail decode err(%v)", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if userLogExtra.EncryptTel != "" {
|
||||
tel, err = d.decrypt(userLogExtra.EncryptTel)
|
||||
if err != nil {
|
||||
log.Error("EncryptTel decode err(%v)", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
ulr := model.UserBindLog{Mid: value.Mid, Email: email, Phone: tel, Time: value.CTime}
|
||||
userBindLogs = append(userBindLogs, &ulr)
|
||||
}
|
||||
res = &model.UserBindLogRes{Page: esres.Page, Result: userBindLogs}
|
||||
return
|
||||
}
|
42
app/service/main/account-recovery/dao/user_act_log_test.go
Normal file
42
app/service/main/account-recovery/dao/user_act_log_test.go
Normal file
@ -0,0 +1,42 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"go-common/app/service/main/account-recovery/model"
|
||||
|
||||
"github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
func TestDaoNickNameLog(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
nickNameReq = &model.NickNameReq{Mid: 111001254}
|
||||
)
|
||||
convey.Convey("NickNameLog", t, func(ctx convey.C) {
|
||||
res, err := d.NickNameLog(c, nickNameReq)
|
||||
ctx.Convey("Then err should be nil.res should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(res, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDaoUserBindLog(t *testing.T) {
|
||||
var (
|
||||
c = context.Background()
|
||||
emailReq = &model.UserBindLogReq{
|
||||
Action: "emailBindLog",
|
||||
Mid: 23,
|
||||
Size: 100,
|
||||
}
|
||||
)
|
||||
convey.Convey("UserBindLog", t, func(ctx convey.C) {
|
||||
res, err := d.UserBindLog(c, emailReq)
|
||||
ctx.Convey("Then err should be nil.res should not be nil.", func(ctx convey.C) {
|
||||
ctx.So(err, convey.ShouldBeNil)
|
||||
ctx.So(res, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user