Create & Init Project...

This commit is contained in:
2019-04-22 18:49:16 +08:00
commit fc4fa37393
25440 changed files with 4054998 additions and 0 deletions

33
library/BUILD Normal file
View File

@@ -0,0 +1,33 @@
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [
":package-srcs",
"//library/cache:all-srcs",
"//library/conf:all-srcs",
"//library/container:all-srcs",
"//library/database:all-srcs",
"//library/ecode:all-srcs",
"//library/exp/feature:all-srcs",
"//library/log:all-srcs",
"//library/naming:all-srcs",
"//library/net:all-srcs",
"//library/os:all-srcs",
"//library/queue:all-srcs",
"//library/rate:all-srcs",
"//library/stat:all-srcs",
"//library/sync:all-srcs",
"//library/syscall:all-srcs",
"//library/text/translate/chinese:all-srcs",
"//library/time:all-srcs",
"//library/xstr:all-srcs",
],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

7
library/OWNERS Normal file
View File

@@ -0,0 +1,7 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- haoguanwei
- maojian
labels:
- library

44
library/cache/BUILD vendored Normal file
View File

@@ -0,0 +1,44 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_test",
"go_library",
)
go_test(
name = "go_default_test",
srcs = ["cache_test.go"],
embed = [":go_default_library"],
tags = ["automanaged"],
)
go_library(
name = "go_default_library",
srcs = ["cache.go"],
importpath = "go-common/library/cache",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//library/log:go_default_library",
"//library/stat/prom:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [
":package-srcs",
"//library/cache/memcache:all-srcs",
"//library/cache/redis:all-srcs",
],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

88
library/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,88 @@
package cache
import (
"errors"
"runtime"
"sync"
"go-common/library/log"
"go-common/library/stat/prom"
)
var (
// ErrFull cache internal chan full.
ErrFull = errors.New("cache chan full")
stats = prom.BusinessInfoCount
)
// Cache async save data by chan.
type Cache struct {
ch chan func()
worker int
waiter sync.WaitGroup
}
// Deprecated: use library/sync/pipeline/fanout instead.
func New(worker, size int) *Cache {
if worker <= 0 {
worker = 1
}
c := &Cache{
ch: make(chan func(), size),
worker: worker,
}
c.waiter.Add(worker)
for i := 0; i < worker; i++ {
go c.proc()
}
return c
}
func (c *Cache) proc() {
defer c.waiter.Done()
for {
f := <-c.ch
if f == nil {
return
}
wrapFunc(f)()
stats.State("cache_channel", int64(len(c.ch)))
}
}
func wrapFunc(f func()) (res func()) {
res = func() {
defer func() {
if r := recover(); r != nil {
buf := make([]byte, 64*1024)
buf = buf[:runtime.Stack(buf, false)]
log.Error("panic in cache proc, err: %s, stack: %s", r, buf)
}
}()
f()
}
return
}
// Save save a callback cache func.
func (c *Cache) Save(f func()) (err error) {
if f == nil {
return
}
select {
case c.ch <- f:
default:
err = ErrFull
}
stats.State("cache_channel", int64(len(c.ch)))
return
}
// Close close cache
func (c *Cache) Close() (err error) {
for i := 0; i < c.worker; i++ {
c.ch <- nil
}
c.waiter.Wait()
return
}

20
library/cache/cache_test.go vendored Normal file
View File

@@ -0,0 +1,20 @@
package cache
import (
"testing"
"time"
)
func TestCache_Save(t *testing.T) {
ca := New(1, 1024)
var run bool
ca.Save(func() {
run = true
panic("error")
})
time.Sleep(time.Millisecond * 50)
t.Log("don't panic")
if !run {
t.Fatal("expect run be true")
}
}

69
library/cache/memcache/BUILD vendored Normal file
View File

@@ -0,0 +1,69 @@
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
package(default_visibility = ["//visibility:public"])
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [
":package-srcs",
"//library/cache/memcache/test:all-srcs",
],
tags = ["automanaged"],
)
go_library(
name = "go_default_library",
srcs = [
"client.go",
"conn.go",
"errors.go",
"memcache.go",
"mock.go",
"pool.go",
"trace.go",
"util.go",
],
importpath = "go-common/library/cache/memcache",
tags = ["automanaged"],
deps = [
"//library/container/pool:go_default_library",
"//library/net/trace:go_default_library",
"//library/stat:go_default_library",
"//library/time:go_default_library",
"//vendor/github.com/pkg/errors:go_default_library",
"@com_github_gogo_protobuf//proto:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = [
"client_test.go",
"conn_test.go",
"memcache_test.go",
"pool_test.go",
"util_test.go",
],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
deps = [
"//library/cache/memcache/test:go_default_library",
"//library/container/pool:go_default_library",
"//library/time:go_default_library",
"//vendor/github.com/bouk/monkey:go_default_library",
"//vendor/github.com/stretchr/testify/assert:go_default_library",
"@com_github_gogo_protobuf//proto:go_default_library",
],
)

34
library/cache/memcache/CHANGELOG.md vendored Normal file
View File

@@ -0,0 +1,34 @@
### memcache client
##### Version 1.5.1
> 1.修复parse reply时如果有err不关闭连接问题
##### Version 1.5.0
> 1.支持cache和cache conn的写法
##### Version 1.4.0
> 1.add memcache mock conn
##### Version 1.3.2
> 1.修复判断是否合法key
##### Version 1.3.1
> 1.修复pool放回连接的bug
##### Version 1.3.0
> 1.修改memcache pool的实现方式引用container/pool
> 2.pool支持context传入超时以及Get connection WaitTimeout
##### Version 1.2.0
> 1. 增加pkg errors
##### Version 1.1.2
> 1. 修复gzip writer默认压缩level为0的bug
##### Version 1.1.1
> 1. fix populateOne error
##### Version 1.1.0
> 1. memcache添加largevalue支持
##### Version 1.0.0
> 1. 修改decode时protobuf bug,补全测试

10
library/cache/memcache/CONTRIBUTORS.md vendored Normal file
View File

@@ -0,0 +1,10 @@
# Owner
maojian
# Author
maojian
chenshangqiang
zhoujixiang
# Reviewer
maojian

13
library/cache/memcache/OWNERS vendored Normal file
View File

@@ -0,0 +1,13 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- chenshangqiang
- maojian
- zhoujixiang
labels:
- library
- library/cache/memcache
reviewers:
- chenshangqiang
- maojian
- zhoujixiang

32
library/cache/memcache/README.md vendored Normal file
View File

@@ -0,0 +1,32 @@
# go-common/cache/memcache
##### 项目简介
> 1. 提供protobufgobjson序列化方式gzip的memcache接口
##### 编译环境
> 1. 请只用golang v1.7.x以上版本编译执行。
##### 测试
> 1. 执行当前目录下所有测试文件,测试所有功能
##### 特别说明
> 1. 使用protobuf需要在pb文件目录下运行business/make.sh脚本生成go文件才能使用
#### 使用方式
```golang
// 初始化
mc := memcache.New(&memcache.Config{})
// 增加 key
err = mc.Set(c, &memcache.Item{})
// 删除key
err := mc.Delete(c,key)
// 获得某个key的内容
err := mc.Get(c,key).Scan(&v)
// 获取多个key的内容
replies, err := mc.GetMulti(c, keys)
for _, key := range replies.Keys() {
if err = rows.Scan(key, &v); err != nil {
return
}
}
```

188
library/cache/memcache/client.go vendored Normal file
View File

@@ -0,0 +1,188 @@
package memcache
import (
"context"
)
// Memcache memcache client
type Memcache struct {
pool *Pool
}
// Reply is the result of Get
type Reply struct {
err error
item *Item
conn Conn
closed bool
}
// Replies is the result of GetMulti
type Replies struct {
err error
items map[string]*Item
usedItems map[string]struct{}
conn Conn
closed bool
}
// New get a memcache client
func New(c *Config) *Memcache {
return &Memcache{pool: NewPool(c)}
}
// Close close connection pool
func (mc *Memcache) Close() error {
return mc.pool.Close()
}
// Conn direct get a connection
func (mc *Memcache) Conn(c context.Context) Conn {
return mc.pool.Get(c)
}
// Set writes the given item, unconditionally.
func (mc *Memcache) Set(c context.Context, item *Item) (err error) {
conn := mc.pool.Get(c)
err = conn.Set(item)
conn.Close()
return
}
// Add writes the given item, if no value already exists for its key.
// ErrNotStored is returned if that condition is not met.
func (mc *Memcache) Add(c context.Context, item *Item) (err error) {
conn := mc.pool.Get(c)
err = conn.Add(item)
conn.Close()
return
}
// Replace writes the given item, but only if the server *does* already hold data for this key.
func (mc *Memcache) Replace(c context.Context, item *Item) (err error) {
conn := mc.pool.Get(c)
err = conn.Replace(item)
conn.Close()
return
}
// CompareAndSwap writes the given item that was previously returned by Get
func (mc *Memcache) CompareAndSwap(c context.Context, item *Item) (err error) {
conn := mc.pool.Get(c)
err = conn.CompareAndSwap(item)
conn.Close()
return
}
// Get sends a command to the server for gets data.
func (mc *Memcache) Get(c context.Context, key string) *Reply {
conn := mc.pool.Get(c)
item, err := conn.Get(key)
if err != nil {
conn.Close()
}
return &Reply{err: err, item: item, conn: conn}
}
// Item get raw Item
func (r *Reply) Item() *Item {
return r.item
}
// Scan converts value, read from the memcache
func (r *Reply) Scan(v interface{}) (err error) {
if r.err != nil {
return r.err
}
err = r.conn.Scan(r.item, v)
if !r.closed {
r.conn.Close()
r.closed = true
}
return
}
// GetMulti is a batch version of Get
func (mc *Memcache) GetMulti(c context.Context, keys []string) (*Replies, error) {
conn := mc.pool.Get(c)
items, err := conn.GetMulti(keys)
rs := &Replies{err: err, items: items, conn: conn, usedItems: make(map[string]struct{}, len(keys))}
if err != nil {
conn.Close()
rs.closed = true
}
return rs, err
}
// Close close rows.
func (rs *Replies) Close() (err error) {
if !rs.closed {
err = rs.conn.Close()
rs.closed = true
}
return
}
// Item get Item from rows
func (rs *Replies) Item(key string) *Item {
return rs.items[key]
}
// Scan converts value, read from key in rows
func (rs *Replies) Scan(key string, v interface{}) (err error) {
if rs.err != nil {
return rs.err
}
item, ok := rs.items[key]
if !ok {
return ErrNotFound
}
rs.usedItems[key] = struct{}{}
err = rs.conn.Scan(item, v)
shouldClose := len(rs.items) == len(rs.usedItems)
if shouldClose {
rs.Close()
}
return
}
// Keys keys of result
func (rs *Replies) Keys() (keys []string) {
keys = make([]string, 0, len(rs.items))
for key := range rs.items {
keys = append(keys, key)
}
return
}
// Touch updates the expiry for the given key.
func (mc *Memcache) Touch(c context.Context, key string, timeout int32) (err error) {
conn := mc.pool.Get(c)
err = conn.Touch(key, timeout)
conn.Close()
return
}
// Delete deletes the item with the provided key.
func (mc *Memcache) Delete(c context.Context, key string) (err error) {
conn := mc.pool.Get(c)
err = conn.Delete(key)
conn.Close()
return
}
// Increment atomically increments key by delta.
func (mc *Memcache) Increment(c context.Context, key string, delta uint64) (newValue uint64, err error) {
conn := mc.pool.Get(c)
newValue, err = conn.Increment(key, delta)
conn.Close()
return
}
// Decrement atomically decrements key by delta.
func (mc *Memcache) Decrement(c context.Context, key string, delta uint64) (newValue uint64, err error) {
conn := mc.pool.Get(c)
newValue, err = conn.Decrement(key, delta)
conn.Close()
return
}

302
library/cache/memcache/client_test.go vendored Normal file
View File

@@ -0,0 +1,302 @@
package memcache
import (
"context"
"fmt"
"reflect"
"testing"
"time"
)
var testClient *Memcache
func Test_client_Set(t *testing.T) {
type args struct {
c context.Context
item *Item
}
tests := []struct {
name string
args args
wantErr bool
}{
{name: "set value", args: args{c: context.Background(), item: &Item{Key: "Test_client_Set", Value: []byte("abc")}}, wantErr: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := testClient.Set(tt.args.c, tt.args.item); (err != nil) != tt.wantErr {
t.Errorf("client.Set() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_client_Add(t *testing.T) {
type args struct {
c context.Context
item *Item
}
key := fmt.Sprintf("Test_client_Add_%d", time.Now().Unix())
tests := []struct {
name string
args args
wantErr bool
}{
{name: "add not exist value", args: args{c: context.Background(), item: &Item{Key: key, Value: []byte("abc")}}, wantErr: false},
{name: "add exist value", args: args{c: context.Background(), item: &Item{Key: key, Value: []byte("abc")}}, wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := testClient.Add(tt.args.c, tt.args.item); (err != nil) != tt.wantErr {
t.Errorf("client.Add() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_client_Replace(t *testing.T) {
key := fmt.Sprintf("Test_client_Replace_%d", time.Now().Unix())
ekey := "Test_client_Replace_exist"
testClient.Set(context.Background(), &Item{Key: ekey, Value: []byte("ok")})
type args struct {
c context.Context
item *Item
}
tests := []struct {
name string
args args
wantErr bool
}{
{name: "not exist value", args: args{c: context.Background(), item: &Item{Key: key, Value: []byte("abc")}}, wantErr: true},
{name: "exist value", args: args{c: context.Background(), item: &Item{Key: ekey, Value: []byte("abc")}}, wantErr: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := testClient.Replace(tt.args.c, tt.args.item); (err != nil) != tt.wantErr {
t.Errorf("client.Replace() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_client_CompareAndSwap(t *testing.T) {
key := fmt.Sprintf("Test_client_CompareAndSwap_%d", time.Now().Unix())
ekey := "Test_client_CompareAndSwap_k"
testClient.Set(context.Background(), &Item{Key: ekey, Value: []byte("old")})
cas := testClient.Get(context.Background(), ekey).Item().cas
type args struct {
c context.Context
item *Item
}
tests := []struct {
name string
args args
wantErr bool
}{
{name: "not exist value", args: args{c: context.Background(), item: &Item{Key: key, Value: []byte("abc")}}, wantErr: true},
{name: "exist value", args: args{c: context.Background(), item: &Item{Key: ekey, cas: cas, Value: []byte("abc")}}, wantErr: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := testClient.CompareAndSwap(tt.args.c, tt.args.item); (err != nil) != tt.wantErr {
t.Errorf("client.CompareAndSwap() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_client_Get(t *testing.T) {
key := fmt.Sprintf("Test_client_Get_%d", time.Now().Unix())
ekey := "Test_client_Get_k"
testClient.Set(context.Background(), &Item{Key: ekey, Value: []byte("old")})
type args struct {
c context.Context
key string
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{name: "not exist value", args: args{c: context.Background(), key: key}, wantErr: true},
{name: "exist value", args: args{c: context.Background(), key: ekey}, wantErr: false, want: "old"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var res string
if err := testClient.Get(tt.args.c, tt.args.key).Scan(&res); (err != nil) != tt.wantErr || res != tt.want {
t.Errorf("client.Get() = %v, want %v, got err: %v, want err: %v", err, tt.want, err, tt.wantErr)
}
})
}
}
func Test_client_Touch(t *testing.T) {
key := fmt.Sprintf("Test_client_Touch_%d", time.Now().Unix())
ekey := "Test_client_Touch_k"
testClient.Set(context.Background(), &Item{Key: ekey, Value: []byte("old")})
type args struct {
c context.Context
key string
timeout int32
}
tests := []struct {
name string
args args
wantErr bool
}{
{name: "not exist value", args: args{c: context.Background(), key: key, timeout: 100000}, wantErr: true},
{name: "exist value", args: args{c: context.Background(), key: ekey, timeout: 100000}, wantErr: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := testClient.Touch(tt.args.c, tt.args.key, tt.args.timeout); (err != nil) != tt.wantErr {
t.Errorf("client.Touch() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_client_Delete(t *testing.T) {
key := fmt.Sprintf("Test_client_Delete_%d", time.Now().Unix())
ekey := "Test_client_Delete_k"
testClient.Set(context.Background(), &Item{Key: ekey, Value: []byte("old")})
type args struct {
c context.Context
key string
}
tests := []struct {
name string
args args
wantErr bool
}{
{name: "not exist value", args: args{c: context.Background(), key: key}, wantErr: true},
{name: "exist value", args: args{c: context.Background(), key: ekey}, wantErr: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := testClient.Delete(tt.args.c, tt.args.key); (err != nil) != tt.wantErr {
t.Errorf("client.Delete() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_client_Increment(t *testing.T) {
key := fmt.Sprintf("Test_client_Increment_%d", time.Now().Unix())
ekey := "Test_client_Increment_k"
testClient.Set(context.Background(), &Item{Key: ekey, Value: []byte("1")})
type args struct {
c context.Context
key string
delta uint64
}
tests := []struct {
name string
args args
wantNewValue uint64
wantErr bool
}{
{name: "not exist value", args: args{c: context.Background(), key: key, delta: 10}, wantErr: true},
{name: "exist value", args: args{c: context.Background(), key: ekey, delta: 10}, wantErr: false, wantNewValue: 11},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotNewValue, err := testClient.Increment(tt.args.c, tt.args.key, tt.args.delta)
if (err != nil) != tt.wantErr {
t.Errorf("client.Increment() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotNewValue != tt.wantNewValue {
t.Errorf("client.Increment() = %v, want %v", gotNewValue, tt.wantNewValue)
}
})
}
}
func Test_client_Decrement(t *testing.T) {
key := fmt.Sprintf("Test_client_Decrement_%d", time.Now().Unix())
ekey := "Test_client_Decrement_k"
testClient.Set(context.Background(), &Item{Key: ekey, Value: []byte("100")})
type args struct {
c context.Context
key string
delta uint64
}
tests := []struct {
name string
args args
wantNewValue uint64
wantErr bool
}{
{name: "not exist value", args: args{c: context.Background(), key: key, delta: 10}, wantErr: true},
{name: "exist value", args: args{c: context.Background(), key: ekey, delta: 10}, wantErr: false, wantNewValue: 90},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotNewValue, err := testClient.Decrement(tt.args.c, tt.args.key, tt.args.delta)
if (err != nil) != tt.wantErr {
t.Errorf("client.Decrement() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotNewValue != tt.wantNewValue {
t.Errorf("client.Decrement() = %v, want %v", gotNewValue, tt.wantNewValue)
}
})
}
}
func Test_client_GetMulti(t *testing.T) {
key := fmt.Sprintf("Test_client_GetMulti_%d", time.Now().Unix())
ekey1 := "Test_client_GetMulti_k1"
ekey2 := "Test_client_GetMulti_k2"
testClient.Set(context.Background(), &Item{Key: ekey1, Value: []byte("1")})
testClient.Set(context.Background(), &Item{Key: ekey2, Value: []byte("2")})
keys := []string{key, ekey1, ekey2}
rows, err := testClient.GetMulti(context.Background(), keys)
if err != nil {
t.Errorf("client.GetMulti() error = %v, wantErr %v", err, nil)
}
tests := []struct {
key string
wantNewValue string
wantErr bool
nilItem bool
}{
{key: key, wantErr: true, nilItem: true},
{key: ekey1, wantErr: false, wantNewValue: "1", nilItem: false},
{key: ekey2, wantErr: false, wantNewValue: "2", nilItem: false},
}
if reflect.DeepEqual(keys, rows.Keys()) {
t.Errorf("got %v, expect: %v", rows.Keys(), keys)
}
for _, tt := range tests {
t.Run(tt.key, func(t *testing.T) {
var gotNewValue string
err = rows.Scan(tt.key, &gotNewValue)
if (err != nil) != tt.wantErr {
t.Errorf("rows.Scan() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotNewValue != tt.wantNewValue {
t.Errorf("rows.Value() = %v, want %v", gotNewValue, tt.wantNewValue)
}
if (rows.Item(tt.key) == nil) != tt.nilItem {
t.Errorf("rows.Item() = %v, want %v", rows.Item(tt.key) == nil, tt.nilItem)
}
})
}
err = rows.Close()
if err != nil {
t.Errorf("client.Replies.Close() error = %v, wantErr %v", err, nil)
}
}
func Test_client_Conn(t *testing.T) {
conn := testClient.Conn(context.Background())
defer conn.Close()
if conn == nil {
t.Errorf("expect get conn, get nil")
}
}

685
library/cache/memcache/conn.go vendored Normal file
View File

@@ -0,0 +1,685 @@
package memcache
import (
"bufio"
"bytes"
"compress/gzip"
"context"
"encoding/gob"
"encoding/json"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/gogo/protobuf/proto"
pkgerr "github.com/pkg/errors"
)
var (
crlf = []byte("\r\n")
spaceStr = string(" ")
replyOK = []byte("OK\r\n")
replyStored = []byte("STORED\r\n")
replyNotStored = []byte("NOT_STORED\r\n")
replyExists = []byte("EXISTS\r\n")
replyNotFound = []byte("NOT_FOUND\r\n")
replyDeleted = []byte("DELETED\r\n")
replyEnd = []byte("END\r\n")
replyTouched = []byte("TOUCHED\r\n")
replyValueStr = "VALUE"
replyClientErrorPrefix = []byte("CLIENT_ERROR ")
replyServerErrorPrefix = []byte("SERVER_ERROR ")
)
const (
_encodeBuf = 4096 // 4kb
// 1024*1024 - 1, set error???
_largeValue = 1000 * 1000 // 1MB
)
type reader struct {
io.Reader
}
func (r *reader) Reset(rd io.Reader) {
r.Reader = rd
}
// conn is the low-level implementation of Conn
type conn struct {
// Shared
mu sync.Mutex
err error
conn net.Conn
// Read & Write
readTimeout time.Duration
writeTimeout time.Duration
rw *bufio.ReadWriter
// Item Reader
ir bytes.Reader
// Compress
gr gzip.Reader
gw *gzip.Writer
cb bytes.Buffer
// Encoding
edb bytes.Buffer
// json
jr reader
jd *json.Decoder
je *json.Encoder
// protobuffer
ped *proto.Buffer
}
// DialOption specifies an option for dialing a Memcache server.
type DialOption struct {
f func(*dialOptions)
}
type dialOptions struct {
readTimeout time.Duration
writeTimeout time.Duration
dial func(network, addr string) (net.Conn, error)
}
// DialReadTimeout specifies the timeout for reading a single command reply.
func DialReadTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.readTimeout = d
}}
}
// DialWriteTimeout specifies the timeout for writing a single command.
func DialWriteTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.writeTimeout = d
}}
}
// DialConnectTimeout specifies the timeout for connecting to the Memcache server.
func DialConnectTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
dialer := net.Dialer{Timeout: d}
do.dial = dialer.Dial
}}
}
// DialNetDial specifies a custom dial function for creating TCP
// connections. If this option is left out, then net.Dial is
// used. DialNetDial overrides DialConnectTimeout.
func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption {
return DialOption{func(do *dialOptions) {
do.dial = dial
}}
}
// Dial connects to the Memcache server at the given network and
// address using the specified options.
func Dial(network, address string, options ...DialOption) (Conn, error) {
do := dialOptions{
dial: net.Dial,
}
for _, option := range options {
option.f(&do)
}
netConn, err := do.dial(network, address)
if err != nil {
return nil, pkgerr.WithStack(err)
}
return NewConn(netConn, do.readTimeout, do.writeTimeout), nil
}
// NewConn returns a new memcache connection for the given net connection.
func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
if writeTimeout <= 0 || readTimeout <= 0 {
panic("must config memcache timeout")
}
c := &conn{
conn: netConn,
rw: bufio.NewReadWriter(bufio.NewReader(netConn),
bufio.NewWriter(netConn)),
readTimeout: readTimeout,
writeTimeout: writeTimeout,
}
c.jd = json.NewDecoder(&c.jr)
c.je = json.NewEncoder(&c.edb)
c.gw = gzip.NewWriter(&c.cb)
c.edb.Grow(_encodeBuf)
// NOTE reuse bytes.Buffer internal buf
// DON'T concurrency call Scan
c.ped = proto.NewBuffer(c.edb.Bytes())
return c
}
func (c *conn) Close() error {
c.mu.Lock()
err := c.err
if c.err == nil {
c.err = pkgerr.New("memcache: closed")
err = c.conn.Close()
}
c.mu.Unlock()
return err
}
func (c *conn) fatal(err error) error {
c.mu.Lock()
if c.err == nil {
c.err = pkgerr.WithStack(err)
// Close connection to force errors on subsequent calls and to unblock
// other reader or writer.
c.conn.Close()
}
c.mu.Unlock()
return c.err
}
func (c *conn) Err() error {
c.mu.Lock()
err := c.err
c.mu.Unlock()
return err
}
func (c *conn) Add(item *Item) error {
return c.populate("add", item)
}
func (c *conn) Set(item *Item) error {
return c.populate("set", item)
}
func (c *conn) Replace(item *Item) error {
return c.populate("replace", item)
}
func (c *conn) CompareAndSwap(item *Item) error {
return c.populate("cas", item)
}
func (c *conn) populate(cmd string, item *Item) (err error) {
if !legalKey(item.Key) {
return pkgerr.WithStack(ErrMalformedKey)
}
var res []byte
if res, err = c.encode(item); err != nil {
return
}
l := len(res)
count := l/(_largeValue) + 1
if count == 1 {
item.Value = res
return c.populateOne(cmd, item)
}
nItem := &Item{
Key: item.Key,
Value: []byte(strconv.Itoa(l)),
Expiration: item.Expiration,
Flags: item.Flags | flagLargeValue,
}
err = c.populateOne(cmd, nItem)
if err != nil {
return
}
k := item.Key
nItem.Flags = item.Flags
for i := 1; i <= count; i++ {
if i == count {
nItem.Value = res[_largeValue*(count-1):]
} else {
nItem.Value = res[_largeValue*(i-1) : _largeValue*i]
}
nItem.Key = fmt.Sprintf("%s%d", k, i)
if err = c.populateOne(cmd, nItem); err != nil {
return
}
}
return
}
func (c *conn) populateOne(cmd string, item *Item) (err error) {
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
// <command name> <key> <flags> <exptime> <bytes> [noreply]\r\n
if cmd == "cas" {
_, err = fmt.Fprintf(c.rw, "%s %s %d %d %d %d\r\n",
cmd, item.Key, item.Flags, item.Expiration, len(item.Value), item.cas)
} else {
_, err = fmt.Fprintf(c.rw, "%s %s %d %d %d\r\n",
cmd, item.Key, item.Flags, item.Expiration, len(item.Value))
}
if err != nil {
return c.fatal(err)
}
c.rw.Write(item.Value)
c.rw.Write(crlf)
if err = c.rw.Flush(); err != nil {
return c.fatal(err)
}
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
line, err := c.rw.ReadSlice('\n')
if err != nil {
return c.fatal(err)
}
switch {
case bytes.Equal(line, replyStored):
return nil
case bytes.Equal(line, replyNotStored):
return ErrNotStored
case bytes.Equal(line, replyExists):
return ErrCASConflict
case bytes.Equal(line, replyNotFound):
return ErrNotFound
}
return pkgerr.WithStack(protocolError(string(line)))
}
func (c *conn) Get(key string) (r *Item, err error) {
if !legalKey(key) {
return nil, pkgerr.WithStack(ErrMalformedKey)
}
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", key); err != nil {
return nil, c.fatal(err)
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(err)
}
if err = c.parseGetReply(func(it *Item) {
r = it
}); err != nil {
return
}
if r == nil {
err = ErrNotFound
return
}
if r.Flags&flagLargeValue != flagLargeValue {
return
}
if r, err = c.getLargeValue(r); err != nil {
return
}
return
}
func (c *conn) GetMulti(keys []string) (res map[string]*Item, err error) {
for _, key := range keys {
if !legalKey(key) {
return nil, pkgerr.WithStack(ErrMalformedKey)
}
}
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil {
return nil, c.fatal(err)
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(err)
}
res = make(map[string]*Item, len(keys))
if err = c.parseGetReply(func(it *Item) {
res[it.Key] = it
}); err != nil {
return
}
for k, v := range res {
if v.Flags&flagLargeValue != flagLargeValue {
continue
}
r, err := c.getLargeValue(v)
if err != nil {
return res, err
}
res[k] = r
}
return
}
func (c *conn) getMulti(keys []string) (res map[string]*Item, err error) {
for _, key := range keys {
if !legalKey(key) {
return nil, pkgerr.WithStack(ErrMalformedKey)
}
}
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil {
return nil, c.fatal(err)
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(err)
}
res = make(map[string]*Item, len(keys))
err = c.parseGetReply(func(it *Item) {
res[it.Key] = it
})
return
}
func (c *conn) getLargeValue(it *Item) (r *Item, err error) {
l, err := strconv.Atoi(string(it.Value))
if err != nil {
return
}
count := l/_largeValue + 1
keys := make([]string, 0, count)
for i := 1; i <= count; i++ {
keys = append(keys, fmt.Sprintf("%s%d", it.Key, i))
}
items, err := c.getMulti(keys)
if err != nil {
return
}
if len(items) < count {
err = ErrNotFound
return
}
v := make([]byte, 0, l)
for _, k := range keys {
if items[k] == nil || items[k].Value == nil {
err = ErrNotFound
return
}
v = append(v, items[k].Value...)
}
it.Value = v
it.Flags = it.Flags ^ flagLargeValue
r = it
return
}
func (c *conn) parseGetReply(f func(*Item)) error {
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
for {
line, err := c.rw.ReadSlice('\n')
if err != nil {
return c.fatal(err)
}
if bytes.Equal(line, replyEnd) {
return nil
}
if bytes.HasPrefix(line, replyServerErrorPrefix) {
errMsg := line[len(replyServerErrorPrefix):]
return c.fatal(protocolError(errMsg))
}
it := new(Item)
size, err := scanGetReply(line, it)
if err != nil {
return c.fatal(err)
}
it.Value = make([]byte, size+2)
if _, err = io.ReadFull(c.rw, it.Value); err != nil {
return c.fatal(err)
}
if !bytes.HasSuffix(it.Value, crlf) {
return c.fatal(protocolError("corrupt get reply, no except CRLF"))
}
it.Value = it.Value[:size]
f(it)
}
}
func scanGetReply(line []byte, item *Item) (size int, err error) {
if !bytes.HasSuffix(line, crlf) {
return 0, protocolError("corrupt get reply, no except CRLF")
}
// VALUE <key> <flags> <bytes> [<cas unique>]
chunks := strings.Split(string(line[:len(line)-2]), spaceStr)
if len(chunks) < 4 {
return 0, protocolError("corrupt get reply")
}
if chunks[0] != replyValueStr {
return 0, protocolError("corrupt get reply, no except VALUE")
}
item.Key = chunks[1]
flags64, err := strconv.ParseUint(chunks[2], 10, 32)
if err != nil {
return 0, err
}
item.Flags = uint32(flags64)
if size, err = strconv.Atoi(chunks[3]); err != nil {
return
}
if len(chunks) > 4 {
item.cas, err = strconv.ParseUint(chunks[4], 10, 64)
}
return
}
func (c *conn) Touch(key string, expire int32) (err error) {
if !legalKey(key) {
return pkgerr.WithStack(ErrMalformedKey)
}
line, err := c.writeReadLine("touch %s %d\r\n", key, expire)
if err != nil {
return err
}
switch {
case bytes.Equal(line, replyTouched):
return nil
case bytes.Equal(line, replyNotFound):
return ErrNotFound
default:
return pkgerr.WithStack(protocolError(string(line)))
}
}
func (c *conn) Increment(key string, delta uint64) (uint64, error) {
return c.incrDecr("incr", key, delta)
}
func (c *conn) Decrement(key string, delta uint64) (newValue uint64, err error) {
return c.incrDecr("decr", key, delta)
}
func (c *conn) incrDecr(cmd, key string, delta uint64) (uint64, error) {
if !legalKey(key) {
return 0, pkgerr.WithStack(ErrMalformedKey)
}
line, err := c.writeReadLine("%s %s %d\r\n", cmd, key, delta)
if err != nil {
return 0, err
}
switch {
case bytes.Equal(line, replyNotFound):
return 0, ErrNotFound
case bytes.HasPrefix(line, replyClientErrorPrefix):
errMsg := line[len(replyClientErrorPrefix):]
return 0, pkgerr.WithStack(protocolError(errMsg))
}
val, err := strconv.ParseUint(string(line[:len(line)-2]), 10, 64)
if err != nil {
return 0, err
}
return val, nil
}
func (c *conn) Delete(key string) (err error) {
if !legalKey(key) {
return pkgerr.WithStack(ErrMalformedKey)
}
line, err := c.writeReadLine("delete %s\r\n", key)
if err != nil {
return err
}
switch {
case bytes.Equal(line, replyOK):
return nil
case bytes.Equal(line, replyDeleted):
return nil
case bytes.Equal(line, replyNotStored):
return ErrNotStored
case bytes.Equal(line, replyExists):
return ErrCASConflict
case bytes.Equal(line, replyNotFound):
return ErrNotFound
}
return pkgerr.WithStack(protocolError(string(line)))
}
func (c *conn) writeReadLine(format string, args ...interface{}) ([]byte, error) {
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
_, err := fmt.Fprintf(c.rw, format, args...)
if err != nil {
return nil, c.fatal(pkgerr.WithStack(err))
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(pkgerr.WithStack(err))
}
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
line, err := c.rw.ReadSlice('\n')
if err != nil {
return line, c.fatal(pkgerr.WithStack(err))
}
return line, nil
}
func (c *conn) Scan(item *Item, v interface{}) (err error) {
c.ir.Reset(item.Value)
if item.Flags&FlagGzip == FlagGzip {
if err = c.gr.Reset(&c.ir); err != nil {
return
}
if err = c.decode(&c.gr, item, v); err != nil {
err = pkgerr.WithStack(err)
return
}
err = c.gr.Close()
} else {
err = c.decode(&c.ir, item, v)
}
err = pkgerr.WithStack(err)
return
}
func (c *conn) WithContext(ctx context.Context) Conn {
// FIXME: implement WithContext
return c
}
func (c *conn) encode(item *Item) (data []byte, err error) {
if (item.Flags | _flagEncoding) == _flagEncoding {
if item.Value == nil {
return nil, ErrItem
}
} else if item.Object == nil {
return nil, ErrItem
}
// encoding
switch {
case item.Flags&FlagGOB == FlagGOB:
c.edb.Reset()
if err = gob.NewEncoder(&c.edb).Encode(item.Object); err != nil {
return
}
data = c.edb.Bytes()
case item.Flags&FlagProtobuf == FlagProtobuf:
c.edb.Reset()
c.ped.SetBuf(c.edb.Bytes())
pb, ok := item.Object.(proto.Message)
if !ok {
err = ErrItemObject
return
}
if err = c.ped.Marshal(pb); err != nil {
return
}
data = c.ped.Bytes()
case item.Flags&FlagJSON == FlagJSON:
c.edb.Reset()
if err = c.je.Encode(item.Object); err != nil {
return
}
data = c.edb.Bytes()
default:
data = item.Value
}
// compress
if item.Flags&FlagGzip == FlagGzip {
c.cb.Reset()
c.gw.Reset(&c.cb)
if _, err = c.gw.Write(data); err != nil {
return
}
if err = c.gw.Close(); err != nil {
return
}
data = c.cb.Bytes()
}
if len(data) > 8000000 {
err = ErrValueSize
}
return
}
func (c *conn) decode(rd io.Reader, item *Item, v interface{}) (err error) {
var data []byte
switch {
case item.Flags&FlagGOB == FlagGOB:
err = gob.NewDecoder(rd).Decode(v)
case item.Flags&FlagJSON == FlagJSON:
c.jr.Reset(rd)
err = c.jd.Decode(v)
default:
data = item.Value
if item.Flags&FlagGzip == FlagGzip {
c.edb.Reset()
if _, err = io.Copy(&c.edb, rd); err != nil {
return
}
data = c.edb.Bytes()
}
if item.Flags&FlagProtobuf == FlagProtobuf {
m, ok := v.(proto.Message)
if !ok {
err = ErrItemObject
return
}
c.ped.SetBuf(data)
err = c.ped.Unmarshal(m)
} else {
switch v.(type) {
case *[]byte:
d := v.(*[]byte)
*d = data
case *string:
d := v.(*string)
*d = string(data)
case interface{}:
err = json.Unmarshal(data, v)
}
}
}
return
}
func legalKey(key string) bool {
if len(key) > 250 || len(key) == 0 {
return false
}
for i := 0; i < len(key); i++ {
if key[i] <= ' ' || key[i] == 0x7f {
return false
}
}
return true
}

524
library/cache/memcache/conn_test.go vendored Normal file
View File

@@ -0,0 +1,524 @@
package memcache
import (
"bytes"
"encoding/json"
"errors"
test "go-common/library/cache/memcache/test"
"testing"
"time"
"github.com/bouk/monkey"
"github.com/gogo/protobuf/proto"
)
var s = []string{"test", "test1"}
var c Conn
var item = &Item{
Key: "test",
Value: []byte("test"),
Flags: FlagRAW,
Expiration: 60,
cas: 0,
}
var item2 = &Item{
Key: "test1",
Value: []byte("test"),
Flags: 0,
Expiration: 1000,
cas: 0,
}
var item3 = &Item{
Key: "test2",
Value: []byte("0"),
Flags: 0,
Expiration: 60,
cas: 0,
}
type TestItem struct {
Name string
Age int
}
func (t *TestItem) Compare(nt *TestItem) bool {
return t.Name == nt.Name && t.Age == nt.Age
}
func prepareEnv(t *testing.T) {
if c != nil {
return
}
var err error
cnop := DialConnectTimeout(time.Duration(2 * time.Second))
rdop := DialReadTimeout(time.Duration(2 * time.Second))
wrop := DialWriteTimeout(time.Duration(2 * time.Second))
c, err = Dial("tcp", testMemcacheAddr, cnop, rdop, wrop)
if err != nil {
t.Errorf("Dial() error(%v)", err)
t.FailNow()
}
c.Delete("test")
c.Delete("test1")
c.Delete("test2")
}
func TestRaw(t *testing.T) {
prepareEnv(t)
if err := c.Set(item); err != nil {
t.Errorf("conn.Store() error(%v)", err)
}
}
func TestAdd(t *testing.T) {
var (
key = "test_add"
item = &Item{
Key: key,
Value: []byte("0"),
Flags: 0,
Expiration: 60,
cas: 0,
}
)
prepareEnv(t)
c.Delete(key)
if err := c.Add(item); err != nil {
t.Errorf("c.Add() error(%v)", err)
}
if err := c.Add(item); err != ErrNotStored {
t.Errorf("c.Add() error(%v)", err)
}
}
func TestSetErr(t *testing.T) {
prepareEnv(t)
//set
st := &TestItem{Name: "jsongzip", Age: 10}
itemx := &Item{Key: "jsongzip", Object: st}
if err := c.Set(itemx); err != ErrItem {
t.Errorf("conn.Set() error(%v)", err)
}
}
func TestSetErr2(t *testing.T) {
prepareEnv(t)
//set
itemx := &Item{Key: "jsongzip", Flags: FlagJSON | FlagGzip}
if err := c.Set(itemx); err != ErrItem {
t.Errorf("conn.Set() error(%v)", err)
}
}
func TestSetErr3(t *testing.T) {
prepareEnv(t)
//set
itemx := &Item{Key: "jsongzip", Value: []byte("test"), Flags: FlagJSON}
if err := c.Set(itemx); err != ErrItem {
t.Errorf("conn.Set() error(%v)", err)
}
}
func TestJSONGzip(t *testing.T) {
prepareEnv(t)
//set
st := &TestItem{Name: "jsongzip", Age: 10}
itemx := &Item{Key: "jsongzip", Object: st, Flags: FlagJSON | FlagGzip}
if err := c.Set(itemx); err != nil {
t.Errorf("conn.Set() error(%v)", err)
}
if r, err := c.Get("jsongzip"); err != nil {
t.Errorf("conn.Get() error(%v)", err)
} else {
var nst TestItem
scanAndCompare(t, r, st, &nst)
}
}
func TestJSON(t *testing.T) {
prepareEnv(t)
st := &TestItem{Name: "json", Age: 10}
itemx := &Item{Key: "json", Object: st, Flags: FlagJSON}
if err := c.Set(itemx); err != nil {
t.Errorf("conn.Set() error(%v)", err)
}
if r, err := c.Get("json"); err != nil {
t.Errorf("conn.Get() error(%v)", err)
} else {
var nst TestItem
scanAndCompare(t, r, st, &nst)
}
}
func BenchmarkJSON(b *testing.B) {
st := &TestItem{Name: "json", Age: 10}
itemx := &Item{Key: "json", Object: st, Flags: FlagJSON}
var (
eb bytes.Buffer
je *json.Encoder
ir bytes.Reader
jd *json.Decoder
jr reader
nst test.TestItem
)
jd = json.NewDecoder(&jr)
je = json.NewEncoder(&eb)
eb.Grow(_encodeBuf)
// NOTE reuse bytes.Buffer internal buf
// DON'T concurrency call Scan
b.ResetTimer()
for i := 0; i < b.N; i++ {
eb.Reset()
if err := je.Encode(itemx.Object); err != nil {
return
}
data := eb.Bytes()
ir.Reset(data)
jr.Reset(&ir)
jd.Decode(&nst)
}
}
func BenchmarkProtobuf(b *testing.B) {
st := &test.TestItem{Name: "protobuf", Age: 10}
itemx := &Item{Key: "protobuf", Object: st, Flags: FlagJSON}
var (
eb bytes.Buffer
nst test.TestItem
ped *proto.Buffer
)
ped = proto.NewBuffer(eb.Bytes())
eb.Grow(_encodeBuf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
ped.Reset()
pb, ok := itemx.Object.(proto.Message)
if !ok {
return
}
if err := ped.Marshal(pb); err != nil {
return
}
data := ped.Bytes()
ped.SetBuf(data)
ped.Unmarshal(&nst)
}
}
func TestGob(t *testing.T) {
prepareEnv(t)
st := &TestItem{Name: "gob", Age: 10}
itemx := &Item{Key: "gob", Object: st, Flags: FlagGOB}
if err := c.Set(itemx); err != nil {
t.Errorf("conn.Set() error(%v)", err)
}
if r, err := c.Get("gob"); err != nil {
t.Errorf("conn.Get() error(%v)", err)
} else {
var nst TestItem
scanAndCompare(t, r, st, &nst)
}
}
func TestGobGzip(t *testing.T) {
prepareEnv(t)
st := &TestItem{Name: "gobgzip", Age: 10}
itemx := &Item{Key: "gobgzip", Object: st, Flags: FlagGOB | FlagGzip}
if err := c.Set(itemx); err != nil {
t.Errorf("conn.Set() error(%v)", err)
}
if r, err := c.Get("gobgzip"); err != nil {
t.Errorf("conn.Get() error(%v)", err)
} else {
var nst TestItem
scanAndCompare(t, r, st, &nst)
}
}
func TestGzip(t *testing.T) {
prepareEnv(t)
st := &TestItem{Name: "gzip", Age: 123}
itemx := &Item{Key: "gzip", Object: st, Flags: FlagGOB | FlagGzip}
if err := c.Set(itemx); err != nil {
t.Errorf("conn.Set() error(%v)", err)
}
if r, err := c.Get("gzip"); err != nil {
t.Errorf("conn.Get() error(%v)", err)
} else {
var nst TestItem
scanAndCompare(t, r, st, &nst)
}
}
func TestProtobuf(t *testing.T) {
prepareEnv(t)
var (
err error
// value []byte
r *Item
nst test.TestItem
)
st := &test.TestItem{Name: "proto", Age: 3021}
itemx := &Item{Key: "proto", Object: st, Flags: FlagProtobuf}
if err = c.Set(itemx); err != nil {
t.Errorf("conn.Set() error(%v)", err)
}
if r, err = c.Get("proto"); err != nil {
t.Errorf("conn.Get() error(%v)", err)
}
if err = c.Scan(r, &nst); err != nil {
t.Errorf("decode() error(%v)", err)
t.FailNow()
} else {
scanAndCompare2(t, r, st, &nst)
}
}
func TestProtobufGzip(t *testing.T) {
prepareEnv(t)
var (
err error
// value []byte
r *Item
nst test.TestItem
)
st := &test.TestItem{Name: "protogzip", Age: 3021}
itemx := &Item{Key: "protogzip", Object: st, Flags: FlagProtobuf | FlagGzip}
if err = c.Set(itemx); err != nil {
t.Errorf("conn.Set() error(%v)", err)
}
if r, err = c.Get("protogzip"); err != nil {
t.Errorf("conn.Get() error(%v)", err)
}
if err = c.Scan(r, &nst); err != nil {
t.Errorf("decode() error(%v)", err)
t.FailNow()
} else {
scanAndCompare2(t, r, st, &nst)
}
}
func TestGet(t *testing.T) {
prepareEnv(t)
// get
if r, err := c.Get("test"); err != nil {
t.Errorf("conn.Get() error(%v)", err)
} else if r.Key != "test" || !bytes.Equal(r.Value, []byte("test")) || r.Flags != 0 {
t.Error("conn.Get() error, value")
}
}
func TestGetHasErr(t *testing.T) {
prepareEnv(t)
st := &TestItem{Name: "json", Age: 10}
itemx := &Item{Key: "test", Object: st, Flags: FlagJSON}
c.Set(itemx)
expected := errors.New("some error")
monkey.Patch(scanGetReply, func(line []byte, item *Item) (size int, err error) {
return 0, expected
})
if _, err := c.Get("test"); err.Error() != expected.Error() {
t.Errorf("conn.Get() unexpected error(%v)", err)
}
if err := c.(*conn).err; err.Error() != expected.Error() {
t.Errorf("unexpected error(%v)", err)
}
}
func TestGet2(t *testing.T) {
prepareEnv(t)
// get not exist
if _, err := c.Get("not_exist"); err != ErrNotFound {
t.Errorf("conn.Get() error(%v)", err)
}
}
func TestGetMulti(t *testing.T) {
prepareEnv(t)
// getMulti
if _, err := c.GetMulti(s); err != nil {
t.Errorf("conn.GetMulti() error(%v)", err)
}
}
func TestGetMulti2(t *testing.T) {
prepareEnv(t)
//set
if err := c.Set(item); err != nil {
t.Errorf("conn.Set() error(%v)", err)
}
if err := c.Set(item2); err != nil {
t.Errorf("conn.Set() error(%v)", err)
}
if res, err := c.GetMulti(s); err != nil {
t.Errorf("conn.Get() error(%v)", err)
} else {
if len(res) != 2 {
t.Error("conn.Get() error, length", len(res))
}
reply := res["test"]
compareItem2(t, reply, item)
reply = res["test1"]
compareItem2(t, reply, item2)
}
}
func TestIncrement(t *testing.T) {
// set
if err := c.Set(item3); err != nil {
t.Errorf("conn.Set() error(%v)", err)
}
// incr
if d, err := c.Increment("test2", 4); err != nil {
t.Errorf("conn.Set() error(%v)", err)
} else {
if d != 4 {
t.Error("conn.IncrDecr value error")
}
}
}
func TestDecrement(t *testing.T) {
// decr
if d, err := c.Decrement("test2", 3); err != nil {
t.Errorf("conn.Store() error(%v)", err)
} else {
if d != 1 {
t.Error("conn.IncrDecr value error", d)
}
}
}
func TestTouch(t *testing.T) {
// touch
if err := c.Touch("test2", 1); err != nil {
t.Errorf("conn.Touch error(%v)", err)
}
}
func TestCompareAndSwap(t *testing.T) {
prepareEnv(t)
if err := c.Set(item3); err != nil {
t.Errorf("conn.Set() error(%v)", err)
}
//cas
if r, err := c.Get("test2"); err != nil {
t.Errorf("conn.Get() error(%v)", err)
} else {
r.Value = []byte("fuck")
if err := c.CompareAndSwap(r); err != nil {
t.Errorf("conn.CompareAndSwap() error(%v)", err)
}
if r, err := c.Get("test2"); err != nil {
t.Errorf("conn.Get() error(%v)", err)
} else {
itemx := &Item{Key: "test2", Value: []byte("fuck"), Flags: 0}
compareItem2(t, r, itemx)
}
}
}
func TestReplace(t *testing.T) {
prepareEnv(t)
if err := c.Set(item); err != nil {
t.Errorf("conn.Set() error(%v)", err)
}
if r, err := c.Get("test"); err != nil {
t.Errorf("conn.Get() error(%v)", err)
} else {
r.Value = []byte("go")
if err := c.Replace(r); err != nil {
t.Errorf("conn.CompareAndSwap() error(%v)", err)
}
if r, err := c.Get("test"); err != nil {
t.Errorf("conn.Get() error(%v)", err)
} else {
itemx := &Item{Key: "test", Value: []byte("go"), Flags: 0}
compareItem2(t, r, itemx)
}
}
}
func scanAndCompare(t *testing.T, item *Item, st *TestItem, nst *TestItem) {
if err := c.Scan(item, nst); err != nil {
t.Errorf("decode() error(%v)", err)
t.FailNow()
}
if !st.Compare(nst) {
t.Errorf("st: %v, use of closed network connection nst: %v", st, &nst)
t.FailNow()
}
}
func scanAndCompare2(t *testing.T, item *Item, st *test.TestItem, nst *test.TestItem) {
if err := c.Scan(item, nst); err != nil {
t.Errorf("decode() error(%v)", err)
t.FailNow()
}
if st.Age != nst.Age || st.Name != nst.Name {
t.Errorf("st: %v, use of closed network connection nst: %v", st, &nst)
t.FailNow()
}
}
func compareItem2(t *testing.T, r, item *Item) {
if r.Key != item.Key || !bytes.Equal(r.Value, item.Value) || r.Flags != item.Flags {
t.Error("conn.Get() error, value")
}
}
func Test_legalKey(t *testing.T) {
type args struct {
key string
}
tests := []struct {
name string
args args
want bool
}{
{
name: "test empty key",
want: false,
},
{
name: "test too large key",
args: args{func() string {
var data []byte
for i := 0; i < 255; i++ {
data = append(data, 'k')
}
return string(data)
}()},
want: false,
},
{
name: "test invalid char",
args: args{"hello world"},
want: false,
},
{
name: "test invalid char",
args: args{string([]byte{0x7f})},
want: false,
},
{
name: "test normal key",
args: args{"hello"},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := legalKey(tt.args.key); got != tt.want {
t.Errorf("legalKey() = %v, want %v", got, tt.want)
}
})
}
}

76
library/cache/memcache/errors.go vendored Normal file
View File

@@ -0,0 +1,76 @@
package memcache
import (
"errors"
"fmt"
"strings"
pkgerr "github.com/pkg/errors"
)
var (
// ErrNotFound not found
ErrNotFound = errors.New("memcache: key not found")
// ErrExists exists
ErrExists = errors.New("memcache: key exists")
// ErrNotStored not stored
ErrNotStored = errors.New("memcache: key not stored")
// ErrCASConflict means that a CompareAndSwap call failed due to the
// cached value being modified between the Get and the CompareAndSwap.
// If the cached value was simply evicted rather than replaced,
// ErrNotStored will be returned instead.
ErrCASConflict = errors.New("memcache: compare-and-swap conflict")
// ErrPoolExhausted is returned from a pool connection method (Store, Get,
// Delete, IncrDecr, Err) when the maximum number of database connections
// in the pool has been reached.
ErrPoolExhausted = errors.New("memcache: connection pool exhausted")
// ErrPoolClosed pool closed
ErrPoolClosed = errors.New("memcache: connection pool closed")
// ErrConnClosed conn closed
ErrConnClosed = errors.New("memcache: connection closed")
// ErrMalformedKey is returned when an invalid key is used.
// Keys must be at maximum 250 bytes long and not
// contain whitespace or control characters.
ErrMalformedKey = errors.New("memcache: malformed key is too long or contains invalid characters")
// ErrValueSize item value size must less than 1mb
ErrValueSize = errors.New("memcache: item value size must not greater than 1mb")
// ErrStat stat error for monitor
ErrStat = errors.New("memcache unexpected errors")
// ErrItem item nil.
ErrItem = errors.New("memcache: item object nil")
// ErrItemObject object type Assertion failed
ErrItemObject = errors.New("memcache: item object protobuf type assertion failed")
)
type protocolError string
func (pe protocolError) Error() string {
return fmt.Sprintf("memcache: %s (possible server error or unsupported concurrent read by application)", string(pe))
}
func formatErr(err error) string {
e := pkgerr.Cause(err)
switch e {
case ErrNotFound, ErrExists, ErrNotStored, nil:
return ""
default:
es := e.Error()
switch {
case strings.HasPrefix(es, "read"):
return "read timeout"
case strings.HasPrefix(es, "dial"):
return "dial timeout"
case strings.HasPrefix(es, "write"):
return "write timeout"
case strings.Contains(es, "EOF"):
return "eof"
case strings.Contains(es, "reset"):
return "reset"
case strings.Contains(es, "broken"):
return "broken pipe"
default:
return "unexpected err"
}
}
}

136
library/cache/memcache/memcache.go vendored Normal file
View File

@@ -0,0 +1,136 @@
package memcache
import (
"context"
)
// Error represents an error returned in a command reply.
type Error string
func (err Error) Error() string { return string(err) }
const (
// Flag, 15(encoding) bit+ 17(compress) bit
// FlagRAW default flag.
FlagRAW = uint32(0)
// FlagGOB gob encoding.
FlagGOB = uint32(1) << 0
// FlagJSON json encoding.
FlagJSON = uint32(1) << 1
// FlagProtobuf protobuf
FlagProtobuf = uint32(1) << 2
_flagEncoding = uint32(0xFFFF8000)
// FlagGzip gzip compress.
FlagGzip = uint32(1) << 15
// left mv 31??? not work!!!
flagLargeValue = uint32(1) << 30
)
// Item is an reply to be got or stored in a memcached server.
type Item struct {
// Key is the Item's key (250 bytes maximum).
Key string
// Value is the Item's value.
Value []byte
// Object is the Item's object for use codec.
Object interface{}
// Flags are server-opaque flags whose semantics are entirely
// up to the app.
Flags uint32
// Expiration is the cache expiration time, in seconds: either a relative
// time from now (up to 1 month), or an absolute Unix epoch time.
// Zero means the Item has no expiration time.
Expiration int32
// Compare and swap ID.
cas uint64
}
// Conn represents a connection to a Memcache server.
// Command Reference: https://github.com/memcached/memcached/wiki/Commands
type Conn interface {
// Close closes the connection.
Close() error
// Err returns a non-nil value if the connection is broken. The returned
// value is either the first non-nil value returned from the underlying
// network connection or a protocol parsing error. Applications should
// close broken connections.
Err() error
// Add writes the given item, if no value already exists for its key.
// ErrNotStored is returned if that condition is not met.
Add(item *Item) error
// Set writes the given item, unconditionally.
Set(item *Item) error
// Replace writes the given item, but only if the server *does* already
// hold data for this key.
Replace(item *Item) error
// Get sends a command to the server for gets data.
Get(key string) (*Item, error)
// GetMulti is a batch version of Get. The returned map from keys to items
// may have fewer elements than the input slice, due to memcache cache
// misses. Each key must be at most 250 bytes in length.
// If no error is returned, the returned map will also be non-nil.
GetMulti(keys []string) (map[string]*Item, error)
// Delete deletes the item with the provided key.
// The error ErrCacheMiss is returned if the item didn't already exist in
// the cache.
Delete(key string) error
// Increment atomically increments key by delta. The return value is the
// new value after being incremented or an error. If the value didn't exist
// in memcached the error is ErrCacheMiss. The value in memcached must be
// an decimal number, or an error will be returned.
// On 64-bit overflow, the new value wraps around.
Increment(key string, delta uint64) (newValue uint64, err error)
// Decrement atomically decrements key by delta. The return value is the
// new value after being decremented or an error. If the value didn't exist
// in memcached the error is ErrCacheMiss. The value in memcached must be
// an decimal number, or an error will be returned. On underflow, the new
// value is capped at zero and does not wrap around.
Decrement(key string, delta uint64) (newValue uint64, err error)
// CompareAndSwap writes the given item that was previously returned by
// Get, if the value was neither modified or evicted between the Get and
// the CompareAndSwap calls. The item's Key should not change between calls
// but all other item fields may differ. ErrCASConflict is returned if the
// value was modified in between the calls.
// ErrNotStored is returned if the value was evicted in between the calls.
CompareAndSwap(item *Item) error
// Touch updates the expiry for the given key. The seconds parameter is
// either a Unix timestamp or, if seconds is less than 1 month, the number
// of seconds into the future at which time the item will expire.
//ErrCacheMiss is returned if the key is not in the cache. The key must be
// at most 250 bytes in length.
Touch(key string, seconds int32) (err error)
// Scan converts value read from the memcache into the following
// common Go types and special types:
//
// *string
// *[]byte
// *interface{}
//
Scan(item *Item, v interface{}) (err error)
// WithContext return a Conn with its context changed to ctx
// the context controls the entire lifetime of Conn before you change it
// NOTE: this method is not thread-safe
WithContext(ctx context.Context) Conn
}

210
library/cache/memcache/memcache_test.go vendored Normal file
View File

@@ -0,0 +1,210 @@
package memcache
import (
"encoding/json"
"fmt"
"os"
"testing"
"time"
"go-common/library/container/pool"
xtime "go-common/library/time"
)
var testMemcacheAddr = "127.0.0.1:11211"
var testConfig = &Config{
Config: &pool.Config{
Active: 10,
Idle: 10,
IdleTimeout: xtime.Duration(time.Second),
WaitTimeout: xtime.Duration(time.Second),
Wait: false,
},
Proto: "tcp",
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}
func init() {
if addr := os.Getenv("TEST_MEMCACHE_ADDR"); addr != "" {
testMemcacheAddr = addr
}
testConfig.Addr = testMemcacheAddr
}
func TestMain(m *testing.M) {
testClient = New(testConfig)
m.Run()
testClient.Close()
os.Exit(0)
}
func ExampleConn_set() {
var (
err error
value []byte
conn Conn
expire int32 = 100
p = struct {
Name string
Age int64
}{"golang", 10}
)
cnop := DialConnectTimeout(time.Duration(time.Second))
rdop := DialReadTimeout(time.Duration(time.Second))
wrop := DialWriteTimeout(time.Duration(time.Second))
if value, err = json.Marshal(p); err != nil {
fmt.Println(err)
return
}
if conn, err = Dial("tcp", testMemcacheAddr, cnop, rdop, wrop); err != nil {
fmt.Println(err)
return
}
// FlagRAW test
itemRaw := &Item{
Key: "test_raw",
Value: value,
Expiration: expire,
}
if err = conn.Set(itemRaw); err != nil {
fmt.Println(err)
return
}
// FlagGzip
itemGZip := &Item{
Key: "test_gzip",
Value: value,
Flags: FlagGzip,
Expiration: expire,
}
if err = conn.Set(itemGZip); err != nil {
fmt.Println(err)
return
}
// FlagGOB
itemGOB := &Item{
Key: "test_gob",
Object: p,
Flags: FlagGOB,
Expiration: expire,
}
if err = conn.Set(itemGOB); err != nil {
fmt.Println(err)
return
}
// FlagJSON
itemJSON := &Item{
Key: "test_json",
Object: p,
Flags: FlagJSON,
Expiration: expire,
}
if err = conn.Set(itemJSON); err != nil {
fmt.Println(err)
return
}
// FlagJSON | FlagGzip
itemJSONGzip := &Item{
Key: "test_jsonGzip",
Object: p,
Flags: FlagJSON | FlagGzip,
Expiration: expire,
}
if err = conn.Set(itemJSONGzip); err != nil {
fmt.Println(err)
return
}
// Output:
}
func ExampleConn_get() {
var (
err error
item2 *Item
conn Conn
p struct {
Name string
Age int64
}
)
cnop := DialConnectTimeout(time.Duration(time.Second))
rdop := DialReadTimeout(time.Duration(time.Second))
wrop := DialWriteTimeout(time.Duration(time.Second))
if conn, err = Dial("tcp", testMemcacheAddr, cnop, rdop, wrop); err != nil {
fmt.Println(err)
return
}
if item2, err = conn.Get("test_raw"); err != nil {
fmt.Println(err)
} else {
if err = conn.Scan(item2, &p); err != nil {
fmt.Printf("FlagRAW conn.Scan error(%v)\n", err)
return
}
}
// FlagGZip
if item2, err = conn.Get("test_gzip"); err != nil {
fmt.Println(err)
} else {
if err = conn.Scan(item2, &p); err != nil {
fmt.Printf("FlagGZip conn.Scan error(%v)\n", err)
return
}
}
// FlagGOB
if item2, err = conn.Get("test_gob"); err != nil {
fmt.Println(err)
} else {
if err = conn.Scan(item2, &p); err != nil {
fmt.Printf("FlagGOB conn.Scan error(%v)\n", err)
return
}
}
// FlagJSON
if item2, err = conn.Get("test_json"); err != nil {
fmt.Println(err)
} else {
if err = conn.Scan(item2, &p); err != nil {
fmt.Printf("FlagJSON conn.Scan error(%v)\n", err)
return
}
}
// Output:
}
func ExampleConn_getMulti() {
var (
err error
conn Conn
res map[string]*Item
keys = []string{"test_raw", "test_gzip"}
p struct {
Name string
Age int64
}
)
cnop := DialConnectTimeout(time.Duration(time.Second))
rdop := DialReadTimeout(time.Duration(time.Second))
wrop := DialWriteTimeout(time.Duration(time.Second))
if conn, err = Dial("tcp", testMemcacheAddr, cnop, rdop, wrop); err != nil {
fmt.Println(err)
return
}
if res, err = conn.GetMulti(keys); err != nil {
fmt.Printf("conn.GetMulti(%v) error(%v)", keys, err)
return
}
for _, v := range res {
if err = conn.Scan(v, &p); err != nil {
fmt.Printf("conn.Scan error(%v)\n", err)
return
}
fmt.Println(p)
}
// Output:
//{golang 10}
//{golang 10}
}

59
library/cache/memcache/mock.go vendored Normal file
View File

@@ -0,0 +1,59 @@
package memcache
import (
"context"
)
// MockErr for unit test.
type MockErr struct {
Error error
}
var _ Conn = MockErr{}
// MockWith return a mock conn.
func MockWith(err error) MockErr {
return MockErr{Error: err}
}
// Err .
func (m MockErr) Err() error { return m.Error }
// Close .
func (m MockErr) Close() error { return m.Error }
// Add .
func (m MockErr) Add(item *Item) error { return m.Error }
// Set .
func (m MockErr) Set(item *Item) error { return m.Error }
// Replace .
func (m MockErr) Replace(item *Item) error { return m.Error }
// CompareAndSwap .
func (m MockErr) CompareAndSwap(item *Item) error { return m.Error }
// Get .
func (m MockErr) Get(key string) (*Item, error) { return nil, m.Error }
// GetMulti .
func (m MockErr) GetMulti(keys []string) (map[string]*Item, error) { return nil, m.Error }
// Touch .
func (m MockErr) Touch(key string, timeout int32) error { return m.Error }
// Delete .
func (m MockErr) Delete(key string) error { return m.Error }
// Increment .
func (m MockErr) Increment(key string, delta uint64) (uint64, error) { return 0, m.Error }
// Decrement .
func (m MockErr) Decrement(key string, delta uint64) (uint64, error) { return 0, m.Error }
// Scan .
func (m MockErr) Scan(item *Item, v interface{}) error { return m.Error }
// WithContext .
func (m MockErr) WithContext(ctx context.Context) Conn { return m }

197
library/cache/memcache/pool.go vendored Normal file
View File

@@ -0,0 +1,197 @@
package memcache
import (
"context"
"io"
"time"
"go-common/library/container/pool"
"go-common/library/stat"
xtime "go-common/library/time"
)
var stats = stat.Cache
// Config memcache config.
type Config struct {
*pool.Config
Name string // memcache name, for trace
Proto string
Addr string
DialTimeout xtime.Duration
ReadTimeout xtime.Duration
WriteTimeout xtime.Duration
}
// Pool memcache connection pool struct.
type Pool struct {
p pool.Pool
c *Config
}
// NewPool new a memcache conn pool.
func NewPool(c *Config) (p *Pool) {
if c.DialTimeout <= 0 || c.ReadTimeout <= 0 || c.WriteTimeout <= 0 {
panic("must config memcache timeout")
}
p1 := pool.NewList(c.Config)
cnop := DialConnectTimeout(time.Duration(c.DialTimeout))
rdop := DialReadTimeout(time.Duration(c.ReadTimeout))
wrop := DialWriteTimeout(time.Duration(c.WriteTimeout))
p1.New = func(ctx context.Context) (io.Closer, error) {
conn, err := Dial(c.Proto, c.Addr, cnop, rdop, wrop)
return &traceConn{Conn: conn, address: c.Addr}, err
}
p = &Pool{p: p1, c: c}
return
}
// Get gets a connection. The application must close the returned connection.
// This method always returns a valid connection so that applications can defer
// error handling to the first use of the connection. If there is an error
// getting an underlying connection, then the connection Err, Do, Send, Flush
// and Receive methods return that error.
func (p *Pool) Get(ctx context.Context) Conn {
c, err := p.p.Get(ctx)
if err != nil {
return errorConnection{err}
}
c1, _ := c.(Conn)
return &pooledConnection{p: p, c: c1.WithContext(ctx), ctx: ctx}
}
// Close release the resources used by the pool.
func (p *Pool) Close() error {
return p.p.Close()
}
type pooledConnection struct {
p *Pool
c Conn
ctx context.Context
}
func pstat(key string, t time.Time, err error) {
stats.Timing(key, int64(time.Since(t)/time.Millisecond))
if err != nil {
if msg := formatErr(err); msg != "" {
stats.Incr("memcache", msg)
}
}
}
func (pc *pooledConnection) Close() error {
c := pc.c
if _, ok := c.(errorConnection); ok {
return nil
}
pc.c = errorConnection{ErrConnClosed}
pc.p.p.Put(context.Background(), c, c.Err() != nil)
return nil
}
func (pc *pooledConnection) Err() error {
return pc.c.Err()
}
func (pc *pooledConnection) Set(item *Item) (err error) {
now := time.Now()
err = pc.c.Set(item)
pstat("memcache:set", now, err)
return
}
func (pc *pooledConnection) Add(item *Item) (err error) {
now := time.Now()
err = pc.c.Add(item)
pstat("memcache:add", now, err)
return
}
func (pc *pooledConnection) Replace(item *Item) (err error) {
now := time.Now()
err = pc.c.Replace(item)
pstat("memcache:replace", now, err)
return
}
func (pc *pooledConnection) CompareAndSwap(item *Item) (err error) {
now := time.Now()
err = pc.c.CompareAndSwap(item)
pstat("memcache:cas", now, err)
return
}
func (pc *pooledConnection) Get(key string) (r *Item, err error) {
now := time.Now()
r, err = pc.c.Get(key)
pstat("memcache:get", now, err)
return
}
func (pc *pooledConnection) GetMulti(keys []string) (res map[string]*Item, err error) {
// if keys is empty slice returns empty map direct
if len(keys) == 0 {
return make(map[string]*Item), nil
}
now := time.Now()
res, err = pc.c.GetMulti(keys)
pstat("memcache:gets", now, err)
return
}
func (pc *pooledConnection) Touch(key string, timeout int32) (err error) {
now := time.Now()
err = pc.c.Touch(key, timeout)
pstat("memcache:touch", now, err)
return
}
func (pc *pooledConnection) Scan(item *Item, v interface{}) error {
return pc.c.Scan(item, v)
}
func (pc *pooledConnection) WithContext(ctx context.Context) Conn {
// TODO: set context
pc.ctx = ctx
return pc
}
func (pc *pooledConnection) Delete(key string) (err error) {
now := time.Now()
err = pc.c.Delete(key)
pstat("memcache:delete", now, err)
return
}
func (pc *pooledConnection) Increment(key string, delta uint64) (newValue uint64, err error) {
now := time.Now()
newValue, err = pc.c.Increment(key, delta)
pstat("memcache:increment", now, err)
return
}
func (pc *pooledConnection) Decrement(key string, delta uint64) (newValue uint64, err error) {
now := time.Now()
newValue, err = pc.c.Decrement(key, delta)
pstat("memcache:decrement", now, err)
return
}
type errorConnection struct{ err error }
func (ec errorConnection) Err() error { return ec.err }
func (ec errorConnection) Close() error { return ec.err }
func (ec errorConnection) Add(item *Item) error { return ec.err }
func (ec errorConnection) Set(item *Item) error { return ec.err }
func (ec errorConnection) Replace(item *Item) error { return ec.err }
func (ec errorConnection) CompareAndSwap(item *Item) error { return ec.err }
func (ec errorConnection) Get(key string) (*Item, error) { return nil, ec.err }
func (ec errorConnection) GetMulti(keys []string) (map[string]*Item, error) { return nil, ec.err }
func (ec errorConnection) Touch(key string, timeout int32) error { return ec.err }
func (ec errorConnection) Delete(key string) error { return ec.err }
func (ec errorConnection) Increment(key string, delta uint64) (uint64, error) { return 0, ec.err }
func (ec errorConnection) Decrement(key string, delta uint64) (uint64, error) { return 0, ec.err }
func (ec errorConnection) Scan(item *Item, v interface{}) error { return ec.err }
func (ec errorConnection) WithContext(ctx context.Context) Conn { return ec }

361
library/cache/memcache/pool_test.go vendored Normal file
View File

@@ -0,0 +1,361 @@
package memcache
import (
"bytes"
"context"
"os"
"testing"
"time"
"go-common/library/container/pool"
xtime "go-common/library/time"
)
var p *Pool
var config *Config
func init() {
testMemcacheAddr := "127.0.0.1:11211"
if addr := os.Getenv("TEST_MEMCACHE_ADDR"); addr != "" {
testMemcacheAddr = addr
}
config = &Config{
Name: "test",
Proto: "tcp",
Addr: testMemcacheAddr,
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}
config.Config = &pool.Config{
Active: 10,
Idle: 5,
IdleTimeout: xtime.Duration(90 * time.Second),
}
}
var itempool = &Item{
Key: "testpool",
Value: []byte("testpool"),
Flags: 0,
Expiration: 60,
cas: 0,
}
var itempool2 = &Item{
Key: "test_count",
Value: []byte("0"),
Flags: 0,
Expiration: 1000,
cas: 0,
}
type testObject struct {
Mid int64
Value []byte
}
var largeValue = &Item{
Key: "large_value",
Flags: FlagGOB | FlagGzip,
Expiration: 1000,
cas: 0,
}
var largeValueBoundary = &Item{
Key: "large_value",
Flags: FlagGOB | FlagGzip,
Expiration: 1000,
cas: 0,
}
func prepareEnv2() {
if p != nil {
return
}
p = NewPool(config)
}
func TestPoolSet(t *testing.T) {
prepareEnv2()
conn := p.Get(context.Background())
defer conn.Close()
// set
if err := conn.Set(itempool); err != nil {
t.Errorf("memcache: set error(%v)", err)
} else {
t.Logf("memcache: set value: %s", item.Value)
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolGet(t *testing.T) {
prepareEnv2()
key := "testpool"
conn := p.Get(context.Background())
defer conn.Close()
// get
if res, err := conn.Get(key); err != nil {
t.Errorf("memcache: get error(%v)", err)
} else {
t.Logf("memcache: get value: %s", res.Value)
}
if _, err := conn.Get("not_found"); err != ErrNotFound {
t.Errorf("memcache: expceted err is not found but got: %v", err)
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolGetMulti(t *testing.T) {
prepareEnv2()
conn := p.Get(context.Background())
defer conn.Close()
s := []string{"testpool", "test1"}
// get
if res, err := conn.GetMulti(s); err != nil {
t.Errorf("memcache: gets error(%v)", err)
} else {
t.Logf("memcache: gets value: %d", len(res))
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolTouch(t *testing.T) {
prepareEnv2()
key := "testpool"
conn := p.Get(context.Background())
defer conn.Close()
// touch
if err := conn.Touch(key, 10); err != nil {
t.Errorf("memcache: touch error(%v)", err)
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolIncrement(t *testing.T) {
prepareEnv2()
key := "test_count"
conn := p.Get(context.Background())
defer conn.Close()
// set
if err := conn.Set(itempool2); err != nil {
t.Errorf("memcache: set error(%v)", err)
} else {
t.Logf("memcache: set value: 0")
}
// incr
if res, err := conn.Increment(key, 1); err != nil {
t.Errorf("memcache: incr error(%v)", err)
} else {
t.Logf("memcache: incr n: %d", res)
if res != 1 {
t.Errorf("memcache: expected res=1 but got %d", res)
}
}
// decr
if res, err := conn.Decrement(key, 1); err != nil {
t.Errorf("memcache: decr error(%v)", err)
} else {
t.Logf("memcache: decr n: %d", res)
if res != 0 {
t.Errorf("memcache: expected res=0 but got %d", res)
}
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolErr(t *testing.T) {
prepareEnv2()
conn := p.Get(context.Background())
defer conn.Close()
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
if err := conn.Err(); err == nil {
t.Errorf("memcache: err not nil")
} else {
t.Logf("memcache: err: %v", err)
}
}
func TestPoolCompareAndSwap(t *testing.T) {
prepareEnv2()
conn := p.Get(context.Background())
defer conn.Close()
key := "testpool"
//cas
if r, err := conn.Get(key); err != nil {
t.Errorf("conn.Get() error(%v)", err)
} else {
r.Value = []byte("shit")
if err := conn.CompareAndSwap(r); err != nil {
t.Errorf("conn.Get() error(%v)", err)
}
r, _ := conn.Get("testpool")
if r.Key != "testpool" || !bytes.Equal(r.Value, []byte("shit")) || r.Flags != 0 {
t.Error("conn.Get() error, value")
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
}
func TestPoolDel(t *testing.T) {
prepareEnv2()
key := "testpool"
conn := p.Get(context.Background())
defer conn.Close()
// delete
if err := conn.Delete(key); err != nil {
t.Errorf("memcache: delete error(%v)", err)
} else {
t.Logf("memcache: delete key: %s", key)
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func BenchmarkMemcache(b *testing.B) {
c := &Config{
Name: "test",
Proto: "tcp",
Addr: "127.0.0.1:11211",
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}
c.Config = &pool.Config{
Active: 10,
Idle: 5,
IdleTimeout: xtime.Duration(90 * time.Second),
}
p = NewPool(c)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
conn := p.Get(context.Background())
if err := conn.Close(); err != nil {
b.Errorf("memcache: close error(%v)", err)
}
}
})
if err := p.Close(); err != nil {
b.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolSetLargeValue(t *testing.T) {
var b bytes.Buffer
for i := 0; i < 4000000; i++ {
b.WriteByte(1)
}
obj := &testObject{}
obj.Mid = 1000
obj.Value = b.Bytes()
largeValue.Object = obj
prepareEnv2()
conn := p.Get(context.Background())
defer conn.Close()
// set
if err := conn.Set(largeValue); err != nil {
t.Errorf("memcache: set error(%v)", err)
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolGetLargeValue(t *testing.T) {
prepareEnv2()
key := largeValue.Key
conn := p.Get(context.Background())
defer conn.Close()
// get
var err error
if _, err = conn.Get(key); err != nil {
t.Errorf("memcache: large get error(%+v)", err)
}
}
func TestPoolGetMultiLargeValue(t *testing.T) {
prepareEnv2()
conn := p.Get(context.Background())
defer conn.Close()
s := []string{largeValue.Key, largeValue.Key}
// get
if res, err := conn.GetMulti(s); err != nil {
t.Errorf("memcache: gets error(%v)", err)
} else {
t.Logf("memcache: gets value: %d", len(res))
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolSetLargeValueBoundary(t *testing.T) {
var b bytes.Buffer
for i := 0; i < _largeValue; i++ {
b.WriteByte(1)
}
obj := &testObject{}
obj.Mid = 1000
obj.Value = b.Bytes()
largeValueBoundary.Object = obj
prepareEnv2()
conn := p.Get(context.Background())
defer conn.Close()
// set
if err := conn.Set(largeValueBoundary); err != nil {
t.Errorf("memcache: set error(%v)", err)
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolGetLargeValueBoundary(t *testing.T) {
prepareEnv2()
key := largeValueBoundary.Key
conn := p.Get(context.Background())
defer conn.Close()
// get
var err error
if _, err = conn.Get(key); err != nil {
t.Errorf("memcache: large get error(%v)", err)
}
}
func TestPoolAdd(t *testing.T) {
var (
key = "test_add"
item = &Item{
Key: key,
Value: []byte("0"),
Flags: 0,
Expiration: 60,
cas: 0,
}
conn = p.Get(context.Background())
)
defer conn.Close()
prepareEnv2()
conn.Delete(key)
if err := conn.Add(item); err != nil {
t.Errorf("memcache: add error(%v)", err)
}
if err := conn.Add(item); err != ErrNotStored {
t.Errorf("memcache: add error(%v)", err)
}
}

48
library/cache/memcache/test/BUILD.bazel vendored Normal file
View File

@@ -0,0 +1,48 @@
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
)
load(
"@io_bazel_rules_go//proto:def.bzl",
"go_proto_library",
)
proto_library(
name = "proto_proto",
srcs = ["test.proto"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)
go_proto_library(
name = "proto_go_proto",
compilers = ["@io_bazel_rules_go//proto:go_proto"],
importpath = "go-common/library/cache/memcache/test",
proto = ":proto_proto",
tags = ["automanaged"],
visibility = ["//visibility:public"],
)
go_library(
name = "go_default_library",
srcs = [],
embed = [":proto_go_proto"],
importpath = "go-common/library/cache/memcache/test",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = ["@com_github_golang_protobuf//proto:go_default_library"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

375
library/cache/memcache/test/test.pb.go vendored Normal file
View File

@@ -0,0 +1,375 @@
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: test.proto
/*
Package proto is a generated protocol buffer package.
It is generated from these files:
test.proto
It has these top-level messages:
TestItem
*/
package proto
import proto1 "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import io "io"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto1.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto1.ProtoPackageIsVersion2 // please upgrade the proto package
type FOO int32
const (
FOO_X FOO = 0
)
var FOO_name = map[int32]string{
0: "X",
}
var FOO_value = map[string]int32{
"X": 0,
}
func (x FOO) String() string {
return proto1.EnumName(FOO_name, int32(x))
}
func (FOO) EnumDescriptor() ([]byte, []int) { return fileDescriptorTest, []int{0} }
type TestItem struct {
Name string `protobuf:"bytes,1,opt,name=Name,proto3" json:"Name,omitempty"`
Age int32 `protobuf:"varint,2,opt,name=Age,proto3" json:"Age,omitempty"`
}
func (m *TestItem) Reset() { *m = TestItem{} }
func (m *TestItem) String() string { return proto1.CompactTextString(m) }
func (*TestItem) ProtoMessage() {}
func (*TestItem) Descriptor() ([]byte, []int) { return fileDescriptorTest, []int{0} }
func (m *TestItem) GetName() string {
if m != nil {
return m.Name
}
return ""
}
func (m *TestItem) GetAge() int32 {
if m != nil {
return m.Age
}
return 0
}
func init() {
proto1.RegisterType((*TestItem)(nil), "proto.TestItem")
proto1.RegisterEnum("proto.FOO", FOO_name, FOO_value)
}
func (m *TestItem) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalTo(dAtA)
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *TestItem) MarshalTo(dAtA []byte) (int, error) {
var i int
_ = i
var l int
_ = l
if len(m.Name) > 0 {
dAtA[i] = 0xa
i++
i = encodeVarintTest(dAtA, i, uint64(len(m.Name)))
i += copy(dAtA[i:], m.Name)
}
if m.Age != 0 {
dAtA[i] = 0x10
i++
i = encodeVarintTest(dAtA, i, uint64(m.Age))
}
return i, nil
}
func encodeFixed64Test(dAtA []byte, offset int, v uint64) int {
dAtA[offset] = uint8(v)
dAtA[offset+1] = uint8(v >> 8)
dAtA[offset+2] = uint8(v >> 16)
dAtA[offset+3] = uint8(v >> 24)
dAtA[offset+4] = uint8(v >> 32)
dAtA[offset+5] = uint8(v >> 40)
dAtA[offset+6] = uint8(v >> 48)
dAtA[offset+7] = uint8(v >> 56)
return offset + 8
}
func encodeFixed32Test(dAtA []byte, offset int, v uint32) int {
dAtA[offset] = uint8(v)
dAtA[offset+1] = uint8(v >> 8)
dAtA[offset+2] = uint8(v >> 16)
dAtA[offset+3] = uint8(v >> 24)
return offset + 4
}
func encodeVarintTest(dAtA []byte, offset int, v uint64) int {
for v >= 1<<7 {
dAtA[offset] = uint8(v&0x7f | 0x80)
v >>= 7
offset++
}
dAtA[offset] = uint8(v)
return offset + 1
}
func (m *TestItem) Size() (n int) {
var l int
_ = l
l = len(m.Name)
if l > 0 {
n += 1 + l + sovTest(uint64(l))
}
if m.Age != 0 {
n += 1 + sovTest(uint64(m.Age))
}
return n
}
func sovTest(x uint64) (n int) {
for {
n++
x >>= 7
if x == 0 {
break
}
}
return n
}
func sozTest(x uint64) (n int) {
return sovTest(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
func (m *TestItem) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowTest
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: TestItem: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: TestItem: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Name", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowTest
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthTest
}
postIndex := iNdEx + intStringLen
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Name = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
case 2:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Age", wireType)
}
m.Age = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowTest
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Age |= (int32(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
default:
iNdEx = preIndex
skippy, err := skipTest(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthTest
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipTest(dAtA []byte) (n int, err error) {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowTest
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
wireType := int(wire & 0x7)
switch wireType {
case 0:
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowTest
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
iNdEx++
if dAtA[iNdEx-1] < 0x80 {
break
}
}
return iNdEx, nil
case 1:
iNdEx += 8
return iNdEx, nil
case 2:
var length int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowTest
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
length |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
iNdEx += length
if length < 0 {
return 0, ErrInvalidLengthTest
}
return iNdEx, nil
case 3:
for {
var innerWire uint64
var start int = iNdEx
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowTest
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
innerWire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
innerWireType := int(innerWire & 0x7)
if innerWireType == 4 {
break
}
next, err := skipTest(dAtA[start:])
if err != nil {
return 0, err
}
iNdEx = start + next
}
return iNdEx, nil
case 4:
return iNdEx, nil
case 5:
iNdEx += 4
return iNdEx, nil
default:
return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
}
}
panic("unreachable")
}
var (
ErrInvalidLengthTest = fmt.Errorf("proto: negative length found during unmarshaling")
ErrIntOverflowTest = fmt.Errorf("proto: integer overflow")
)
func init() { proto1.RegisterFile("test.proto", fileDescriptorTest) }
var fileDescriptorTest = []byte{
// 122 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x49, 0x2d, 0x2e,
0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x05, 0x53, 0x4a, 0x06, 0x5c, 0x1c, 0x21, 0xa9,
0xc5, 0x25, 0x9e, 0x25, 0xa9, 0xb9, 0x42, 0x42, 0x5c, 0x2c, 0x7e, 0x89, 0xb9, 0xa9, 0x12, 0x8c,
0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x60, 0xb6, 0x90, 0x00, 0x17, 0xb3, 0x63, 0x7a, 0xaa, 0x04, 0x93,
0x02, 0xa3, 0x06, 0x6b, 0x10, 0x88, 0xa9, 0xc5, 0xc3, 0xc5, 0xec, 0xe6, 0xef, 0x2f, 0xc4, 0xca,
0xc5, 0x18, 0x21, 0xc0, 0xe0, 0x24, 0x70, 0xe2, 0x91, 0x1c, 0xe3, 0x85, 0x47, 0x72, 0x8c, 0x0f,
0x1e, 0xc9, 0x31, 0xce, 0x78, 0x2c, 0xc7, 0x90, 0xc4, 0x06, 0x36, 0xd8, 0x18, 0x10, 0x00, 0x00,
0xff, 0xff, 0x16, 0x80, 0x60, 0x15, 0x6d, 0x00, 0x00, 0x00,
}

12
library/cache/memcache/test/test.proto vendored Normal file
View File

@@ -0,0 +1,12 @@
syntax = "proto3";
package proto;
enum FOO
{
X = 0;
};
message TestItem{
string Name = 1;
int32 Age = 2;
}

101
library/cache/memcache/trace.go vendored Normal file
View File

@@ -0,0 +1,101 @@
package memcache
import (
"context"
"strconv"
"strings"
"go-common/library/net/trace"
)
const (
_traceFamily = "memcache"
_traceSpanKind = "client"
_traceComponentName = "library/cache/memcache"
_tracePeerService = "memcache"
)
type traceConn struct {
Conn
ctx context.Context
address string
}
func (t *traceConn) setTrace(action, statement string) func(error) error {
parent, ok := trace.FromContext(t.ctx)
if !ok {
return func(err error) error { return err }
}
span := parent.Fork(_traceFamily, "Memcache:"+action)
span.SetTag(
trace.String(trace.TagSpanKind, _traceSpanKind),
trace.String(trace.TagComponent, _traceComponentName),
trace.String(trace.TagPeerService, _tracePeerService),
trace.String(trace.TagPeerAddress, t.address),
trace.String(trace.TagDBStatement, action+" "+statement),
)
return func(err error) error {
span.Finish(&err)
return err
}
}
func (t *traceConn) WithContext(ctx context.Context) Conn {
t.ctx = ctx
t.Conn = t.Conn.WithContext(ctx)
return t
}
func (t *traceConn) Add(item *Item) error {
finishFn := t.setTrace("Add", item.Key)
return finishFn(t.Conn.Add(item))
}
func (t *traceConn) Set(item *Item) error {
finishFn := t.setTrace("Set", item.Key)
return finishFn(t.Conn.Set(item))
}
func (t *traceConn) Replace(item *Item) error {
finishFn := t.setTrace("Replace", item.Key)
return finishFn(t.Conn.Replace(item))
}
func (t *traceConn) Get(key string) (*Item, error) {
finishFn := t.setTrace("Get", key)
item, err := t.Conn.Get(key)
return item, finishFn(err)
}
func (t *traceConn) GetMulti(keys []string) (map[string]*Item, error) {
finishFn := t.setTrace("GetMulti", strings.Join(keys, " "))
items, err := t.Conn.GetMulti(keys)
return items, finishFn(err)
}
func (t *traceConn) Delete(key string) error {
finishFn := t.setTrace("Delete", key)
return finishFn(t.Conn.Delete(key))
}
func (t *traceConn) Increment(key string, delta uint64) (newValue uint64, err error) {
finishFn := t.setTrace("Increment", key+" "+strconv.FormatUint(delta, 10))
newValue, err = t.Conn.Increment(key, delta)
return newValue, finishFn(err)
}
func (t *traceConn) Decrement(key string, delta uint64) (newValue uint64, err error) {
finishFn := t.setTrace("Decrement", key+" "+strconv.FormatUint(delta, 10))
newValue, err = t.Conn.Decrement(key, delta)
return newValue, finishFn(err)
}
func (t *traceConn) CompareAndSwap(item *Item) error {
finishFn := t.setTrace("CompareAndSwap", item.Key)
return finishFn(t.Conn.CompareAndSwap(item))
}
func (t *traceConn) Touch(key string, seconds int32) (err error) {
finishFn := t.setTrace("Touch", key+" "+strconv.Itoa(int(seconds)))
return finishFn(t.Conn.Touch(key, seconds))
}

32
library/cache/memcache/util.go vendored Normal file
View File

@@ -0,0 +1,32 @@
package memcache
import (
"github.com/gogo/protobuf/proto"
)
// RawItem item with FlagRAW flag.
//
// Expiration is the cache expiration time, in seconds: either a relative
// time from now (up to 1 month), or an absolute Unix epoch time.
// Zero means the Item has no expiration time.
func RawItem(key string, data []byte, flags uint32, expiration int32) *Item {
return &Item{Key: key, Flags: flags | FlagRAW, Value: data, Expiration: expiration}
}
// JSONItem item with FlagJSON flag.
//
// Expiration is the cache expiration time, in seconds: either a relative
// time from now (up to 1 month), or an absolute Unix epoch time.
// Zero means the Item has no expiration time.
func JSONItem(key string, v interface{}, flags uint32, expiration int32) *Item {
return &Item{Key: key, Flags: flags | FlagJSON, Object: v, Expiration: expiration}
}
// ProtobufItem item with FlagProtobuf flag.
//
// Expiration is the cache expiration time, in seconds: either a relative
// time from now (up to 1 month), or an absolute Unix epoch time.
// Zero means the Item has no expiration time.
func ProtobufItem(key string, message proto.Message, flags uint32, expiration int32) *Item {
return &Item{Key: key, Flags: flags | FlagProtobuf, Object: message, Expiration: expiration}
}

26
library/cache/memcache/util_test.go vendored Normal file
View File

@@ -0,0 +1,26 @@
package memcache
import (
"testing"
"github.com/stretchr/testify/assert"
pb "go-common/library/cache/memcache/test"
)
func TestItemUtil(t *testing.T) {
item1 := RawItem("test", []byte("hh"), 0, 0)
assert.Equal(t, "test", item1.Key)
assert.Equal(t, []byte("hh"), item1.Value)
assert.Equal(t, FlagRAW, FlagRAW&item1.Flags)
item1 = JSONItem("test", &Item{}, 0, 0)
assert.Equal(t, "test", item1.Key)
assert.NotNil(t, item1.Object)
assert.Equal(t, FlagJSON, FlagJSON&item1.Flags)
item1 = ProtobufItem("test", &pb.TestItem{}, 0, 0)
assert.Equal(t, "test", item1.Key)
assert.NotNil(t, item1.Object)
assert.Equal(t, FlagProtobuf, FlagProtobuf&item1.Flags)
}

74
library/cache/redis/BUILD vendored Normal file
View File

@@ -0,0 +1,74 @@
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
package(default_visibility = ["//visibility:public"])
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
)
go_library(
name = "go_default_library",
srcs = [
"commandinfo.go",
"conn.go",
"doc.go",
"errors.go",
"log.go",
"mock.go",
"pool.go",
"pubsub.go",
"redis.go",
"reply.go",
"scan.go",
"script.go",
"trace.go",
],
importpath = "go-common/library/cache/redis",
tags = ["automanaged"],
deps = [
"//library/container/pool:go_default_library",
"//library/net/trace:go_default_library",
"//library/stat:go_default_library",
"//library/time:go_default_library",
"//vendor/github.com/pkg/errors:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = [
"commandinfo_test.go",
"conn_test.go",
"pool_test.go",
"pubsub_test.go",
"redis_test.go",
"reply_test.go",
"scan_test.go",
"script_test.go",
"test_test.go",
"trace_test.go",
],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
deps = [
"//library/container/pool:go_default_library",
"//library/net/trace:go_default_library",
"//library/time:go_default_library",
"//vendor/github.com/pkg/errors:go_default_library",
"//vendor/github.com/stretchr/testify/assert:go_default_library",
],
)

10
library/cache/redis/CHANGELOG.md vendored Normal file
View File

@@ -0,0 +1,10 @@
### redis client
##### Version 1.1.1
> 1.add redis mockError type
##### Version 1.1.0
> 1.修改redis pool的实现方式引用container/pool
> 2.pool支持context传入超时以及Get connection WaitTimeout
##### Version 1.0.0
> 1. fix NewPool赋值最大空闲连接数

9
library/cache/redis/CONTRIBUTORS.md vendored Normal file
View File

@@ -0,0 +1,9 @@
# Owner
maojian
# Author
maojian
zhapuyu
# Reviewer
maojian

175
library/cache/redis/LICENSE vendored Normal file
View File

@@ -0,0 +1,175 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.

11
library/cache/redis/OWNERS vendored Normal file
View File

@@ -0,0 +1,11 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- maojian
- zhapuyu
labels:
- library
- library/cache/redis
reviewers:
- maojian
- zhapuyu

49
library/cache/redis/README.markdown vendored Normal file
View File

@@ -0,0 +1,49 @@
Redigo
======
[![Build Status](https://travis-ci.org/garyburd/redigo.svg?branch=master)](https://travis-ci.org/garyburd/redigo)
Redigo is a [Go](http://golang.org/) client for the [Redis](http://redis.io/) database.
Features
-------
* A [Print-like](http://godoc.org/github.com/garyburd/redigo/redis#hdr-Executing_Commands) API with support for all Redis commands.
* [Pipelining](http://godoc.org/github.com/garyburd/redigo/redis#hdr-Pipelining), including pipelined transactions.
* [Publish/Subscribe](http://godoc.org/github.com/garyburd/redigo/redis#hdr-Publish_and_Subscribe).
* [Connection pooling](http://godoc.org/github.com/garyburd/redigo/redis#Pool).
* [Script helper type](http://godoc.org/github.com/garyburd/redigo/redis#Script) with optimistic use of EVALSHA.
* [Helper functions](http://godoc.org/github.com/garyburd/redigo/redis#hdr-Reply_Helpers) for working with command replies.
Documentation
-------------
- [API Reference](http://godoc.org/github.com/garyburd/redigo/redis)
- [FAQ](https://github.com/garyburd/redigo/wiki/FAQ)
Installation
------------
Install Redigo using the "go get" command:
go get github.com/garyburd/redigo/redis
The Go distribution is Redigo's only dependency.
Related Projects
----------------
- [rafaeljusto/redigomock](https://godoc.org/github.com/rafaeljusto/redigomock) - A mock library for Redigo.
- [chasex/redis-go-cluster](https://github.com/chasex/redis-go-cluster) - A Redis cluster client implementation.
Contributing
------------
Gary is looking for someone to take over maintenance of this project. If you are interested, contact Gary at the email address listed on his GitHub profile page.
PRs for major features will not be accepted until a new maintainer is found. Bug reports and PRs for bug fixes are welcome.
License
-------
Redigo is available under the [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0.html).

57
library/cache/redis/commandinfo.go vendored Normal file
View File

@@ -0,0 +1,57 @@
// Copyright 2014 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"strings"
)
// redis state
const (
WatchState = 1 << iota
MultiState
SubscribeState
MonitorState
)
// CommandInfo command info.
type CommandInfo struct {
Set, Clear int
}
var commandInfos = map[string]CommandInfo{
"WATCH": {Set: WatchState},
"UNWATCH": {Clear: WatchState},
"MULTI": {Set: MultiState},
"EXEC": {Clear: WatchState | MultiState},
"DISCARD": {Clear: WatchState | MultiState},
"PSUBSCRIBE": {Set: SubscribeState},
"SUBSCRIBE": {Set: SubscribeState},
"MONITOR": {Set: MonitorState},
}
func init() {
for n, ci := range commandInfos {
commandInfos[strings.ToLower(n)] = ci
}
}
// LookupCommandInfo get command info.
func LookupCommandInfo(commandName string) CommandInfo {
if ci, ok := commandInfos[commandName]; ok {
return ci
}
return commandInfos[strings.ToUpper(commandName)]
}

27
library/cache/redis/commandinfo_test.go vendored Normal file
View File

@@ -0,0 +1,27 @@
package redis
import "testing"
func TestLookupCommandInfo(t *testing.T) {
for _, n := range []string{"watch", "WATCH", "wAtch"} {
if LookupCommandInfo(n) == (CommandInfo{}) {
t.Errorf("LookupCommandInfo(%q) = CommandInfo{}, expected non-zero value", n)
}
}
}
func benchmarkLookupCommandInfo(b *testing.B, names ...string) {
for i := 0; i < b.N; i++ {
for _, c := range names {
LookupCommandInfo(c)
}
}
}
func BenchmarkLookupCommandInfoCorrectCase(b *testing.B) {
benchmarkLookupCommandInfo(b, "watch", "WATCH", "monitor", "MONITOR")
}
func BenchmarkLookupCommandInfoMixedCase(b *testing.B) {
benchmarkLookupCommandInfo(b, "wAtch", "WeTCH", "monItor", "MONiTOR")
}

597
library/cache/redis/conn.go vendored Normal file
View File

@@ -0,0 +1,597 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
"net/url"
"regexp"
"strconv"
"sync"
"time"
"go-common/library/stat"
"github.com/pkg/errors"
)
var stats = stat.Cache
// conn is the low-level implementation of Conn
type conn struct {
// Shared
mu sync.Mutex
pending int
err error
conn net.Conn
// Read
readTimeout time.Duration
br *bufio.Reader
// Write
writeTimeout time.Duration
bw *bufio.Writer
// Scratch space for formatting argument length.
// '*' or '$', length, "\r\n"
lenScratch [32]byte
// Scratch space for formatting integers and floats.
numScratch [40]byte
// stat func,default prom
stat func(string, *error) func()
}
func statfunc(cmd string, err *error) func() {
now := time.Now()
return func() {
stats.Timing(fmt.Sprintf("redis:%s", cmd), int64(time.Since(now)/time.Millisecond))
if err != nil {
if msg := formatErr(*err); msg != "" {
stats.Incr("redis", msg)
}
}
}
}
// DialTimeout acts like Dial but takes timeouts for establishing the
// connection to the server, writing a command and reading a reply.
//
// Deprecated: Use Dial with options instead.
func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) {
return Dial(network, address,
DialConnectTimeout(connectTimeout),
DialReadTimeout(readTimeout),
DialWriteTimeout(writeTimeout))
}
// DialOption specifies an option for dialing a Redis server.
type DialOption struct {
f func(*dialOptions)
}
type dialOptions struct {
readTimeout time.Duration
writeTimeout time.Duration
dial func(network, addr string) (net.Conn, error)
db int
password string
stat func(string, *error) func()
}
// DialStats specifies stat func for stats.default statfunc.
func DialStats(fn func(string, *error) func()) DialOption {
return DialOption{func(do *dialOptions) {
do.stat = fn
}}
}
// DialReadTimeout specifies the timeout for reading a single command reply.
func DialReadTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.readTimeout = d
}}
}
// DialWriteTimeout specifies the timeout for writing a single command.
func DialWriteTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.writeTimeout = d
}}
}
// DialConnectTimeout specifies the timeout for connecting to the Redis server.
func DialConnectTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
dialer := net.Dialer{Timeout: d}
do.dial = dialer.Dial
}}
}
// DialNetDial specifies a custom dial function for creating TCP
// connections. If this option is left out, then net.Dial is
// used. DialNetDial overrides DialConnectTimeout.
func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption {
return DialOption{func(do *dialOptions) {
do.dial = dial
}}
}
// DialDatabase specifies the database to select when dialing a connection.
func DialDatabase(db int) DialOption {
return DialOption{func(do *dialOptions) {
do.db = db
}}
}
// DialPassword specifies the password to use when connecting to
// the Redis server.
func DialPassword(password string) DialOption {
return DialOption{func(do *dialOptions) {
do.password = password
}}
}
// Dial connects to the Redis server at the given network and
// address using the specified options.
func Dial(network, address string, options ...DialOption) (Conn, error) {
do := dialOptions{
dial: net.Dial,
}
for _, option := range options {
option.f(&do)
}
netConn, err := do.dial(network, address)
if err != nil {
return nil, errors.WithStack(err)
}
c := &conn{
conn: netConn,
bw: bufio.NewWriter(netConn),
br: bufio.NewReader(netConn),
readTimeout: do.readTimeout,
writeTimeout: do.writeTimeout,
stat: statfunc,
}
if do.password != "" {
if _, err := c.Do("AUTH", do.password); err != nil {
netConn.Close()
return nil, errors.WithStack(err)
}
}
if do.db != 0 {
if _, err := c.Do("SELECT", do.db); err != nil {
netConn.Close()
return nil, errors.WithStack(err)
}
}
if do.stat != nil {
c.stat = do.stat
}
return c, nil
}
var pathDBRegexp = regexp.MustCompile(`/(\d+)\z`)
// DialURL connects to a Redis server at the given URL using the Redis
// URI scheme. URLs should follow the draft IANA specification for the
// scheme (https://www.iana.org/assignments/uri-schemes/prov/redis).
func DialURL(rawurl string, options ...DialOption) (Conn, error) {
u, err := url.Parse(rawurl)
if err != nil {
return nil, errors.WithStack(err)
}
if u.Scheme != "redis" {
return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme)
}
// As per the IANA draft spec, the host defaults to localhost and
// the port defaults to 6379.
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
// assume port is missing
host = u.Host
port = "6379"
}
if host == "" {
host = "localhost"
}
address := net.JoinHostPort(host, port)
if u.User != nil {
password, isSet := u.User.Password()
if isSet {
options = append(options, DialPassword(password))
}
}
match := pathDBRegexp.FindStringSubmatch(u.Path)
if len(match) == 2 {
db, err := strconv.Atoi(match[1])
if err != nil {
return nil, errors.Errorf("invalid database: %s", u.Path[1:])
}
if db != 0 {
options = append(options, DialDatabase(db))
}
} else if u.Path != "" {
return nil, errors.Errorf("invalid database: %s", u.Path[1:])
}
return Dial("tcp", address, options...)
}
// NewConn new a redis conn.
func NewConn(c *Config) (cn Conn, err error) {
cnop := DialConnectTimeout(time.Duration(c.DialTimeout))
rdop := DialReadTimeout(time.Duration(c.ReadTimeout))
wrop := DialWriteTimeout(time.Duration(c.WriteTimeout))
auop := DialPassword(c.Auth)
// new conn
cn, err = Dial(c.Proto, c.Addr, cnop, rdop, wrop, auop)
return
}
func (c *conn) Close() error {
c.mu.Lock()
err := c.err
if c.err == nil {
c.err = errors.New("redigo: closed")
err = c.conn.Close()
}
c.mu.Unlock()
return err
}
func (c *conn) fatal(err error) error {
c.mu.Lock()
if c.err == nil {
c.err = err
// Close connection to force errors on subsequent calls and to unblock
// other reader or writer.
c.conn.Close()
}
c.mu.Unlock()
return errors.WithStack(c.err)
}
func (c *conn) Err() error {
c.mu.Lock()
err := c.err
c.mu.Unlock()
return err
}
func (c *conn) writeLen(prefix byte, n int) error {
c.lenScratch[len(c.lenScratch)-1] = '\n'
c.lenScratch[len(c.lenScratch)-2] = '\r'
i := len(c.lenScratch) - 3
for {
c.lenScratch[i] = byte('0' + n%10)
i--
n = n / 10
if n == 0 {
break
}
}
c.lenScratch[i] = prefix
_, err := c.bw.Write(c.lenScratch[i:])
return errors.WithStack(err)
}
func (c *conn) writeString(s string) error {
c.writeLen('$', len(s))
c.bw.WriteString(s)
_, err := c.bw.WriteString("\r\n")
return errors.WithStack(err)
}
func (c *conn) writeBytes(p []byte) error {
c.writeLen('$', len(p))
c.bw.Write(p)
_, err := c.bw.WriteString("\r\n")
return errors.WithStack(err)
}
func (c *conn) writeInt64(n int64) error {
return errors.WithStack(c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10)))
}
func (c *conn) writeFloat64(n float64) error {
return errors.WithStack(c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64)))
}
func (c *conn) writeCommand(cmd string, args []interface{}) (err error) {
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
c.writeLen('*', 1+len(args))
err = c.writeString(cmd)
for _, arg := range args {
if err != nil {
break
}
switch arg := arg.(type) {
case string:
err = c.writeString(arg)
case []byte:
err = c.writeBytes(arg)
case int:
err = c.writeInt64(int64(arg))
case int64:
err = c.writeInt64(arg)
case float64:
err = c.writeFloat64(arg)
case bool:
if arg {
err = c.writeString("1")
} else {
err = c.writeString("0")
}
case nil:
err = c.writeString("")
default:
var buf bytes.Buffer
fmt.Fprint(&buf, arg)
err = errors.WithStack(c.writeBytes(buf.Bytes()))
}
}
return err
}
type protocolError string
func (pe protocolError) Error() string {
return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe))
}
func (c *conn) readLine() ([]byte, error) {
p, err := c.br.ReadSlice('\n')
if err == bufio.ErrBufferFull {
return nil, errors.WithStack(protocolError("long response line"))
}
if err != nil {
return nil, err
}
i := len(p) - 2
if i < 0 || p[i] != '\r' {
return nil, errors.WithStack(protocolError("bad response line terminator"))
}
return p[:i], nil
}
// parseLen parses bulk string and array lengths.
func parseLen(p []byte) (int, error) {
if len(p) == 0 {
return -1, errors.WithStack(protocolError("malformed length"))
}
if p[0] == '-' && len(p) == 2 && p[1] == '1' {
// handle $-1 and $-1 null replies.
return -1, nil
}
var n int
for _, b := range p {
n *= 10
if b < '0' || b > '9' {
return -1, errors.WithStack(protocolError("illegal bytes in length"))
}
n += int(b - '0')
}
return n, nil
}
// parseInt parses an integer reply.
func parseInt(p []byte) (interface{}, error) {
if len(p) == 0 {
return 0, errors.WithStack(protocolError("malformed integer"))
}
var negate bool
if p[0] == '-' {
negate = true
p = p[1:]
if len(p) == 0 {
return 0, errors.WithStack(protocolError("malformed integer"))
}
}
var n int64
for _, b := range p {
n *= 10
if b < '0' || b > '9' {
return 0, errors.WithStack(protocolError("illegal bytes in length"))
}
n += int64(b - '0')
}
if negate {
n = -n
}
return n, nil
}
var (
okReply interface{} = "OK"
pongReply interface{} = "PONG"
)
func (c *conn) readReply() (interface{}, error) {
line, err := c.readLine()
if err != nil {
return nil, err
}
if len(line) == 0 {
return nil, errors.WithStack(protocolError("short response line"))
}
switch line[0] {
case '+':
switch {
case len(line) == 3 && line[1] == 'O' && line[2] == 'K':
// Avoid allocation for frequent "+OK" response.
return okReply, nil
case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G':
// Avoid allocation in PING command benchmarks :)
return pongReply, nil
default:
return string(line[1:]), nil
}
case '-':
return Error(string(line[1:])), nil
case ':':
return parseInt(line[1:])
case '$':
n, err := parseLen(line[1:])
if n < 0 || err != nil {
return nil, err
}
p := make([]byte, n)
_, err = io.ReadFull(c.br, p)
if err != nil {
return nil, errors.WithStack(err)
}
if line1, err := c.readLine(); err != nil {
return nil, err
} else if len(line1) != 0 {
return nil, errors.WithStack(protocolError("bad bulk string format"))
}
return p, nil
case '*':
n, err := parseLen(line[1:])
if n < 0 || err != nil {
return nil, err
}
r := make([]interface{}, n)
for i := range r {
r[i], err = c.readReply()
if err != nil {
return nil, err
}
}
return r, nil
}
return nil, errors.WithStack(protocolError("unexpected response line"))
}
func (c *conn) Send(cmd string, args ...interface{}) (err error) {
c.mu.Lock()
c.pending++
c.mu.Unlock()
if err = c.writeCommand(cmd, args); err != nil {
c.fatal(err)
}
return err
}
func (c *conn) Flush() (err error) {
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
if err = c.bw.Flush(); err != nil {
c.fatal(err)
}
return err
}
func (c *conn) Receive() (reply interface{}, err error) {
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
if reply, err = c.readReply(); err != nil {
return nil, c.fatal(err)
}
// When using pub/sub, the number of receives can be greater than the
// number of sends. To enable normal use of the connection after
// unsubscribing from all channels, we do not decrement pending to a
// negative value.
//
// The pending field is decremented after the reply is read to handle the
// case where Receive is called before Send.
c.mu.Lock()
if c.pending > 0 {
c.pending--
}
c.mu.Unlock()
if err, ok := reply.(Error); ok {
return nil, err
}
return
}
func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
c.mu.Lock()
pending := c.pending
c.pending = 0
c.mu.Unlock()
if cmd == "" && pending == 0 {
return nil, nil
}
var err error
defer c.stat(cmd, &err)()
if cmd != "" {
err = c.writeCommand(cmd, args)
}
if err == nil {
err = errors.WithStack(c.bw.Flush())
}
if err != nil {
return nil, c.fatal(err)
}
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
if cmd == "" {
reply := make([]interface{}, pending)
for i := range reply {
var r interface{}
r, err = c.readReply()
if err != nil {
break
}
reply[i] = r
}
if err != nil {
return nil, c.fatal(err)
}
return reply, nil
}
var reply interface{}
for i := 0; i <= pending; i++ {
var e error
if reply, e = c.readReply(); e != nil {
return nil, c.fatal(e)
}
if e, ok := reply.(Error); ok && err == nil {
err = e
}
}
return reply, err
}
// WithContext FIXME: implement WithContext
func (c *conn) WithContext(ctx context.Context) Conn { return c }

657
library/cache/redis/conn_test.go vendored Normal file
View File

@@ -0,0 +1,657 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"bytes"
"io"
"math"
"net"
"os"
"reflect"
"strings"
"testing"
"time"
)
type testConn struct {
io.Reader
io.Writer
}
func (*testConn) Close() error { return nil }
func (*testConn) LocalAddr() net.Addr { return nil }
func (*testConn) RemoteAddr() net.Addr { return nil }
func (*testConn) SetDeadline(t time.Time) error { return nil }
func (*testConn) SetReadDeadline(t time.Time) error { return nil }
func (*testConn) SetWriteDeadline(t time.Time) error { return nil }
func dialTestConn(r io.Reader, w io.Writer) DialOption {
return DialNetDial(func(net, addr string) (net.Conn, error) {
return &testConn{Reader: r, Writer: w}, nil
})
}
var writeTests = []struct {
args []interface{}
expected string
}{
{
[]interface{}{"SET", "key", "value"},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
},
{
[]interface{}{"SET", "key", "value"},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
},
{
[]interface{}{"SET", "key", byte(100)},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$3\r\n100\r\n",
},
{
[]interface{}{"SET", "key", 100},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$3\r\n100\r\n",
},
{
[]interface{}{"SET", "key", int64(math.MinInt64)},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$20\r\n-9223372036854775808\r\n",
},
{
[]interface{}{"SET", "key", float64(1349673917.939762)},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$21\r\n1.349673917939762e+09\r\n",
},
{
[]interface{}{"SET", "key", ""},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
},
{
[]interface{}{"SET", "key", nil},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
},
{
[]interface{}{"ECHO", true, false},
"*3\r\n$4\r\nECHO\r\n$1\r\n1\r\n$1\r\n0\r\n",
},
}
func TestWrite(t *testing.T) {
for _, tt := range writeTests {
var buf bytes.Buffer
c, _ := Dial("", "", dialTestConn(nil, &buf))
err := c.Send(tt.args[0].(string), tt.args[1:]...)
if err != nil {
t.Errorf("Send(%v) returned error %v", tt.args, err)
continue
}
c.Flush()
actual := buf.String()
if actual != tt.expected {
t.Errorf("Send(%v) = %q, want %q", tt.args, actual, tt.expected)
}
}
}
var errorSentinel = &struct{}{}
var readTests = []struct {
reply string
expected interface{}
}{
{
"+OK\r\n",
"OK",
},
{
"+PONG\r\n",
"PONG",
},
{
"@OK\r\n",
errorSentinel,
},
{
"$6\r\nfoobar\r\n",
[]byte("foobar"),
},
{
"$-1\r\n",
nil,
},
{
":1\r\n",
int64(1),
},
{
":-2\r\n",
int64(-2),
},
{
"*0\r\n",
[]interface{}{},
},
{
"*-1\r\n",
nil,
},
{
"*4\r\n$3\r\nfoo\r\n$3\r\nbar\r\n$5\r\nHello\r\n$5\r\nWorld\r\n",
[]interface{}{[]byte("foo"), []byte("bar"), []byte("Hello"), []byte("World")},
},
{
"*3\r\n$3\r\nfoo\r\n$-1\r\n$3\r\nbar\r\n",
[]interface{}{[]byte("foo"), nil, []byte("bar")},
},
{
// "x" is not a valid length
"$x\r\nfoobar\r\n",
errorSentinel,
},
{
// -2 is not a valid length
"$-2\r\n",
errorSentinel,
},
{
// "x" is not a valid integer
":x\r\n",
errorSentinel,
},
{
// missing \r\n following value
"$6\r\nfoobar",
errorSentinel,
},
{
// short value
"$6\r\nxx",
errorSentinel,
},
{
// long value
"$6\r\nfoobarx\r\n",
errorSentinel,
},
}
func TestRead(t *testing.T) {
for _, tt := range readTests {
c, _ := Dial("", "", dialTestConn(strings.NewReader(tt.reply), nil))
actual, err := c.Receive()
if tt.expected == errorSentinel {
if err == nil {
t.Errorf("Receive(%q) did not return expected error", tt.reply)
}
} else {
if err != nil {
t.Errorf("Receive(%q) returned error %v", tt.reply, err)
continue
}
if !reflect.DeepEqual(actual, tt.expected) {
t.Errorf("Receive(%q) = %v, want %v", tt.reply, actual, tt.expected)
}
}
}
}
var testCommands = []struct {
args []interface{}
expected interface{}
}{
{
[]interface{}{"PING"},
"PONG",
},
{
[]interface{}{"SET", "foo", "bar"},
"OK",
},
{
[]interface{}{"GET", "foo"},
[]byte("bar"),
},
{
[]interface{}{"GET", "nokey"},
nil,
},
{
[]interface{}{"MGET", "nokey", "foo"},
[]interface{}{nil, []byte("bar")},
},
{
[]interface{}{"INCR", "mycounter"},
int64(1),
},
{
[]interface{}{"LPUSH", "mylist", "foo"},
int64(1),
},
{
[]interface{}{"LPUSH", "mylist", "bar"},
int64(2),
},
{
[]interface{}{"LRANGE", "mylist", 0, -1},
[]interface{}{[]byte("bar"), []byte("foo")},
},
{
[]interface{}{"MULTI"},
"OK",
},
{
[]interface{}{"LRANGE", "mylist", 0, -1},
"QUEUED",
},
{
[]interface{}{"PING"},
"QUEUED",
},
{
[]interface{}{"EXEC"},
[]interface{}{
[]interface{}{[]byte("bar"), []byte("foo")},
"PONG",
},
},
}
func TestDoCommands(t *testing.T) {
c, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
for _, cmd := range testCommands {
actual, err := c.Do(cmd.args[0].(string), cmd.args[1:]...)
if err != nil {
t.Errorf("Do(%v) returned error %v", cmd.args, err)
continue
}
if !reflect.DeepEqual(actual, cmd.expected) {
t.Errorf("Do(%v) = %v, want %v", cmd.args, actual, cmd.expected)
}
}
}
func TestPipelineCommands(t *testing.T) {
c, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
for _, cmd := range testCommands {
if err := c.Send(cmd.args[0].(string), cmd.args[1:]...); err != nil {
t.Fatalf("Send(%v) returned error %v", cmd.args, err)
}
}
if err := c.Flush(); err != nil {
t.Errorf("Flush() returned error %v", err)
}
for _, cmd := range testCommands {
actual, err := c.Receive()
if err != nil {
t.Fatalf("Receive(%v) returned error %v", cmd.args, err)
}
if !reflect.DeepEqual(actual, cmd.expected) {
t.Errorf("Receive(%v) = %v, want %v", cmd.args, actual, cmd.expected)
}
}
}
func TestBlankCommmand(t *testing.T) {
c, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
for _, cmd := range testCommands {
if err = c.Send(cmd.args[0].(string), cmd.args[1:]...); err != nil {
t.Fatalf("Send(%v) returned error %v", cmd.args, err)
}
}
reply, err := Values(c.Do(""))
if err != nil {
t.Fatalf("Do() returned error %v", err)
}
if len(reply) != len(testCommands) {
t.Fatalf("len(reply)=%d, want %d", len(reply), len(testCommands))
}
for i, cmd := range testCommands {
actual := reply[i]
if !reflect.DeepEqual(actual, cmd.expected) {
t.Errorf("Receive(%v) = %v, want %v", cmd.args, actual, cmd.expected)
}
}
}
func TestRecvBeforeSend(t *testing.T) {
c, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
done := make(chan struct{})
go func() {
c.Receive()
close(done)
}()
time.Sleep(time.Millisecond)
c.Send("PING")
c.Flush()
<-done
_, err = c.Do("")
if err != nil {
t.Fatalf("error=%v", err)
}
}
func TestError(t *testing.T) {
c, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
c.Do("SET", "key", "val")
_, err = c.Do("HSET", "key", "fld", "val")
if err == nil {
t.Errorf("Expected err for HSET on string key.")
}
if c.Err() != nil {
t.Errorf("Conn has Err()=%v, expect nil", c.Err())
}
_, err = c.Do("SET", "key", "val")
if err != nil {
t.Errorf("Do(SET, key, val) returned error %v, expected nil.", err)
}
}
func TestReadTimeout(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen returned %v", err)
}
defer l.Close()
go func() {
for {
c, err1 := l.Accept()
if err1 != nil {
return
}
go func() {
time.Sleep(time.Second)
c.Write([]byte("+OK\r\n"))
c.Close()
}()
}
}()
// Do
c1, err := Dial(l.Addr().Network(), l.Addr().String(), DialReadTimeout(time.Millisecond))
if err != nil {
t.Fatalf("Dial returned %v", err)
}
defer c1.Close()
_, err = c1.Do("PING")
if err == nil {
t.Fatalf("c1.Do() returned nil, expect error")
}
if c1.Err() == nil {
t.Fatalf("c1.Err() = nil, expect error")
}
// Send/Flush/Receive
c2, err := Dial(l.Addr().Network(), l.Addr().String(), DialReadTimeout(time.Millisecond))
if err != nil {
t.Fatalf("Dial returned %v", err)
}
defer c2.Close()
c2.Send("PING")
c2.Flush()
_, err = c2.Receive()
if err == nil {
t.Fatalf("c2.Receive() returned nil, expect error")
}
if c2.Err() == nil {
t.Fatalf("c2.Err() = nil, expect error")
}
}
var dialErrors = []struct {
rawurl string
expectedError string
}{
{
"localhost",
"invalid redis URL scheme",
},
// The error message for invalid hosts is diffferent in different
// versions of Go, so just check that there is an error message.
{
"redis://weird url",
"",
},
{
"redis://foo:bar:baz",
"",
},
{
"http://www.google.com",
"invalid redis URL scheme: http",
},
{
"redis://localhost:6379/abc123",
"invalid database: abc123",
},
}
func TestDialURLErrors(t *testing.T) {
for _, d := range dialErrors {
_, err := DialURL(d.rawurl)
if err == nil || !strings.Contains(err.Error(), d.expectedError) {
t.Errorf("DialURL did not return expected error (expected %v to contain %s)", err, d.expectedError)
}
}
}
func TestDialURLPort(t *testing.T) {
checkPort := func(network, address string) (net.Conn, error) {
if address != "localhost:6379" {
t.Errorf("DialURL did not set port to 6379 by default (got %v)", address)
}
return nil, nil
}
_, err := DialURL("redis://localhost", DialNetDial(checkPort))
if err != nil {
t.Error("dial error:", err)
}
}
func TestDialURLHost(t *testing.T) {
checkHost := func(network, address string) (net.Conn, error) {
if address != "localhost:6379" {
t.Errorf("DialURL did not set host to localhost by default (got %v)", address)
}
return nil, nil
}
_, err := DialURL("redis://:6379", DialNetDial(checkHost))
if err != nil {
t.Error("dial error:", err)
}
}
func TestDialURLPassword(t *testing.T) {
var buf bytes.Buffer
_, err := DialURL("redis://x:abc123@localhost", dialTestConn(strings.NewReader("+OK\r\n"), &buf))
if err != nil {
t.Error("dial error:", err)
}
expected := "*2\r\n$4\r\nAUTH\r\n$6\r\nabc123\r\n"
actual := buf.String()
if actual != expected {
t.Errorf("commands = %q, want %q", actual, expected)
}
}
func TestDialURLDatabase(t *testing.T) {
var buf bytes.Buffer
_, err := DialURL("redis://localhost/3", dialTestConn(strings.NewReader("+OK\r\n"), &buf))
if err != nil {
t.Error("dial error:", err)
}
expected := "*2\r\n$6\r\nSELECT\r\n$1\r\n3\r\n"
actual := buf.String()
if actual != expected {
t.Errorf("commands = %q, want %q", actual, expected)
}
}
// Connect to local instance of Redis running on the default port.
func ExampleDial() {
c, err := Dial("tcp", ":6379")
if err != nil {
// handle error
}
defer c.Close()
}
// Connect to remote instance of Redis using a URL.
func ExampleDialURL() {
c, err := DialURL(os.Getenv("REDIS_URL"))
if err != nil {
// handle connection error
}
defer c.Close()
}
// TextExecError tests handling of errors in a transaction. See
// http://io/topics/transactions for information on how Redis handles
// errors in a transaction.
func TestExecError(t *testing.T) {
c, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
// Execute commands that fail before EXEC is called.
c.Do("DEL", "k0")
c.Do("ZADD", "k0", 0, 0)
c.Send("MULTI")
c.Send("NOTACOMMAND", "k0", 0, 0)
c.Send("ZINCRBY", "k0", 0, 0)
v, err := c.Do("EXEC")
if err == nil {
t.Fatalf("EXEC returned values %v, expected error", v)
}
// Execute commands that fail after EXEC is called. The first command
// returns an error.
c.Do("DEL", "k1")
c.Do("ZADD", "k1", 0, 0)
c.Send("MULTI")
c.Send("HSET", "k1", 0, 0)
c.Send("ZINCRBY", "k1", 0, 0)
v, err = c.Do("EXEC")
if err != nil {
t.Fatalf("EXEC returned error %v", err)
}
vs, err := Values(v, nil)
if err != nil {
t.Fatalf("Values(v) returned error %v", err)
}
if len(vs) != 2 {
t.Fatalf("len(vs) == %d, want 2", len(vs))
}
if _, ok := vs[0].(error); !ok {
t.Fatalf("first result is type %T, expected error", vs[0])
}
if _, ok := vs[1].([]byte); !ok {
t.Fatalf("second result is type %T, expected []byte", vs[1])
}
// Execute commands that fail after EXEC is called. The second command
// returns an error.
c.Do("ZADD", "k2", 0, 0)
c.Send("MULTI")
c.Send("ZINCRBY", "k2", 0, 0)
c.Send("HSET", "k2", 0, 0)
v, err = c.Do("EXEC")
if err != nil {
t.Fatalf("EXEC returned error %v", err)
}
vs, err = Values(v, nil)
if err != nil {
t.Fatalf("Values(v) returned error %v", err)
}
if len(vs) != 2 {
t.Fatalf("len(vs) == %d, want 2", len(vs))
}
if _, ok := vs[0].([]byte); !ok {
t.Fatalf("first result is type %T, expected []byte", vs[0])
}
if _, ok := vs[1].(error); !ok {
t.Fatalf("second result is type %T, expected error", vs[2])
}
}
func BenchmarkDoEmpty(b *testing.B) {
b.StopTimer()
c, err := DialDefaultServer()
if err != nil {
b.Fatal(err)
}
defer c.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
if _, err := c.Do(""); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkDoPing(b *testing.B) {
b.StopTimer()
c, err := DialDefaultServer()
if err != nil {
b.Fatal(err)
}
defer c.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
if _, err := c.Do("PING"); err != nil {
b.Fatal(err)
}
}
}

169
library/cache/redis/doc.go vendored Normal file
View File

@@ -0,0 +1,169 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
// Package redis is a client for the Redis database.
//
// The Redigo FAQ (https://github.com/garyburd/redigo/wiki/FAQ) contains more
// documentation about this package.
//
// Connections
//
// The Conn interface is the primary interface for working with Redis.
// Applications create connections by calling the Dial, DialWithTimeout or
// NewConn functions. In the future, functions will be added for creating
// sharded and other types of connections.
//
// The application must call the connection Close method when the application
// is done with the connection.
//
// Executing Commands
//
// The Conn interface has a generic method for executing Redis commands:
//
// Do(commandName string, args ...interface{}) (reply interface{}, err error)
//
// The Redis command reference (http://redis.io/commands) lists the available
// commands. An example of using the Redis APPEND command is:
//
// n, err := conn.Do("APPEND", "key", "value")
//
// The Do method converts command arguments to binary strings for transmission
// to the server as follows:
//
// Go Type Conversion
// []byte Sent as is
// string Sent as is
// int, int64 strconv.FormatInt(v)
// float64 strconv.FormatFloat(v, 'g', -1, 64)
// bool true -> "1", false -> "0"
// nil ""
// all other types fmt.Print(v)
//
// Redis command reply types are represented using the following Go types:
//
// Redis type Go type
// error redis.Error
// integer int64
// simple string string
// bulk string []byte or nil if value not present.
// array []interface{} or nil if value not present.
//
// Use type assertions or the reply helper functions to convert from
// interface{} to the specific Go type for the command result.
//
// Pipelining
//
// Connections support pipelining using the Send, Flush and Receive methods.
//
// Send(commandName string, args ...interface{}) error
// Flush() error
// Receive() (reply interface{}, err error)
//
// Send writes the command to the connection's output buffer. Flush flushes the
// connection's output buffer to the server. Receive reads a single reply from
// the server. The following example shows a simple pipeline.
//
// c.Send("SET", "foo", "bar")
// c.Send("GET", "foo")
// c.Flush()
// c.Receive() // reply from SET
// v, err = c.Receive() // reply from GET
//
// The Do method combines the functionality of the Send, Flush and Receive
// methods. The Do method starts by writing the command and flushing the output
// buffer. Next, the Do method receives all pending replies including the reply
// for the command just sent by Do. If any of the received replies is an error,
// then Do returns the error. If there are no errors, then Do returns the last
// reply. If the command argument to the Do method is "", then the Do method
// will flush the output buffer and receive pending replies without sending a
// command.
//
// Use the Send and Do methods to implement pipelined transactions.
//
// c.Send("MULTI")
// c.Send("INCR", "foo")
// c.Send("INCR", "bar")
// r, err := c.Do("EXEC")
// fmt.Println(r) // prints [1, 1]
//
// Concurrency
//
// Connections do not support concurrent calls to the write methods (Send,
// Flush) or concurrent calls to the read method (Receive). Connections do
// allow a concurrent reader and writer.
//
// Because the Do method combines the functionality of Send, Flush and Receive,
// the Do method cannot be called concurrently with the other methods.
//
// For full concurrent access to Redis, use the thread-safe Pool to get and
// release connections from within a goroutine.
//
// Publish and Subscribe
//
// Use the Send, Flush and Receive methods to implement Pub/Sub subscribers.
//
// c.Send("SUBSCRIBE", "example")
// c.Flush()
// for {
// reply, err := c.Receive()
// if err != nil {
// return err
// }
// // process pushed message
// }
//
// The PubSubConn type wraps a Conn with convenience methods for implementing
// subscribers. The Subscribe, PSubscribe, Unsubscribe and PUnsubscribe methods
// send and flush a subscription management command. The receive method
// converts a pushed message to convenient types for use in a type switch.
//
// psc := redis.PubSubConn{c}
// psc.Subscribe("example")
// for {
// switch v := psc.Receive().(type) {
// case redis.Message:
// fmt.Printf("%s: message: %s\n", v.Channel, v.Data)
// case redis.Subscription:
// fmt.Printf("%s: %s %d\n", v.Channel, v.Kind, v.Count)
// case error:
// return v
// }
// }
//
// Reply Helpers
//
// The Bool, Int, Bytes, String, Strings and Values functions convert a reply
// to a value of a specific type. To allow convenient wrapping of calls to the
// connection Do and Receive methods, the functions take a second argument of
// type error. If the error is non-nil, then the helper function returns the
// error. If the error is nil, the function converts the reply to the specified
// type:
//
// exists, err := redis.Bool(c.Do("EXISTS", "foo"))
// if err != nil {
// // handle error return from c.Do or type conversion error.
// }
//
// The Scan function converts elements of a array reply to Go types:
//
// var value1 int
// var value2 string
// reply, err := redis.Values(c.Do("MGET", "key1", "key2"))
// if err != nil {
// // handle error
// }
// if _, err := redis.Scan(reply, &value1, &value2); err != nil {
// // handle error
// }
package redis

33
library/cache/redis/errors.go vendored Normal file
View File

@@ -0,0 +1,33 @@
package redis
import (
"strings"
pkgerr "github.com/pkg/errors"
)
func formatErr(err error) string {
e := pkgerr.Cause(err)
switch e {
case ErrNil, nil:
return ""
default:
es := e.Error()
switch {
case strings.HasPrefix(es, "read"):
return "read timeout"
case strings.HasPrefix(es, "dial"):
return "dial timeout"
case strings.HasPrefix(es, "write"):
return "write timeout"
case strings.Contains(es, "EOF"):
return "eof"
case strings.Contains(es, "reset"):
return "reset"
case strings.Contains(es, "broken"):
return "broken pipe"
default:
return "unexpected err"
}
}
}

117
library/cache/redis/log.go vendored Normal file
View File

@@ -0,0 +1,117 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"bytes"
"fmt"
"log"
)
// NewLoggingConn returns a logging wrapper around a connection.
func NewLoggingConn(conn Conn, logger *log.Logger, prefix string) Conn {
if prefix != "" {
prefix = prefix + "."
}
return &loggingConn{conn, logger, prefix}
}
type loggingConn struct {
Conn
logger *log.Logger
prefix string
}
func (c *loggingConn) Close() error {
err := c.Conn.Close()
var buf bytes.Buffer
fmt.Fprintf(&buf, "%sClose() -> (%v)", c.prefix, err)
c.logger.Output(2, buf.String())
return err
}
func (c *loggingConn) printValue(buf *bytes.Buffer, v interface{}) {
const chop = 32
switch v := v.(type) {
case []byte:
if len(v) > chop {
fmt.Fprintf(buf, "%q...", v[:chop])
} else {
fmt.Fprintf(buf, "%q", v)
}
case string:
if len(v) > chop {
fmt.Fprintf(buf, "%q...", v[:chop])
} else {
fmt.Fprintf(buf, "%q", v)
}
case []interface{}:
if len(v) == 0 {
buf.WriteString("[]")
} else {
sep := "["
fin := "]"
if len(v) > chop {
v = v[:chop]
fin = "...]"
}
for _, vv := range v {
buf.WriteString(sep)
c.printValue(buf, vv)
sep = ", "
}
buf.WriteString(fin)
}
default:
fmt.Fprint(buf, v)
}
}
func (c *loggingConn) print(method, commandName string, args []interface{}, reply interface{}, err error) {
var buf bytes.Buffer
fmt.Fprintf(&buf, "%s%s(", c.prefix, method)
if method != "Receive" {
buf.WriteString(commandName)
for _, arg := range args {
buf.WriteString(", ")
c.printValue(&buf, arg)
}
}
buf.WriteString(") -> (")
if method != "Send" {
c.printValue(&buf, reply)
buf.WriteString(", ")
}
fmt.Fprintf(&buf, "%v)", err)
c.logger.Output(3, buf.String())
}
func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{}, error) {
reply, err := c.Conn.Do(commandName, args...)
c.print("Do", commandName, args, reply, err)
return reply, err
}
func (c *loggingConn) Send(commandName string, args ...interface{}) error {
err := c.Conn.Send(commandName, args...)
c.print("Send", commandName, args, nil, err)
return err
}
func (c *loggingConn) Receive() (interface{}, error) {
reply, err := c.Conn.Receive()
c.print("Receive", "", nil, reply, err)
return reply, err
}

36
library/cache/redis/mock.go vendored Normal file
View File

@@ -0,0 +1,36 @@
package redis
import (
"context"
)
// MockErr for unit test.
type MockErr struct {
Error error
}
// MockWith return a mock conn.
func MockWith(err error) MockErr {
return MockErr{Error: err}
}
// Err .
func (m MockErr) Err() error { return m.Error }
// Close .
func (m MockErr) Close() error { return m.Error }
// Do .
func (m MockErr) Do(commandName string, args ...interface{}) (interface{}, error) { return nil, m.Error }
// Send .
func (m MockErr) Send(commandName string, args ...interface{}) error { return m.Error }
// Flush .
func (m MockErr) Flush() error { return m.Error }
// Receive .
func (m MockErr) Receive() (interface{}, error) { return nil, m.Error }
// WithContext .
func (m MockErr) WithContext(context.Context) Conn { return m }

226
library/cache/redis/pool.go vendored Normal file
View File

@@ -0,0 +1,226 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha1"
"errors"
"io"
"strconv"
"sync"
"time"
"go-common/library/container/pool"
"go-common/library/net/trace"
xtime "go-common/library/time"
)
var beginTime, _ = time.Parse("2006-01-02 15:04:05", "2006-01-02 15:04:05")
var (
errConnClosed = errors.New("redigo: connection closed")
)
// Pool .
type Pool struct {
*pool.Slice
// config
c *Config
}
// Config client settings.
type Config struct {
*pool.Config
Name string // redis name, for trace
Proto string
Addr string
Auth string
DialTimeout xtime.Duration
ReadTimeout xtime.Duration
WriteTimeout xtime.Duration
}
// NewPool creates a new pool.
func NewPool(c *Config, options ...DialOption) (p *Pool) {
if c.DialTimeout <= 0 || c.ReadTimeout <= 0 || c.WriteTimeout <= 0 {
panic("must config redis timeout")
}
p1 := pool.NewSlice(c.Config)
cnop := DialConnectTimeout(time.Duration(c.DialTimeout))
options = append(options, cnop)
rdop := DialReadTimeout(time.Duration(c.ReadTimeout))
options = append(options, rdop)
wrop := DialWriteTimeout(time.Duration(c.WriteTimeout))
options = append(options, wrop)
auop := DialPassword(c.Auth)
options = append(options, auop)
// new pool
p1.New = func(ctx context.Context) (io.Closer, error) {
conn, err := Dial(c.Proto, c.Addr, options...)
if err != nil {
return nil, err
}
return &traceConn{Conn: conn, connTags: []trace.Tag{trace.TagString(trace.TagPeerAddress, c.Addr)}}, nil
}
p = &Pool{Slice: p1, c: c}
return
}
// Get gets a connection. The application must close the returned connection.
// This method always returns a valid connection so that applications can defer
// error handling to the first use of the connection. If there is an error
// getting an underlying connection, then the connection Err, Do, Send, Flush
// and Receive methods return that error.
func (p *Pool) Get(ctx context.Context) Conn {
c, err := p.Slice.Get(ctx)
if err != nil {
return errorConnection{err}
}
c1, _ := c.(Conn)
return &pooledConnection{p: p, c: c1.WithContext(ctx), ctx: ctx, now: beginTime}
}
// Close releases the resources used by the pool.
func (p *Pool) Close() error {
return p.Slice.Close()
}
type pooledConnection struct {
p *Pool
c Conn
state int
now time.Time
cmds []string
ctx context.Context
}
var (
sentinel []byte
sentinelOnce sync.Once
)
func initSentinel() {
p := make([]byte, 64)
if _, err := rand.Read(p); err == nil {
sentinel = p
} else {
h := sha1.New()
io.WriteString(h, "Oops, rand failed. Use time instead.")
io.WriteString(h, strconv.FormatInt(time.Now().UnixNano(), 10))
sentinel = h.Sum(nil)
}
}
func (pc *pooledConnection) Close() error {
c := pc.c
if _, ok := c.(errorConnection); ok {
return nil
}
pc.c = errorConnection{errConnClosed}
if pc.state&MultiState != 0 {
c.Send("DISCARD")
pc.state &^= (MultiState | WatchState)
} else if pc.state&WatchState != 0 {
c.Send("UNWATCH")
pc.state &^= WatchState
}
if pc.state&SubscribeState != 0 {
c.Send("UNSUBSCRIBE")
c.Send("PUNSUBSCRIBE")
// To detect the end of the message stream, ask the server to echo
// a sentinel value and read until we see that value.
sentinelOnce.Do(initSentinel)
c.Send("ECHO", sentinel)
c.Flush()
for {
p, err := c.Receive()
if err != nil {
break
}
if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) {
pc.state &^= SubscribeState
break
}
}
}
_, err := c.Do("")
pc.p.Slice.Put(context.Background(), c, pc.state != 0 || c.Err() != nil)
return err
}
func (pc *pooledConnection) Err() error {
return pc.c.Err()
}
func key(args interface{}) (key string) {
keys, _ := args.([]interface{})
if len(keys) > 0 {
key, _ = keys[0].(string)
}
return
}
func (pc *pooledConnection) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
ci := LookupCommandInfo(commandName)
pc.state = (pc.state | ci.Set) &^ ci.Clear
reply, err = pc.c.Do(commandName, args...)
return
}
func (pc *pooledConnection) Send(commandName string, args ...interface{}) (err error) {
ci := LookupCommandInfo(commandName)
pc.state = (pc.state | ci.Set) &^ ci.Clear
if pc.now.Equal(beginTime) {
// mark first send time
pc.now = time.Now()
}
pc.cmds = append(pc.cmds, commandName)
return pc.c.Send(commandName, args...)
}
func (pc *pooledConnection) Flush() error {
return pc.c.Flush()
}
func (pc *pooledConnection) Receive() (reply interface{}, err error) {
reply, err = pc.c.Receive()
if len(pc.cmds) > 0 {
pc.cmds = pc.cmds[1:]
}
return
}
func (pc *pooledConnection) WithContext(ctx context.Context) Conn {
pc.ctx = ctx
return pc
}
type errorConnection struct{ err error }
func (ec errorConnection) Do(string, ...interface{}) (interface{}, error) {
return nil, ec.err
}
func (ec errorConnection) Send(string, ...interface{}) error { return ec.err }
func (ec errorConnection) Err() error { return ec.err }
func (ec errorConnection) Close() error { return ec.err }
func (ec errorConnection) Flush() error { return ec.err }
func (ec errorConnection) Receive() (interface{}, error) { return nil, ec.err }
func (ec errorConnection) WithContext(context.Context) Conn { return ec }

452
library/cache/redis/pool_test.go vendored Normal file
View File

@@ -0,0 +1,452 @@
// Copyright 2011 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"context"
"errors"
"io"
"reflect"
"sync"
"testing"
"time"
"go-common/library/container/pool"
)
type poolTestConn struct {
d *poolDialer
err error
Conn
}
func (c *poolTestConn) Close() error {
c.d.mu.Lock()
c.d.open--
c.d.mu.Unlock()
return c.Conn.Close()
}
func (c *poolTestConn) Err() error { return c.err }
func (c *poolTestConn) Do(commandName string, args ...interface{}) (interface{}, error) {
if commandName == "ERR" {
c.err = args[0].(error)
commandName = "PING"
}
if commandName != "" {
c.d.commands = append(c.d.commands, commandName)
}
return c.Conn.Do(commandName, args...)
}
func (c *poolTestConn) Send(commandName string, args ...interface{}) error {
c.d.commands = append(c.d.commands, commandName)
return c.Conn.Send(commandName, args...)
}
type poolDialer struct {
mu sync.Mutex
t *testing.T
dialed int
open int
commands []string
dialErr error
}
func (d *poolDialer) dial() (Conn, error) {
d.mu.Lock()
d.dialed++
dialErr := d.dialErr
d.mu.Unlock()
if dialErr != nil {
return nil, d.dialErr
}
c, err := DialDefaultServer()
if err != nil {
return nil, err
}
d.mu.Lock()
d.open++
d.mu.Unlock()
return &poolTestConn{d: d, Conn: c}, nil
}
func (d *poolDialer) check(message string, p *Pool, dialed, open int) {
d.mu.Lock()
if d.dialed != dialed {
d.t.Errorf("%s: dialed=%d, want %d", message, d.dialed, dialed)
}
if d.open != open {
d.t.Errorf("%s: open=%d, want %d", message, d.open, open)
}
// if active := p.ActiveCount(); active != open {
// d.t.Errorf("%s: active=%d, want %d", message, active, open)
// }
d.mu.Unlock()
}
func TestPoolReuse(t *testing.T) {
d := poolDialer{t: t}
p := NewPool(config)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
for i := 0; i < 10; i++ {
c1 := p.Get(context.TODO())
c1.Do("PING")
c2 := p.Get(context.TODO())
c2.Do("PING")
c1.Close()
c2.Close()
}
d.check("before close", p, 2, 2)
p.Close()
d.check("after close", p, 2, 0)
}
func TestPoolMaxIdle(t *testing.T) {
d := poolDialer{t: t}
p := NewPool(config)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
defer p.Close()
for i := 0; i < 10; i++ {
c1 := p.Get(context.TODO())
c1.Do("PING")
c2 := p.Get(context.TODO())
c2.Do("PING")
c3 := p.Get(context.TODO())
c3.Do("PING")
c1.Close()
c2.Close()
c3.Close()
}
d.check("before close", p, 12, 2)
p.Close()
d.check("after close", p, 12, 0)
}
func TestPoolError(t *testing.T) {
d := poolDialer{t: t}
p := NewPool(config)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
defer p.Close()
c := p.Get(context.TODO())
c.Do("ERR", io.EOF)
if c.Err() == nil {
t.Errorf("expected c.Err() != nil")
}
c.Close()
c = p.Get(context.TODO())
c.Do("ERR", io.EOF)
c.Close()
d.check(".", p, 2, 0)
}
func TestPoolClose(t *testing.T) {
d := poolDialer{t: t}
p := NewPool(config)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
defer p.Close()
c1 := p.Get(context.TODO())
c1.Do("PING")
c2 := p.Get(context.TODO())
c2.Do("PING")
c3 := p.Get(context.TODO())
c3.Do("PING")
c1.Close()
if _, err := c1.Do("PING"); err == nil {
t.Errorf("expected error after connection closed")
}
c2.Close()
c2.Close()
p.Close()
d.check("after pool close", p, 3, 1)
if _, err := c1.Do("PING"); err == nil {
t.Errorf("expected error after connection and pool closed")
}
c3.Close()
d.check("after conn close", p, 3, 0)
c1 = p.Get(context.TODO())
if _, err := c1.Do("PING"); err == nil {
t.Errorf("expected error after pool closed")
}
}
func TestPoolConcurrenSendReceive(t *testing.T) {
p := NewPool(config)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return DialDefaultServer()
}
defer p.Close()
c := p.Get(context.TODO())
done := make(chan error, 1)
go func() {
_, err := c.Receive()
done <- err
}()
c.Send("PING")
c.Flush()
err := <-done
if err != nil {
t.Fatalf("Receive() returned error %v", err)
}
_, err = c.Do("")
if err != nil {
t.Fatalf("Do() returned error %v", err)
}
c.Close()
}
func TestPoolMaxActive(t *testing.T) {
d := poolDialer{t: t}
config.Config = &pool.Config{
Active: 2,
Idle: 2,
}
p := NewPool(config)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
defer p.Close()
c1 := p.Get(context.TODO())
c1.Do("PING")
c2 := p.Get(context.TODO())
c2.Do("PING")
d.check("1", p, 2, 2)
c3 := p.Get(context.TODO())
if _, err := c3.Do("PING"); err != pool.ErrPoolExhausted {
t.Errorf("expected pool exhausted")
}
c3.Close()
d.check("2", p, 2, 2)
c2.Close()
d.check("3", p, 2, 2)
c3 = p.Get(context.TODO())
if _, err := c3.Do("PING"); err != nil {
t.Errorf("expected good channel, err=%v", err)
}
c3.Close()
d.check("4", p, 2, 2)
}
func TestPoolMonitorCleanup(t *testing.T) {
d := poolDialer{t: t}
p := NewPool(config)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
defer p.Close()
c := p.Get(context.TODO())
c.Send("MONITOR")
c.Close()
d.check("", p, 1, 0)
}
func TestPoolPubSubCleanup(t *testing.T) {
d := poolDialer{t: t}
p := NewPool(config)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
defer p.Close()
c := p.Get(context.TODO())
c.Send("SUBSCRIBE", "x")
c.Close()
want := []string{"SUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "ECHO"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get(context.TODO())
c.Send("PSUBSCRIBE", "x*")
c.Close()
want = []string{"PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "ECHO"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
}
func TestPoolTransactionCleanup(t *testing.T) {
d := poolDialer{t: t}
p := NewPool(config)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
defer p.Close()
c := p.Get(context.TODO())
c.Do("WATCH", "key")
c.Do("PING")
c.Close()
want := []string{"WATCH", "PING", "UNWATCH"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get(context.TODO())
c.Do("WATCH", "key")
c.Do("UNWATCH")
c.Do("PING")
c.Close()
want = []string{"WATCH", "UNWATCH", "PING"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get(context.TODO())
c.Do("WATCH", "key")
c.Do("MULTI")
c.Do("PING")
c.Close()
want = []string{"WATCH", "MULTI", "PING", "DISCARD"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get(context.TODO())
c.Do("WATCH", "key")
c.Do("MULTI")
c.Do("DISCARD")
c.Do("PING")
c.Close()
want = []string{"WATCH", "MULTI", "DISCARD", "PING"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get(context.TODO())
c.Do("WATCH", "key")
c.Do("MULTI")
c.Do("EXEC")
c.Do("PING")
c.Close()
want = []string{"WATCH", "MULTI", "EXEC", "PING"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
}
func startGoroutines(p *Pool, cmd string, args ...interface{}) chan error {
errs := make(chan error, 10)
for i := 0; i < cap(errs); i++ {
go func() {
c := p.Get(context.TODO())
_, err := c.Do(cmd, args...)
errs <- err
c.Close()
}()
}
// Wait for goroutines to block.
time.Sleep(time.Second / 4)
return errs
}
func TestWaitPoolDialError(t *testing.T) {
testErr := errors.New("test")
d := poolDialer{t: t}
config1 := getConfig()
config1.Config = &pool.Config{
Active: 1,
Idle: 1,
Wait: true,
}
p := NewPool(config1)
p.Slice.New = func(ctx context.Context) (io.Closer, error) {
return d.dial()
}
defer p.Close()
c := p.Get(context.TODO())
errs := startGoroutines(p, "ERR", testErr)
d.check("before close", p, 1, 1)
d.dialErr = errors.New("dial")
c.Close()
nilCount := 0
errCount := 0
timeout := time.After(2 * time.Second)
for i := 0; i < cap(errs); i++ {
select {
case err := <-errs:
switch err {
case nil:
nilCount++
case d.dialErr:
errCount++
default:
t.Fatalf("expected dial error or nil, got %v", err)
}
case <-timeout:
t.Logf("Wait all the time and timeout %d", i)
return
}
}
if nilCount != 1 {
t.Errorf("expected one nil error, got %d", nilCount)
}
if errCount != cap(errs)-1 {
t.Errorf("expected %d dial erors, got %d", cap(errs)-1, errCount)
}
d.check("done", p, cap(errs), 0)
}

152
library/cache/redis/pubsub.go vendored Normal file
View File

@@ -0,0 +1,152 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"errors"
pkgerr "github.com/pkg/errors"
)
var (
errPubSub = errors.New("redigo: unknown pubsub notification")
)
// Subscription represents a subscribe or unsubscribe notification.
type Subscription struct {
// Kind is "subscribe", "unsubscribe", "psubscribe" or "punsubscribe"
Kind string
// The channel that was changed.
Channel string
// The current number of subscriptions for connection.
Count int
}
// Message represents a message notification.
type Message struct {
// The originating channel.
Channel string
// The message data.
Data []byte
}
// PMessage represents a pmessage notification.
type PMessage struct {
// The matched pattern.
Pattern string
// The originating channel.
Channel string
// The message data.
Data []byte
}
// Pong represents a pubsub pong notification.
type Pong struct {
Data string
}
// PubSubConn wraps a Conn with convenience methods for subscribers.
type PubSubConn struct {
Conn Conn
}
// Close closes the connection.
func (c PubSubConn) Close() error {
return c.Conn.Close()
}
// Subscribe subscribes the connection to the specified channels.
func (c PubSubConn) Subscribe(channel ...interface{}) error {
c.Conn.Send("SUBSCRIBE", channel...)
return c.Conn.Flush()
}
// PSubscribe subscribes the connection to the given patterns.
func (c PubSubConn) PSubscribe(channel ...interface{}) error {
c.Conn.Send("PSUBSCRIBE", channel...)
return c.Conn.Flush()
}
// Unsubscribe unsubscribes the connection from the given channels, or from all
// of them if none is given.
func (c PubSubConn) Unsubscribe(channel ...interface{}) error {
c.Conn.Send("UNSUBSCRIBE", channel...)
return c.Conn.Flush()
}
// PUnsubscribe unsubscribes the connection from the given patterns, or from all
// of them if none is given.
func (c PubSubConn) PUnsubscribe(channel ...interface{}) error {
c.Conn.Send("PUNSUBSCRIBE", channel...)
return c.Conn.Flush()
}
// Ping sends a PING to the server with the specified data.
func (c PubSubConn) Ping(data string) error {
c.Conn.Send("PING", data)
return c.Conn.Flush()
}
// Receive returns a pushed message as a Subscription, Message, PMessage, Pong
// or error. The return value is intended to be used directly in a type switch
// as illustrated in the PubSubConn example.
func (c PubSubConn) Receive() interface{} {
reply, err := Values(c.Conn.Receive())
if err != nil {
return err
}
var kind string
reply, err = Scan(reply, &kind)
if err != nil {
return err
}
switch kind {
case "message":
var m Message
if _, err := Scan(reply, &m.Channel, &m.Data); err != nil {
return err
}
return m
case "pmessage":
var pm PMessage
if _, err := Scan(reply, &pm.Pattern, &pm.Channel, &pm.Data); err != nil {
return err
}
return pm
case "subscribe", "psubscribe", "unsubscribe", "punsubscribe":
s := Subscription{Kind: kind}
if _, err := Scan(reply, &s.Channel, &s.Count); err != nil {
return err
}
return s
case "pong":
var p Pong
if _, err := Scan(reply, &p.Data); err != nil {
return err
}
return p
}
return pkgerr.WithStack(errPubSub)
}

146
library/cache/redis/pubsub_test.go vendored Normal file
View File

@@ -0,0 +1,146 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"fmt"
"reflect"
"sync"
"testing"
)
func publish(channel, value interface{}) {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Do("PUBLISH", channel, value)
}
// Applications can receive pushed messages from one goroutine and manage subscriptions from another goroutine.
func ExamplePubSubConn() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
var wg sync.WaitGroup
wg.Add(2)
psc := PubSubConn{Conn: c}
// This goroutine receives and prints pushed notifications from the server.
// The goroutine exits when the connection is unsubscribed from all
// channels or there is an error.
go func() {
defer wg.Done()
for {
switch n := psc.Receive().(type) {
case Message:
fmt.Printf("Message: %s %s\n", n.Channel, n.Data)
case PMessage:
fmt.Printf("PMessage: %s %s %s\n", n.Pattern, n.Channel, n.Data)
case Subscription:
fmt.Printf("Subscription: %s %s %d\n", n.Kind, n.Channel, n.Count)
if n.Count == 0 {
return
}
case error:
fmt.Printf("error: %v\n", n)
return
}
}
}()
// This goroutine manages subscriptions for the connection.
go func() {
defer wg.Done()
psc.Subscribe("example")
psc.PSubscribe("p*")
// The following function calls publish a message using another
// connection to the Redis server.
publish("example", "hello")
publish("example", "world")
publish("pexample", "foo")
publish("pexample", "bar")
// Unsubscribe from all connections. This will cause the receiving
// goroutine to exit.
psc.Unsubscribe()
psc.PUnsubscribe()
}()
wg.Wait()
// Output:
// Subscription: subscribe example 1
// Subscription: psubscribe p* 2
// Message: example hello
// Message: example world
// PMessage: p* pexample foo
// PMessage: p* pexample bar
// Subscription: unsubscribe example 1
// Subscription: punsubscribe p* 0
}
func expectPushed(t *testing.T, c PubSubConn, message string, expected interface{}) {
actual := c.Receive()
if !reflect.DeepEqual(actual, expected) {
t.Errorf("%s = %v, want %v", message, actual, expected)
}
}
func TestPushed(t *testing.T) {
pc, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer pc.Close()
sc, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer sc.Close()
c := PubSubConn{Conn: sc}
c.Subscribe("c1")
expectPushed(t, c, "Subscribe(c1)", Subscription{Kind: "subscribe", Channel: "c1", Count: 1})
c.Subscribe("c2")
expectPushed(t, c, "Subscribe(c2)", Subscription{Kind: "subscribe", Channel: "c2", Count: 2})
c.PSubscribe("p1")
expectPushed(t, c, "PSubscribe(p1)", Subscription{Kind: "psubscribe", Channel: "p1", Count: 3})
c.PSubscribe("p2")
expectPushed(t, c, "PSubscribe(p2)", Subscription{Kind: "psubscribe", Channel: "p2", Count: 4})
c.PUnsubscribe()
expectPushed(t, c, "Punsubscribe(p1)", Subscription{Kind: "punsubscribe", Channel: "p1", Count: 3})
expectPushed(t, c, "Punsubscribe()", Subscription{Kind: "punsubscribe", Channel: "p2", Count: 2})
pc.Do("PUBLISH", "c1", "hello")
expectPushed(t, c, "PUBLISH c1 hello", Message{Channel: "c1", Data: []byte("hello")})
c.Ping("hello")
expectPushed(t, c, `Ping("hello")`, Pong{"hello"})
c.Conn.Send("PING")
c.Conn.Flush()
expectPushed(t, c, `Send("PING")`, Pong{})
}

51
library/cache/redis/redis.go vendored Normal file
View File

@@ -0,0 +1,51 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"context"
)
// Error represents an error returned in a command reply.
type Error string
func (err Error) Error() string { return string(err) }
// Conn represents a connection to a Redis server.
type Conn interface {
// Close closes the connection.
Close() error
// Err returns a non-nil value if the connection is broken. The returned
// value is either the first non-nil value returned from the underlying
// network connection or a protocol parsing error. Applications should
// close broken connections.
Err() error
// Do sends a command to the server and returns the received reply.
Do(commandName string, args ...interface{}) (reply interface{}, err error)
// Send writes the command to the client's output buffer.
Send(commandName string, args ...interface{}) error
// Flush flushes the output buffer to the Redis server.
Flush() error
// Receive receives a single reply from the Redis server
Receive() (reply interface{}, err error)
// WithContext
WithContext(ctx context.Context) Conn
}

132
library/cache/redis/redis_test.go vendored Normal file
View File

@@ -0,0 +1,132 @@
package redis
import (
"context"
"testing"
"time"
"go-common/library/container/pool"
xtime "go-common/library/time"
)
var p *Pool
var config *Config
func init() {
config = getConfig()
p = NewPool(config)
}
func getConfig() (c *Config) {
c = &Config{
Name: "test",
Proto: "tcp",
Addr: "172.16.33.54:6379",
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}
c.Config = &pool.Config{
Active: 20,
Idle: 2,
IdleTimeout: xtime.Duration(90 * time.Second),
}
return
}
func TestRedis(t *testing.T) {
testSet(t, p)
testSend(t, p)
testGet(t, p)
testErr(t, p)
if err := p.Close(); err != nil {
t.Errorf("redis: close error(%v)", err)
}
conn, err := NewConn(config)
if err != nil {
t.Errorf("redis: new conn error(%v)", err)
}
if err := conn.Close(); err != nil {
t.Errorf("redis: close error(%v)", err)
}
}
func testSet(t *testing.T, p *Pool) {
var (
key = "test"
value = "test"
conn = p.Get(context.TODO())
)
defer conn.Close()
if reply, err := conn.Do("set", key, value); err != nil {
t.Errorf("redis: conn.Do(SET, %s, %s) error(%v)", key, value, err)
} else {
t.Logf("redis: set status: %s", reply)
}
}
func testSend(t *testing.T, p *Pool) {
var (
key = "test"
value = "test"
expire = 1000
conn = p.Get(context.TODO())
)
defer conn.Close()
if err := conn.Send("SET", key, value); err != nil {
t.Errorf("redis: conn.Send(SET, %s, %s) error(%v)", key, value, err)
}
if err := conn.Send("EXPIRE", key, expire); err != nil {
t.Errorf("redis: conn.Send(EXPIRE key(%s) expire(%d)) error(%v)", key, expire, err)
}
if err := conn.Flush(); err != nil {
t.Errorf("redis: conn.Flush error(%v)", err)
}
for i := 0; i < 2; i++ {
if _, err := conn.Receive(); err != nil {
t.Errorf("redis: conn.Receive error(%v)", err)
return
}
}
t.Logf("redis: set value: %s", value)
}
func testGet(t *testing.T, p *Pool) {
var (
key = "test"
conn = p.Get(context.TODO())
)
defer conn.Close()
if reply, err := conn.Do("GET", key); err != nil {
t.Errorf("redis: conn.Do(GET, %s) error(%v)", key, err)
} else {
t.Logf("redis: get value: %s", reply)
}
}
func testErr(t *testing.T, p *Pool) {
conn := p.Get(context.TODO())
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
if err := conn.Err(); err == nil {
t.Errorf("redis: err not nil")
} else {
t.Logf("redis: err: %v", err)
}
}
func BenchmarkMemcache(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
conn := p.Get(context.TODO())
if err := conn.Close(); err != nil {
b.Errorf("memcache: close error(%v)", err)
}
}
})
if err := p.Close(); err != nil {
b.Errorf("memcache: close error(%v)", err)
}
}

409
library/cache/redis/reply.go vendored Normal file
View File

@@ -0,0 +1,409 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"errors"
"strconv"
pkgerr "github.com/pkg/errors"
)
// ErrNil indicates that a reply value is nil.
var ErrNil = errors.New("redigo: nil returned")
// Int is a helper that converts a command reply to an integer. If err is not
// equal to nil, then Int returns 0, err. Otherwise, Int converts the
// reply to an int as follows:
//
// Reply type Result
// integer int(reply), nil
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
func Int(reply interface{}, err error) (int, error) {
if err != nil {
return 0, err
}
switch reply := reply.(type) {
case int64:
x := int(reply)
if int64(x) != reply {
return 0, pkgerr.WithStack(strconv.ErrRange)
}
return x, nil
case []byte:
n, err := strconv.ParseInt(string(reply), 10, 0)
return int(n), pkgerr.WithStack(err)
case nil:
return 0, ErrNil
case Error:
return 0, reply
}
return 0, pkgerr.Errorf("redigo: unexpected type for Int, got type %T", reply)
}
// Int64 is a helper that converts a command reply to 64 bit integer. If err is
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
// reply to an int64 as follows:
//
// Reply type Result
// integer reply, nil
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
func Int64(reply interface{}, err error) (int64, error) {
if err != nil {
return 0, err
}
switch reply := reply.(type) {
case int64:
return reply, nil
case []byte:
n, err := strconv.ParseInt(string(reply), 10, 64)
return n, pkgerr.WithStack(err)
case nil:
return 0, ErrNil
case Error:
return 0, reply
}
return 0, pkgerr.Errorf("redigo: unexpected type for Int64, got type %T", reply)
}
var errNegativeInt = errors.New("redigo: unexpected value for Uint64")
// Uint64 is a helper that converts a command reply to 64 bit integer. If err is
// not equal to nil, then Int returns 0, err. Otherwise, Int64 converts the
// reply to an int64 as follows:
//
// Reply type Result
// integer reply, nil
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
func Uint64(reply interface{}, err error) (uint64, error) {
if err != nil {
return 0, err
}
switch reply := reply.(type) {
case int64:
if reply < 0 {
return 0, pkgerr.WithStack(errNegativeInt)
}
return uint64(reply), nil
case []byte:
n, err := strconv.ParseUint(string(reply), 10, 64)
return n, err
case nil:
return 0, ErrNil
case Error:
return 0, reply
}
return 0, pkgerr.Errorf("redigo: unexpected type for Uint64, got type %T", reply)
}
// Float64 is a helper that converts a command reply to 64 bit float. If err is
// not equal to nil, then Float64 returns 0, err. Otherwise, Float64 converts
// the reply to an int as follows:
//
// Reply type Result
// bulk string parsed reply, nil
// nil 0, ErrNil
// other 0, error
func Float64(reply interface{}, err error) (float64, error) {
if err != nil {
return 0, err
}
switch reply := reply.(type) {
case []byte:
n, err := strconv.ParseFloat(string(reply), 64)
return n, pkgerr.WithStack(err)
case nil:
return 0, ErrNil
case Error:
return 0, reply
}
return 0, pkgerr.Errorf("redigo: unexpected type for Float64, got type %T", reply)
}
// String is a helper that converts a command reply to a string. If err is not
// equal to nil, then String returns "", err. Otherwise String converts the
// reply to a string as follows:
//
// Reply type Result
// bulk string string(reply), nil
// simple string reply, nil
// nil "", ErrNil
// other "", error
func String(reply interface{}, err error) (string, error) {
if err != nil {
return "", err
}
switch reply := reply.(type) {
case []byte:
return string(reply), nil
case string:
return reply, nil
case nil:
return "", ErrNil
case Error:
return "", reply
}
return "", pkgerr.Errorf("redigo: unexpected type for String, got type %T", reply)
}
// Bytes is a helper that converts a command reply to a slice of bytes. If err
// is not equal to nil, then Bytes returns nil, err. Otherwise Bytes converts
// the reply to a slice of bytes as follows:
//
// Reply type Result
// bulk string reply, nil
// simple string []byte(reply), nil
// nil nil, ErrNil
// other nil, error
func Bytes(reply interface{}, err error) ([]byte, error) {
if err != nil {
return nil, err
}
switch reply := reply.(type) {
case []byte:
return reply, nil
case string:
return []byte(reply), nil
case nil:
return nil, ErrNil
case Error:
return nil, reply
}
return nil, pkgerr.Errorf("redigo: unexpected type for Bytes, got type %T", reply)
}
// Bool is a helper that converts a command reply to a boolean. If err is not
// equal to nil, then Bool returns false, err. Otherwise Bool converts the
// reply to boolean as follows:
//
// Reply type Result
// integer value != 0, nil
// bulk string strconv.ParseBool(reply)
// nil false, ErrNil
// other false, error
func Bool(reply interface{}, err error) (bool, error) {
if err != nil {
return false, err
}
switch reply := reply.(type) {
case int64:
return reply != 0, nil
case []byte:
b, e := strconv.ParseBool(string(reply))
return b, pkgerr.WithStack(e)
case nil:
return false, ErrNil
case Error:
return false, reply
}
return false, pkgerr.Errorf("redigo: unexpected type for Bool, got type %T", reply)
}
// MultiBulk is a helper that converts an array command reply to a []interface{}.
//
// Deprecated: Use Values instead.
func MultiBulk(reply interface{}, err error) ([]interface{}, error) { return Values(reply, err) }
// Values is a helper that converts an array command reply to a []interface{}.
// If err is not equal to nil, then Values returns nil, err. Otherwise, Values
// converts the reply as follows:
//
// Reply type Result
// array reply, nil
// nil nil, ErrNil
// other nil, error
func Values(reply interface{}, err error) ([]interface{}, error) {
if err != nil {
return nil, err
}
switch reply := reply.(type) {
case []interface{}:
return reply, nil
case nil:
return nil, ErrNil
case Error:
return nil, reply
}
return nil, pkgerr.Errorf("redigo: unexpected type for Values, got type %T", reply)
}
// Strings is a helper that converts an array command reply to a []string. If
// err is not equal to nil, then Strings returns nil, err. Nil array items are
// converted to "" in the output slice. Strings returns an error if an array
// item is not a bulk string or nil.
func Strings(reply interface{}, err error) ([]string, error) {
if err != nil {
return nil, err
}
switch reply := reply.(type) {
case []interface{}:
result := make([]string, len(reply))
for i := range reply {
if reply[i] == nil {
continue
}
p, ok := reply[i].([]byte)
if !ok {
return nil, pkgerr.Errorf("redigo: unexpected element type for Strings, got type %T", reply[i])
}
result[i] = string(p)
}
return result, nil
case nil:
return nil, ErrNil
case Error:
return nil, reply
}
return nil, pkgerr.Errorf("redigo: unexpected type for Strings, got type %T", reply)
}
// ByteSlices is a helper that converts an array command reply to a [][]byte.
// If err is not equal to nil, then ByteSlices returns nil, err. Nil array
// items are stay nil. ByteSlices returns an error if an array item is not a
// bulk string or nil.
func ByteSlices(reply interface{}, err error) ([][]byte, error) {
if err != nil {
return nil, err
}
switch reply := reply.(type) {
case []interface{}:
result := make([][]byte, len(reply))
for i := range reply {
if reply[i] == nil {
continue
}
p, ok := reply[i].([]byte)
if !ok {
return nil, pkgerr.Errorf("redigo: unexpected element type for ByteSlices, got type %T", reply[i])
}
result[i] = p
}
return result, nil
case nil:
return nil, ErrNil
case Error:
return nil, reply
}
return nil, pkgerr.Errorf("redigo: unexpected type for ByteSlices, got type %T", reply)
}
// Ints is a helper that converts an array command reply to a []int. If
// err is not equal to nil, then Ints returns nil, err.
func Ints(reply interface{}, err error) ([]int, error) {
var ints []int
values, err := Values(reply, err)
if err != nil {
return ints, err
}
if err := ScanSlice(values, &ints); err != nil {
return ints, err
}
return ints, nil
}
// Int64s is a helper that converts an array command reply to a []int64. If
// err is not equal to nil, then Int64s returns nil, err.
func Int64s(reply interface{}, err error) ([]int64, error) {
var int64s []int64
values, err := Values(reply, err)
if err != nil {
return int64s, err
}
if err := ScanSlice(values, &int64s); err != nil {
return int64s, err
}
return int64s, nil
}
// StringMap is a helper that converts an array of strings (alternating key, value)
// into a map[string]string. The HGETALL and CONFIG GET commands return replies in this format.
// Requires an even number of values in result.
func StringMap(result interface{}, err error) (map[string]string, error) {
values, err := Values(result, err)
if err != nil {
return nil, err
}
if len(values)%2 != 0 {
return nil, pkgerr.New("redigo: StringMap expects even number of values result")
}
m := make(map[string]string, len(values)/2)
for i := 0; i < len(values); i += 2 {
key, okKey := values[i].([]byte)
value, okValue := values[i+1].([]byte)
if !okKey || !okValue {
return nil, pkgerr.New("redigo: ScanMap key not a bulk string value")
}
m[string(key)] = string(value)
}
return m, nil
}
// IntMap is a helper that converts an array of strings (alternating key, value)
// into a map[string]int. The HGETALL commands return replies in this format.
// Requires an even number of values in result.
func IntMap(result interface{}, err error) (map[string]int, error) {
values, err := Values(result, err)
if err != nil {
return nil, err
}
if len(values)%2 != 0 {
return nil, pkgerr.New("redigo: IntMap expects even number of values result")
}
m := make(map[string]int, len(values)/2)
for i := 0; i < len(values); i += 2 {
key, ok := values[i].([]byte)
if !ok {
return nil, pkgerr.New("redigo: ScanMap key not a bulk string value")
}
value, err := Int(values[i+1], nil)
if err != nil {
return nil, err
}
m[string(key)] = value
}
return m, nil
}
// Int64Map is a helper that converts an array of strings (alternating key, value)
// into a map[string]int64. The HGETALL commands return replies in this format.
// Requires an even number of values in result.
func Int64Map(result interface{}, err error) (map[string]int64, error) {
values, err := Values(result, err)
if err != nil {
return nil, err
}
if len(values)%2 != 0 {
return nil, pkgerr.New("redigo: Int64Map expects even number of values result")
}
m := make(map[string]int64, len(values)/2)
for i := 0; i < len(values); i += 2 {
key, ok := values[i].([]byte)
if !ok {
return nil, pkgerr.New("redigo: ScanMap key not a bulk string value")
}
value, err := Int64(values[i+1], nil)
if err != nil {
return nil, err
}
m[string(key)] = value
}
return m, nil
}

179
library/cache/redis/reply_test.go vendored Normal file
View File

@@ -0,0 +1,179 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"fmt"
"reflect"
"testing"
"github.com/pkg/errors"
)
type valueError struct {
v interface{}
err error
}
func ve(v interface{}, err error) valueError {
return valueError{v, err}
}
var replyTests = []struct {
name interface{}
actual valueError
expected valueError
}{
{
"ints([v1, v2])",
ve(Ints([]interface{}{[]byte("4"), []byte("5")}, nil)),
ve([]int{4, 5}, nil),
},
{
"ints(nil)",
ve(Ints(nil, nil)),
ve([]int(nil), ErrNil),
},
{
"strings([v1, v2])",
ve(Strings([]interface{}{[]byte("v1"), []byte("v2")}, nil)),
ve([]string{"v1", "v2"}, nil),
},
{
"strings(nil)",
ve(Strings(nil, nil)),
ve([]string(nil), ErrNil),
},
{
"byteslices([v1, v2])",
ve(ByteSlices([]interface{}{[]byte("v1"), []byte("v2")}, nil)),
ve([][]byte{[]byte("v1"), []byte("v2")}, nil),
},
{
"byteslices(nil)",
ve(ByteSlices(nil, nil)),
ve([][]byte(nil), ErrNil),
},
{
"values([v1, v2])",
ve(Values([]interface{}{[]byte("v1"), []byte("v2")}, nil)),
ve([]interface{}{[]byte("v1"), []byte("v2")}, nil),
},
{
"values(nil)",
ve(Values(nil, nil)),
ve([]interface{}(nil), ErrNil),
},
{
"float64(1.0)",
ve(Float64([]byte("1.0"), nil)),
ve(float64(1.0), nil),
},
{
"float64(nil)",
ve(Float64(nil, nil)),
ve(float64(0.0), ErrNil),
},
{
"uint64(1)",
ve(Uint64(int64(1), nil)),
ve(uint64(1), nil),
},
{
"uint64(-1)",
ve(Uint64(int64(-1), nil)),
ve(uint64(0), ErrNegativeInt),
},
}
func TestReply(t *testing.T) {
for _, rt := range replyTests {
if errors.Cause(rt.actual.err) != rt.expected.err {
t.Errorf("%s returned err %v, want %v", rt.name, rt.actual.err, rt.expected.err)
continue
}
if !reflect.DeepEqual(rt.actual.v, rt.expected.v) {
t.Errorf("%s=%+v, want %+v", rt.name, rt.actual.v, rt.expected.v)
}
}
}
// dial wraps DialDefaultServer() with a more suitable function name for examples.
func dial() (Conn, error) {
return DialDefaultServer()
}
func ExampleBool() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Do("SET", "foo", 1)
exists, _ := Bool(c.Do("EXISTS", "foo"))
fmt.Printf("%#v\n", exists)
// Output:
// true
}
func ExampleInt() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Do("SET", "k1", 1)
n, _ := Int(c.Do("GET", "k1"))
fmt.Printf("%#v\n", n)
n, _ = Int(c.Do("INCR", "k1"))
fmt.Printf("%#v\n", n)
// Output:
// 1
// 2
}
func ExampleInts() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Do("SADD", "set_with_integers", 4, 5, 6)
ints, _ := Ints(c.Do("SMEMBERS", "set_with_integers"))
fmt.Printf("%#v\n", ints)
// Output:
// []int{4, 5, 6}
}
func ExampleString() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Do("SET", "hello", "world")
s, err := String(c.Do("GET", "hello"))
fmt.Printf("%#v\n", s)
// Output:
// "world"
}

559
library/cache/redis/scan.go vendored Normal file
View File

@@ -0,0 +1,559 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
pkgerr "github.com/pkg/errors"
)
func ensureLen(d reflect.Value, n int) {
if n > d.Cap() {
d.Set(reflect.MakeSlice(d.Type(), n, n))
} else {
d.SetLen(n)
}
}
func cannotConvert(d reflect.Value, s interface{}) error {
var sname string
switch s.(type) {
case string:
sname = "Redis simple string"
case Error:
sname = "Redis error"
case int64:
sname = "Redis integer"
case []byte:
sname = "Redis bulk string"
case []interface{}:
sname = "Redis array"
default:
sname = reflect.TypeOf(s).String()
}
return pkgerr.Errorf("cannot convert from %s to %s", sname, d.Type())
}
func convertAssignBulkString(d reflect.Value, s []byte) (err error) {
switch d.Type().Kind() {
case reflect.Float32, reflect.Float64:
var x float64
x, err = strconv.ParseFloat(string(s), d.Type().Bits())
d.SetFloat(x)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var x int64
x, err = strconv.ParseInt(string(s), 10, d.Type().Bits())
d.SetInt(x)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
var x uint64
x, err = strconv.ParseUint(string(s), 10, d.Type().Bits())
d.SetUint(x)
case reflect.Bool:
var x bool
x, err = strconv.ParseBool(string(s))
d.SetBool(x)
case reflect.String:
d.SetString(string(s))
case reflect.Slice:
if d.Type().Elem().Kind() != reflect.Uint8 {
err = cannotConvert(d, s)
} else {
d.SetBytes(s)
}
default:
err = cannotConvert(d, s)
}
err = pkgerr.WithStack(err)
return
}
func convertAssignInt(d reflect.Value, s int64) (err error) {
switch d.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
d.SetInt(s)
if d.Int() != s {
err = strconv.ErrRange
d.SetInt(0)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if s < 0 {
err = strconv.ErrRange
} else {
x := uint64(s)
d.SetUint(x)
if d.Uint() != x {
err = strconv.ErrRange
d.SetUint(0)
}
}
case reflect.Bool:
d.SetBool(s != 0)
default:
err = cannotConvert(d, s)
}
err = pkgerr.WithStack(err)
return
}
func convertAssignValue(d reflect.Value, s interface{}) (err error) {
switch s := s.(type) {
case []byte:
err = convertAssignBulkString(d, s)
case int64:
err = convertAssignInt(d, s)
default:
err = cannotConvert(d, s)
}
return err
}
func convertAssignArray(d reflect.Value, s []interface{}) error {
if d.Type().Kind() != reflect.Slice {
return cannotConvert(d, s)
}
ensureLen(d, len(s))
for i := 0; i < len(s); i++ {
if err := convertAssignValue(d.Index(i), s[i]); err != nil {
return err
}
}
return nil
}
func convertAssign(d interface{}, s interface{}) (err error) {
// Handle the most common destination types using type switches and
// fall back to reflection for all other types.
switch s := s.(type) {
case nil:
// ingore
case []byte:
switch d := d.(type) {
case *string:
*d = string(s)
case *int:
*d, err = strconv.Atoi(string(s))
case *bool:
*d, err = strconv.ParseBool(string(s))
case *[]byte:
*d = s
case *interface{}:
*d = s
case nil:
// skip value
default:
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
err = cannotConvert(d, s)
} else {
err = convertAssignBulkString(d.Elem(), s)
}
}
case int64:
switch d := d.(type) {
case *int:
x := int(s)
if int64(x) != s {
err = strconv.ErrRange
x = 0
}
*d = x
case *bool:
*d = s != 0
case *interface{}:
*d = s
case nil:
// skip value
default:
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
err = cannotConvert(d, s)
} else {
err = convertAssignInt(d.Elem(), s)
}
}
case string:
switch d := d.(type) {
case *string:
*d = string(s)
default:
err = cannotConvert(reflect.ValueOf(d), s)
}
case []interface{}:
switch d := d.(type) {
case *[]interface{}:
*d = s
case *interface{}:
*d = s
case nil:
// skip value
default:
if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr {
err = cannotConvert(d, s)
} else {
err = convertAssignArray(d.Elem(), s)
}
}
case Error:
err = s
default:
err = cannotConvert(reflect.ValueOf(d), s)
}
err = pkgerr.WithStack(err)
return
}
// Scan copies from src to the values pointed at by dest.
//
// The values pointed at by dest must be an integer, float, boolean, string,
// []byte, interface{} or slices of these types. Scan uses the standard strconv
// package to convert bulk strings to numeric and boolean types.
//
// If a dest value is nil, then the corresponding src value is skipped.
//
// If a src element is nil, then the corresponding dest value is not modified.
//
// To enable easy use of Scan in a loop, Scan returns the slice of src
// following the copied values.
func Scan(src []interface{}, dest ...interface{}) ([]interface{}, error) {
if len(src) < len(dest) {
return nil, pkgerr.New("redigo.Scan: array short")
}
var err error
for i, d := range dest {
err = convertAssign(d, src[i])
if err != nil {
err = fmt.Errorf("redigo.Scan: cannot assign to dest %d: %v", i, err)
break
}
}
return src[len(dest):], err
}
type fieldSpec struct {
name string
index []int
omitEmpty bool
}
type structSpec struct {
m map[string]*fieldSpec
l []*fieldSpec
}
func (ss *structSpec) fieldSpec(name []byte) *fieldSpec {
return ss.m[string(name)]
}
func compileStructSpec(t reflect.Type, depth map[string]int, index []int, ss *structSpec) {
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
switch {
case f.PkgPath != "" && !f.Anonymous:
// Ignore unexported fields.
case f.Anonymous:
// TODO: Handle pointers. Requires change to decoder and
// protection against infinite recursion.
if f.Type.Kind() == reflect.Struct {
compileStructSpec(f.Type, depth, append(index, i), ss)
}
default:
fs := &fieldSpec{name: f.Name}
tag := f.Tag.Get("redis")
p := strings.Split(tag, ",")
if len(p) > 0 {
if p[0] == "-" {
continue
}
if len(p[0]) > 0 {
fs.name = p[0]
}
for _, s := range p[1:] {
switch s {
case "omitempty":
fs.omitEmpty = true
default:
panic(fmt.Errorf("redigo: unknown field tag %s for type %s", s, t.Name()))
}
}
}
d, found := depth[fs.name]
if !found {
d = 1 << 30
}
switch {
case len(index) == d:
// At same depth, remove from result.
delete(ss.m, fs.name)
j := 0
for i1 := 0; i1 < len(ss.l); i1++ {
if fs.name != ss.l[i1].name {
ss.l[j] = ss.l[i1]
j++
}
}
ss.l = ss.l[:j]
case len(index) < d:
fs.index = make([]int, len(index)+1)
copy(fs.index, index)
fs.index[len(index)] = i
depth[fs.name] = len(index)
ss.m[fs.name] = fs
ss.l = append(ss.l, fs)
}
}
}
}
var (
structSpecMutex sync.RWMutex
structSpecCache = make(map[reflect.Type]*structSpec)
)
func structSpecForType(t reflect.Type) *structSpec {
structSpecMutex.RLock()
ss, found := structSpecCache[t]
structSpecMutex.RUnlock()
if found {
return ss
}
structSpecMutex.Lock()
defer structSpecMutex.Unlock()
ss, found = structSpecCache[t]
if found {
return ss
}
ss = &structSpec{m: make(map[string]*fieldSpec)}
compileStructSpec(t, make(map[string]int), nil, ss)
structSpecCache[t] = ss
return ss
}
var errScanStructValue = errors.New("redigo.ScanStruct: value must be non-nil pointer to a struct")
// ScanStruct scans alternating names and values from src to a struct. The
// HGETALL and CONFIG GET commands return replies in this format.
//
// ScanStruct uses exported field names to match values in the response. Use
// 'redis' field tag to override the name:
//
// Field int `redis:"myName"`
//
// Fields with the tag redis:"-" are ignored.
//
// Integer, float, boolean, string and []byte fields are supported. Scan uses the
// standard strconv package to convert bulk string values to numeric and
// boolean types.
//
// If a src element is nil, then the corresponding field is not modified.
func ScanStruct(src []interface{}, dest interface{}) error {
d := reflect.ValueOf(dest)
if d.Kind() != reflect.Ptr || d.IsNil() {
return pkgerr.WithStack(errScanStructValue)
}
d = d.Elem()
if d.Kind() != reflect.Struct {
return pkgerr.WithStack(errScanStructValue)
}
ss := structSpecForType(d.Type())
if len(src)%2 != 0 {
return pkgerr.New("redigo.ScanStruct: number of values not a multiple of 2")
}
for i := 0; i < len(src); i += 2 {
s := src[i+1]
if s == nil {
continue
}
name, ok := src[i].([]byte)
if !ok {
return pkgerr.Errorf("redigo.ScanStruct: key %d not a bulk string value", i)
}
fs := ss.fieldSpec(name)
if fs == nil {
continue
}
if err := convertAssignValue(d.FieldByIndex(fs.index), s); err != nil {
return pkgerr.Errorf("redigo.ScanStruct: cannot assign field %s: %v", fs.name, err)
}
}
return nil
}
var (
errScanSliceValue = errors.New("redigo.ScanSlice: dest must be non-nil pointer to a struct")
)
// ScanSlice scans src to the slice pointed to by dest. The elements the dest
// slice must be integer, float, boolean, string, struct or pointer to struct
// values.
//
// Struct fields must be integer, float, boolean or string values. All struct
// fields are used unless a subset is specified using fieldNames.
func ScanSlice(src []interface{}, dest interface{}, fieldNames ...string) error {
d := reflect.ValueOf(dest)
if d.Kind() != reflect.Ptr || d.IsNil() {
return pkgerr.WithStack(errScanSliceValue)
}
d = d.Elem()
if d.Kind() != reflect.Slice {
return pkgerr.WithStack(errScanSliceValue)
}
isPtr := false
t := d.Type().Elem()
if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct {
isPtr = true
t = t.Elem()
}
if t.Kind() != reflect.Struct {
ensureLen(d, len(src))
for i, s := range src {
if s == nil {
continue
}
if err := convertAssignValue(d.Index(i), s); err != nil {
return pkgerr.Errorf("redigo.ScanSlice: cannot assign element %d: %v", i, err)
}
}
return nil
}
ss := structSpecForType(t)
fss := ss.l
if len(fieldNames) > 0 {
fss = make([]*fieldSpec, len(fieldNames))
for i, name := range fieldNames {
fss[i] = ss.m[name]
if fss[i] == nil {
return pkgerr.Errorf("redigo.ScanSlice: ScanSlice bad field name %s", name)
}
}
}
if len(fss) == 0 {
return pkgerr.New("redigo.ScanSlice: no struct fields")
}
n := len(src) / len(fss)
if n*len(fss) != len(src) {
return pkgerr.New("redigo.ScanSlice: length not a multiple of struct field count")
}
ensureLen(d, n)
for i := 0; i < n; i++ {
d1 := d.Index(i)
if isPtr {
if d1.IsNil() {
d1.Set(reflect.New(t))
}
d1 = d1.Elem()
}
for j, fs := range fss {
s := src[i*len(fss)+j]
if s == nil {
continue
}
if err := convertAssignValue(d1.FieldByIndex(fs.index), s); err != nil {
return pkgerr.Errorf("redigo.ScanSlice: cannot assign element %d to field %s: %v", i*len(fss)+j, fs.name, err)
}
}
}
return nil
}
// Args is a helper for constructing command arguments from structured values.
type Args []interface{}
// Add returns the result of appending value to args.
func (args Args) Add(value ...interface{}) Args {
return append(args, value...)
}
// AddFlat returns the result of appending the flattened value of v to args.
//
// Maps are flattened by appending the alternating keys and map values to args.
//
// Slices are flattened by appending the slice elements to args.
//
// Structs are flattened by appending the alternating names and values of
// exported fields to args. If v is a nil struct pointer, then nothing is
// appended. The 'redis' field tag overrides struct field names. See ScanStruct
// for more information on the use of the 'redis' field tag.
//
// Other types are appended to args as is.
func (args Args) AddFlat(v interface{}) Args {
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Struct:
args = flattenStruct(args, rv)
case reflect.Slice:
for i := 0; i < rv.Len(); i++ {
args = append(args, rv.Index(i).Interface())
}
case reflect.Map:
for _, k := range rv.MapKeys() {
args = append(args, k.Interface(), rv.MapIndex(k).Interface())
}
case reflect.Ptr:
if rv.Type().Elem().Kind() == reflect.Struct {
if !rv.IsNil() {
args = flattenStruct(args, rv.Elem())
}
} else {
args = append(args, v)
}
default:
args = append(args, v)
}
return args
}
func flattenStruct(args Args, v reflect.Value) Args {
ss := structSpecForType(v.Type())
for _, fs := range ss.l {
fv := v.FieldByIndex(fs.index)
if fs.omitEmpty {
var empty = false
switch fv.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
empty = fv.Len() == 0
case reflect.Bool:
empty = !fv.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
empty = fv.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
empty = fv.Uint() == 0
case reflect.Float32, reflect.Float64:
empty = fv.Float() == 0
case reflect.Interface, reflect.Ptr:
empty = fv.IsNil()
}
if empty {
continue
}
}
args = append(args, fs.name, fv.Interface())
}
return args
}

438
library/cache/redis/scan_test.go vendored Normal file
View File

@@ -0,0 +1,438 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"fmt"
"math"
"reflect"
"testing"
)
var scanConversionTests = []struct {
src interface{}
dest interface{}
}{
{[]byte("-inf"), math.Inf(-1)},
{[]byte("+inf"), math.Inf(1)},
{[]byte("0"), float64(0)},
{[]byte("3.14159"), float64(3.14159)},
{[]byte("3.14"), float32(3.14)},
{[]byte("-100"), int(-100)},
{[]byte("101"), int(101)},
{int64(102), int(102)},
{[]byte("103"), uint(103)},
{int64(104), uint(104)},
{[]byte("105"), int8(105)},
{int64(106), int8(106)},
{[]byte("107"), uint8(107)},
{int64(108), uint8(108)},
{[]byte("0"), false},
{int64(0), false},
{[]byte("f"), false},
{[]byte("1"), true},
{int64(1), true},
{[]byte("t"), true},
{"hello", "hello"},
{[]byte("hello"), "hello"},
{[]byte("world"), []byte("world")},
{[]interface{}{[]byte("foo")}, []interface{}{[]byte("foo")}},
{[]interface{}{[]byte("foo")}, []string{"foo"}},
{[]interface{}{[]byte("hello"), []byte("world")}, []string{"hello", "world"}},
{[]interface{}{[]byte("bar")}, [][]byte{[]byte("bar")}},
{[]interface{}{[]byte("1")}, []int{1}},
{[]interface{}{[]byte("1"), []byte("2")}, []int{1, 2}},
{[]interface{}{[]byte("1"), []byte("2")}, []float64{1, 2}},
{[]interface{}{[]byte("1")}, []byte{1}},
{[]interface{}{[]byte("1")}, []bool{true}},
}
func TestScanConversion(t *testing.T) {
for _, tt := range scanConversionTests {
values := []interface{}{tt.src}
dest := reflect.New(reflect.TypeOf(tt.dest))
values, err := Scan(values, dest.Interface())
if err != nil {
t.Errorf("Scan(%v) returned error %v", tt, err)
continue
}
if !reflect.DeepEqual(tt.dest, dest.Elem().Interface()) {
t.Errorf("Scan(%v) returned %v, want %v", tt, dest.Elem().Interface(), tt.dest)
}
}
}
var scanConversionErrorTests = []struct {
src interface{}
dest interface{}
}{
{[]byte("1234"), byte(0)},
{int64(1234), byte(0)},
{[]byte("-1"), byte(0)},
{int64(-1), byte(0)},
{[]byte("junk"), false},
{Error("blah"), false},
}
func TestScanConversionError(t *testing.T) {
for _, tt := range scanConversionErrorTests {
values := []interface{}{tt.src}
dest := reflect.New(reflect.TypeOf(tt.dest))
values, err := Scan(values, dest.Interface())
if err == nil {
t.Errorf("Scan(%v) did not return error", tt)
}
}
}
func ExampleScan() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Send("HMSET", "album:1", "title", "Red", "rating", 5)
c.Send("HMSET", "album:2", "title", "Earthbound", "rating", 1)
c.Send("HMSET", "album:3", "title", "Beat")
c.Send("LPUSH", "albums", "1")
c.Send("LPUSH", "albums", "2")
c.Send("LPUSH", "albums", "3")
values, err := Values(c.Do("SORT", "albums",
"BY", "album:*->rating",
"GET", "album:*->title",
"GET", "album:*->rating"))
if err != nil {
fmt.Println(err)
return
}
for len(values) > 0 {
var title string
rating := -1 // initialize to illegal value to detect nil.
values, err = Scan(values, &title, &rating)
if err != nil {
fmt.Println(err)
return
}
if rating == -1 {
fmt.Println(title, "not-rated")
} else {
fmt.Println(title, rating)
}
}
// Output:
// Beat not-rated
// Earthbound 1
// Red 5
}
type s0 struct {
X int
Y int `redis:"y"`
Bt bool
}
type s1 struct {
X int `redis:"-"`
I int `redis:"i"`
U uint `redis:"u"`
S string `redis:"s"`
P []byte `redis:"p"`
B bool `redis:"b"`
Bt bool
Bf bool
s0
}
var scanStructTests = []struct {
title string
reply []string
value interface{}
}{
{"basic",
[]string{"i", "-1234", "u", "5678", "s", "hello", "p", "world", "b", "t", "Bt", "1", "Bf", "0", "X", "123", "y", "456"},
&s1{I: -1234, U: 5678, S: "hello", P: []byte("world"), B: true, Bt: true, Bf: false, s0: s0{X: 123, Y: 456}},
},
}
func TestScanStruct(t *testing.T) {
for _, tt := range scanStructTests {
var reply []interface{}
for _, v := range tt.reply {
reply = append(reply, []byte(v))
}
value := reflect.New(reflect.ValueOf(tt.value).Type().Elem())
if err := ScanStruct(reply, value.Interface()); err != nil {
t.Fatalf("ScanStruct(%s) returned error %v", tt.title, err)
}
if !reflect.DeepEqual(value.Interface(), tt.value) {
t.Fatalf("ScanStruct(%s) returned %v, want %v", tt.title, value.Interface(), tt.value)
}
}
}
func TestBadScanStructArgs(t *testing.T) {
x := []interface{}{"A", "b"}
test := func(v interface{}) {
if err := ScanStruct(x, v); err == nil {
t.Errorf("Expect error for ScanStruct(%T, %T)", x, v)
}
}
test(nil)
var v0 *struct{}
test(v0)
var v1 int
test(&v1)
x = x[:1]
v2 := struct{ A string }{}
test(&v2)
}
var scanSliceTests = []struct {
src []interface{}
fieldNames []string
ok bool
dest interface{}
}{
{
[]interface{}{[]byte("1"), nil, []byte("-1")},
nil,
true,
[]int{1, 0, -1},
},
{
[]interface{}{[]byte("1"), nil, []byte("2")},
nil,
true,
[]uint{1, 0, 2},
},
{
[]interface{}{[]byte("-1")},
nil,
false,
[]uint{1},
},
{
[]interface{}{[]byte("hello"), nil, []byte("world")},
nil,
true,
[][]byte{[]byte("hello"), nil, []byte("world")},
},
{
[]interface{}{[]byte("hello"), nil, []byte("world")},
nil,
true,
[]string{"hello", "", "world"},
},
{
[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
nil,
true,
[]struct{ A, B string }{{"a1", "b1"}, {"a2", "b2"}},
},
{
[]interface{}{[]byte("a1"), []byte("b1")},
nil,
false,
[]struct{ A, B, C string }{{"a1", "b1", ""}},
},
{
[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
nil,
true,
[]*struct{ A, B string }{{"a1", "b1"}, {"a2", "b2"}},
},
{
[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
[]string{"A", "B"},
true,
[]struct{ A, C, B string }{{"a1", "", "b1"}, {"a2", "", "b2"}},
},
{
[]interface{}{[]byte("a1"), []byte("b1"), []byte("a2"), []byte("b2")},
nil,
false,
[]struct{}{},
},
}
func TestScanSlice(t *testing.T) {
for _, tt := range scanSliceTests {
typ := reflect.ValueOf(tt.dest).Type()
dest := reflect.New(typ)
err := ScanSlice(tt.src, dest.Interface(), tt.fieldNames...)
if tt.ok != (err == nil) {
t.Errorf("ScanSlice(%v, []%s, %v) returned error %v", tt.src, typ, tt.fieldNames, err)
continue
}
if tt.ok && !reflect.DeepEqual(dest.Elem().Interface(), tt.dest) {
t.Errorf("ScanSlice(src, []%s) returned %#v, want %#v", typ, dest.Elem().Interface(), tt.dest)
}
}
}
func ExampleScanSlice() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Send("HMSET", "album:1", "title", "Red", "rating", 5)
c.Send("HMSET", "album:2", "title", "Earthbound", "rating", 1)
c.Send("HMSET", "album:3", "title", "Beat", "rating", 4)
c.Send("LPUSH", "albums", "1")
c.Send("LPUSH", "albums", "2")
c.Send("LPUSH", "albums", "3")
values, err := Values(c.Do("SORT", "albums",
"BY", "album:*->rating",
"GET", "album:*->title",
"GET", "album:*->rating"))
if err != nil {
fmt.Println(err)
return
}
var albums []struct {
Title string
Rating int
}
if err := ScanSlice(values, &albums); err != nil {
fmt.Println(err)
return
}
fmt.Printf("%v\n", albums)
// Output:
// [{Earthbound 1} {Beat 4} {Red 5}]
}
var argsTests = []struct {
title string
actual Args
expected Args
}{
{"struct ptr",
Args{}.AddFlat(&struct {
I int `redis:"i"`
U uint `redis:"u"`
S string `redis:"s"`
P []byte `redis:"p"`
M map[string]string `redis:"m"`
Bt bool
Bf bool
}{
-1234, 5678, "hello", []byte("world"), map[string]string{"hello": "world"}, true, false,
}),
Args{"i", int(-1234), "u", uint(5678), "s", "hello", "p", []byte("world"), "m", map[string]string{"hello": "world"}, "Bt", true, "Bf", false},
},
{"struct",
Args{}.AddFlat(struct{ I int }{123}),
Args{"I", 123},
},
{"slice",
Args{}.Add(1).AddFlat([]string{"a", "b", "c"}).Add(2),
Args{1, "a", "b", "c", 2},
},
{"struct omitempty",
Args{}.AddFlat(&struct {
I int `redis:"i,omitempty"`
U uint `redis:"u,omitempty"`
S string `redis:"s,omitempty"`
P []byte `redis:"p,omitempty"`
M map[string]string `redis:"m,omitempty"`
Bt bool `redis:"Bt,omitempty"`
Bf bool `redis:"Bf,omitempty"`
}{
0, 0, "", []byte{}, map[string]string{}, true, false,
}),
Args{"Bt", true},
},
}
func TestArgs(t *testing.T) {
for _, tt := range argsTests {
if !reflect.DeepEqual(tt.actual, tt.expected) {
t.Fatalf("%s is %v, want %v", tt.title, tt.actual, tt.expected)
}
}
}
func ExampleArgs() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
var p1, p2 struct {
Title string `redis:"title"`
Author string `redis:"author"`
Body string `redis:"body"`
}
p1.Title = "Example"
p1.Author = "Gary"
p1.Body = "Hello"
if _, err := c.Do("HMSET", Args{}.Add("id1").AddFlat(&p1)...); err != nil {
fmt.Println(err)
return
}
m := map[string]string{
"title": "Example2",
"author": "Steve",
"body": "Map",
}
if _, err := c.Do("HMSET", Args{}.Add("id2").AddFlat(m)...); err != nil {
fmt.Println(err)
return
}
for _, id := range []string{"id1", "id2"} {
v, err := Values(c.Do("HGETALL", id))
if err != nil {
fmt.Println(err)
return
}
if err := ScanStruct(v, &p2); err != nil {
fmt.Println(err)
return
}
fmt.Printf("%+v\n", p2)
}
// Output:
// {Title:Example Author:Gary Body:Hello}
// {Title:Example2 Author:Steve Body:Map}
}

86
library/cache/redis/script.go vendored Normal file
View File

@@ -0,0 +1,86 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"crypto/sha1"
"encoding/hex"
"io"
"strings"
)
// Script encapsulates the source, hash and key count for a Lua script. See
// http://redis.io/commands/eval for information on scripts in Redis.
type Script struct {
keyCount int
src string
hash string
}
// NewScript returns a new script object. If keyCount is greater than or equal
// to zero, then the count is automatically inserted in the EVAL command
// argument list. If keyCount is less than zero, then the application supplies
// the count as the first value in the keysAndArgs argument to the Do, Send and
// SendHash methods.
func NewScript(keyCount int, src string) *Script {
h := sha1.New()
io.WriteString(h, src)
return &Script{keyCount, src, hex.EncodeToString(h.Sum(nil))}
}
func (s *Script) args(spec string, keysAndArgs []interface{}) []interface{} {
var args []interface{}
if s.keyCount < 0 {
args = make([]interface{}, 1+len(keysAndArgs))
args[0] = spec
copy(args[1:], keysAndArgs)
} else {
args = make([]interface{}, 2+len(keysAndArgs))
args[0] = spec
args[1] = s.keyCount
copy(args[2:], keysAndArgs)
}
return args
}
// Do evaluates the script. Under the covers, Do optimistically evaluates the
// script using the EVALSHA command. If the command fails because the script is
// not loaded, then Do evaluates the script using the EVAL command (thus
// causing the script to load).
func (s *Script) Do(c Conn, keysAndArgs ...interface{}) (interface{}, error) {
v, err := c.Do("EVALSHA", s.args(s.hash, keysAndArgs)...)
if e, ok := err.(Error); ok && strings.HasPrefix(string(e), "NOSCRIPT ") {
v, err = c.Do("EVAL", s.args(s.src, keysAndArgs)...)
}
return v, err
}
// SendHash evaluates the script without waiting for the reply. The script is
// evaluated with the EVALSHA command. The application must ensure that the
// script is loaded by a previous call to Send, Do or Load methods.
func (s *Script) SendHash(c Conn, keysAndArgs ...interface{}) error {
return c.Send("EVALSHA", s.args(s.hash, keysAndArgs)...)
}
// Send evaluates the script without waiting for the reply.
func (s *Script) Send(c Conn, keysAndArgs ...interface{}) error {
return c.Send("EVAL", s.args(s.src, keysAndArgs)...)
}
// Load loads the script without evaluating it.
func (s *Script) Load(c Conn) error {
_, err := c.Do("SCRIPT", "LOAD", s.src)
return err
}

97
library/cache/redis/script_test.go vendored Normal file
View File

@@ -0,0 +1,97 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"fmt"
"reflect"
"testing"
"time"
)
func ExampleScript() {
c, err := Dial("tcp", ":6379")
if err != nil {
// handle error
}
defer c.Close()
// Initialize a package-level variable with a script.
var getScript = NewScript(1, `return call('get', KEYS[1])`)
// In a function, use the script Do method to evaluate the script. The Do
// method optimistically uses the EVALSHA command. If the script is not
// loaded, then the Do method falls back to the EVAL command.
if _, err = getScript.Do(c, "foo"); err != nil {
// handle error
}
}
func TestScript(t *testing.T) {
c, err := DialDefaultServer()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
// To test fall back in Do, we make script unique by adding comment with current time.
script := fmt.Sprintf("--%d\nreturn {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", time.Now().UnixNano())
s := NewScript(2, script)
reply := []interface{}{[]byte("key1"), []byte("key2"), []byte("arg1"), []byte("arg2")}
v, err := s.Do(c, "key1", "key2", "arg1", "arg2")
if err != nil {
t.Errorf("s.Do(c, ...) returned %v", err)
}
if !reflect.DeepEqual(v, reply) {
t.Errorf("s.Do(c, ..); = %v, want %v", v, reply)
}
err = s.Load(c)
if err != nil {
t.Errorf("s.Load(c) returned %v", err)
}
err = s.SendHash(c, "key1", "key2", "arg1", "arg2")
if err != nil {
t.Errorf("s.SendHash(c, ...) returned %v", err)
}
err = c.Flush()
if err != nil {
t.Errorf("c.Flush() returned %v", err)
}
v, err = c.Receive()
if !reflect.DeepEqual(v, reply) {
t.Errorf("s.SendHash(c, ..); c.Receive() = %v, want %v", v, reply)
}
err = s.Send(c, "key1", "key2", "arg1", "arg2")
if err != nil {
t.Errorf("s.Send(c, ...) returned %v", err)
}
err = c.Flush()
if err != nil {
t.Errorf("c.Flush() returned %v", err)
}
v, err = c.Receive()
if !reflect.DeepEqual(v, reply) {
t.Errorf("s.Send(c, ..); c.Receive() = %v, want %v", v, reply)
}
}

173
library/cache/redis/test_test.go vendored Normal file
View File

@@ -0,0 +1,173 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis
import (
"bufio"
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"os/exec"
"strconv"
"strings"
"sync"
"testing"
"time"
)
var (
ErrNegativeInt = errNegativeInt
serverPath = flag.String("redis-server", "redis-server", "Path to redis server binary")
serverBasePort = flag.Int("redis-port", 16379, "Beginning of port range for test servers")
serverLogName = flag.String("redis-log", "", "Write Redis server logs to `filename`")
serverLog = ioutil.Discard
defaultServerMu sync.Mutex
defaultServer *Server
defaultServerErr error
)
type Server struct {
name string
cmd *exec.Cmd
done chan struct{}
}
func NewServer(name string, args ...string) (*Server, error) {
s := &Server{
name: name,
cmd: exec.Command(*serverPath, args...),
done: make(chan struct{}),
}
r, err := s.cmd.StdoutPipe()
if err != nil {
return nil, err
}
err = s.cmd.Start()
if err != nil {
return nil, err
}
ready := make(chan error, 1)
go s.watch(r, ready)
select {
case err = <-ready:
case <-time.After(time.Second * 10):
err = errors.New("timeout waiting for server to start")
}
if err != nil {
s.Stop()
return nil, err
}
return s, nil
}
func (s *Server) watch(r io.Reader, ready chan error) {
fmt.Fprintf(serverLog, "%d START %s \n", s.cmd.Process.Pid, s.name)
var listening bool
var text string
scn := bufio.NewScanner(r)
for scn.Scan() {
text = scn.Text()
fmt.Fprintf(serverLog, "%s\n", text)
if !listening {
if strings.Contains(text, "The server is now ready to accept connections on port") {
listening = true
ready <- nil
}
}
}
if !listening {
ready <- fmt.Errorf("server exited: %s", text)
}
s.cmd.Wait()
fmt.Fprintf(serverLog, "%d STOP %s \n", s.cmd.Process.Pid, s.name)
close(s.done)
}
func (s *Server) Stop() {
s.cmd.Process.Signal(os.Interrupt)
<-s.done
}
// stopDefaultServer stops the server created by DialDefaultServer.
func stopDefaultServer() {
defaultServerMu.Lock()
defer defaultServerMu.Unlock()
if defaultServer != nil {
defaultServer.Stop()
defaultServer = nil
}
}
// startDefaultServer starts the default server if not already running.
func startDefaultServer() error {
defaultServerMu.Lock()
defer defaultServerMu.Unlock()
if defaultServer != nil || defaultServerErr != nil {
return defaultServerErr
}
defaultServer, defaultServerErr = NewServer(
"default",
"--port", strconv.Itoa(*serverBasePort),
"--save", "",
"--appendonly", "no")
return defaultServerErr
}
// DialDefaultServer starts the test server if not already started and dials a
// connection to the server.
func DialDefaultServer() (Conn, error) {
if err := startDefaultServer(); err != nil {
return nil, err
}
c, err := Dial("tcp", fmt.Sprintf(":%d", *serverBasePort), DialReadTimeout(1*time.Second), DialWriteTimeout(1*time.Second))
if err != nil {
return nil, err
}
c.Do("FLUSHDB")
return c, nil
}
func TestMain(m *testing.M) {
os.Exit(func() int {
flag.Parse()
var f *os.File
if *serverLogName != "" {
var err error
f, err = os.OpenFile(*serverLogName, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0600)
if err != nil {
fmt.Fprintf(os.Stderr, "Error opening redis-log: %v\n", err)
return 1
}
defer f.Close()
serverLog = f
}
defer stopDefaultServer()
return m.Run()
}())
}

127
library/cache/redis/trace.go vendored Normal file
View File

@@ -0,0 +1,127 @@
package redis
import (
"context"
"fmt"
"go-common/library/net/trace"
)
const (
_traceComponentName = "library/cache/redis"
_tracePeerService = "redis"
_traceSpanKind = "client"
)
var _internalTags = []trace.Tag{
trace.TagString(trace.TagSpanKind, _traceSpanKind),
trace.TagString(trace.TagComponent, _traceComponentName),
trace.TagString(trace.TagPeerService, _tracePeerService),
}
type traceConn struct {
// tr for pipeline, if tr != nil meaning on pipeline
tr trace.Trace
ctx context.Context
// connTag include e.g. ip,port
connTags []trace.Tag
// origin redis conn
Conn
pending int
}
func (t *traceConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
root, ok := trace.FromContext(t.ctx)
// NOTE: ignored empty commandName
// current sdk will Do empty command after pipeline finished
if !ok || commandName == "" {
return t.Conn.Do(commandName, args...)
}
tr := root.Fork("", "Redis:"+commandName)
tr.SetTag(_internalTags...)
tr.SetTag(t.connTags...)
statement := commandName
if len(args) > 0 {
statement += fmt.Sprintf(" %v", args[0])
}
tr.SetTag(trace.TagString(trace.TagDBStatement, statement))
reply, err = t.Conn.Do(commandName, args...)
tr.Finish(&err)
return
}
func (t *traceConn) Send(commandName string, args ...interface{}) error {
t.pending++
root, ok := trace.FromContext(t.ctx)
if !ok {
return t.Conn.Send(commandName, args...)
}
if t.tr == nil {
t.tr = root.Fork("", "Redis:Pipeline")
t.tr.SetTag(_internalTags...)
t.tr.SetTag(t.connTags...)
}
statement := commandName
if len(args) > 0 {
statement += fmt.Sprintf(" %v", args[0])
}
t.tr.SetLog(
trace.Log(trace.LogEvent, "Send"),
trace.Log("db.statement", statement),
)
err := t.Conn.Send(commandName, args...)
if err != nil {
t.tr.SetTag(trace.TagBool(trace.TagError, true))
t.tr.SetLog(
trace.Log(trace.LogEvent, "Send Fail"),
trace.Log(trace.LogMessage, err.Error()),
)
}
return err
}
func (t *traceConn) Flush() error {
if t.tr == nil {
return t.Conn.Flush()
}
t.tr.SetLog(trace.Log(trace.LogEvent, "Flush"))
err := t.Conn.Flush()
if err != nil {
t.tr.SetTag(trace.TagBool(trace.TagError, true))
t.tr.SetLog(
trace.Log(trace.LogEvent, "Flush Fail"),
trace.Log(trace.LogMessage, err.Error()),
)
}
return err
}
func (t *traceConn) Receive() (reply interface{}, err error) {
if t.tr == nil {
return t.Conn.Receive()
}
t.tr.SetLog(trace.Log(trace.LogEvent, "Receive"))
reply, err = t.Conn.Receive()
if err != nil {
t.tr.SetTag(trace.TagBool(trace.TagError, true))
t.tr.SetLog(
trace.Log(trace.LogEvent, "Receive Fail"),
trace.Log(trace.LogMessage, err.Error()),
)
}
if t.pending > 0 {
t.pending--
}
if t.pending == 0 {
t.tr.Finish(nil)
t.tr = nil
}
return reply, err
}
func (t *traceConn) WithContext(ctx context.Context) Conn {
t.ctx = ctx
return t
}

166
library/cache/redis/trace_test.go vendored Normal file
View File

@@ -0,0 +1,166 @@
package redis
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"go-common/library/net/trace"
)
type mockTrace struct {
tags []trace.Tag
logs []trace.LogField
perr *error
operationName string
finished bool
}
func (m *mockTrace) Fork(serviceName string, operationName string) trace.Trace {
m.operationName = operationName
return m
}
func (m *mockTrace) Follow(serviceName string, operationName string) trace.Trace {
panic("not implemented")
}
func (m *mockTrace) Finish(err *error) {
m.perr = err
m.finished = true
}
func (m *mockTrace) SetTag(tags ...trace.Tag) trace.Trace {
m.tags = append(m.tags, tags...)
return m
}
func (m *mockTrace) SetLog(logs ...trace.LogField) trace.Trace {
m.logs = append(m.logs, logs...)
return m
}
func (m *mockTrace) Visit(fn func(k, v string)) {}
func (m *mockTrace) SetTitle(title string) {}
type mockConn struct{}
func (c *mockConn) Close() error { return nil }
func (c *mockConn) Err() error { return nil }
func (c *mockConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
return nil, nil
}
func (c *mockConn) Send(commandName string, args ...interface{}) error { return nil }
func (c *mockConn) Flush() error { return nil }
func (c *mockConn) Receive() (reply interface{}, err error) { return nil, nil }
func (c *mockConn) WithContext(context.Context) Conn { return c }
func TestTraceDo(t *testing.T) {
tr := &mockTrace{}
ctx := trace.NewContext(context.Background(), tr)
tc := &traceConn{Conn: &mockConn{}}
conn := tc.WithContext(ctx)
conn.Do("GET", "test")
assert.Equal(t, "Redis:GET", tr.operationName)
assert.NotEmpty(t, tr.tags)
assert.True(t, tr.finished)
}
func TestTraceDoErr(t *testing.T) {
tr := &mockTrace{}
ctx := trace.NewContext(context.Background(), tr)
tc := &traceConn{Conn: MockErr{Error: fmt.Errorf("hhhhhhh")}}
conn := tc.WithContext(ctx)
conn.Do("GET", "test")
assert.Equal(t, "Redis:GET", tr.operationName)
assert.True(t, tr.finished)
assert.NotNil(t, *tr.perr)
}
func TestTracePipeline(t *testing.T) {
tr := &mockTrace{}
ctx := trace.NewContext(context.Background(), tr)
tc := &traceConn{Conn: &mockConn{}}
conn := tc.WithContext(ctx)
N := 2
for i := 0; i < N; i++ {
conn.Send("GET", "hello, world")
}
conn.Flush()
for i := 0; i < N; i++ {
conn.Receive()
}
assert.Equal(t, "Redis:Pipeline", tr.operationName)
assert.NotEmpty(t, tr.tags)
assert.NotEmpty(t, tr.logs)
assert.True(t, tr.finished)
}
func TestTracePipelineErr(t *testing.T) {
tr := &mockTrace{}
ctx := trace.NewContext(context.Background(), tr)
tc := &traceConn{Conn: MockErr{Error: fmt.Errorf("hahah")}}
conn := tc.WithContext(ctx)
N := 2
for i := 0; i < N; i++ {
conn.Send("GET", "hello, world")
}
conn.Flush()
for i := 0; i < N; i++ {
conn.Receive()
}
assert.Equal(t, "Redis:Pipeline", tr.operationName)
assert.NotEmpty(t, tr.tags)
assert.NotEmpty(t, tr.logs)
assert.True(t, tr.finished)
var isError bool
for _, tag := range tr.tags {
if tag.Key == "error" {
isError = true
}
}
assert.True(t, isError)
}
func TestSendStatement(t *testing.T) {
tr := &mockTrace{}
ctx := trace.NewContext(context.Background(), tr)
tc := &traceConn{Conn: MockErr{Error: fmt.Errorf("hahah")}}
conn := tc.WithContext(ctx)
conn.Send("SET", "hello", "test")
conn.Flush()
conn.Receive()
assert.Equal(t, "Redis:Pipeline", tr.operationName)
assert.NotEmpty(t, tr.tags)
assert.NotEmpty(t, tr.logs)
assert.Equal(t, "event", tr.logs[0].Key)
assert.Equal(t, "Send", tr.logs[0].Value)
assert.Equal(t, "db.statement", tr.logs[1].Key)
assert.Equal(t, "SET hello", tr.logs[1].Value)
assert.True(t, tr.finished)
var isError bool
for _, tag := range tr.tags {
if tag.Key == "error" {
isError = true
}
}
assert.True(t, isError)
}
func TestDoStatement(t *testing.T) {
tr := &mockTrace{}
ctx := trace.NewContext(context.Background(), tr)
tc := &traceConn{Conn: MockErr{Error: fmt.Errorf("hahah")}}
conn := tc.WithContext(ctx)
conn.Do("SET", "hello", "test")
assert.Equal(t, "Redis:SET", tr.operationName)
assert.Equal(t, "SET hello", tr.tags[len(tr.tags)-1].Value)
assert.True(t, tr.finished)
}

52
library/conf/BUILD Normal file
View File

@@ -0,0 +1,52 @@
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
package(default_visibility = ["//visibility:public"])
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [
":package-srcs",
"//library/conf/dsn:all-srcs",
"//library/conf/env:all-srcs",
"//library/conf/flagvar:all-srcs",
"//library/conf/paladin:all-srcs",
],
tags = ["automanaged"],
)
go_library(
name = "go_default_library",
srcs = [
"client.go",
"client_v2.go",
],
importpath = "go-common/library/conf",
tags = ["automanaged"],
deps = [
"//library/conf/env:go_default_library",
"//library/ecode:go_default_library",
"//library/log:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = [
"client_test.go",
"client_v2_test.go",
],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
)

452
library/conf/client.go Normal file
View File

@@ -0,0 +1,452 @@
package conf
import (
"bytes"
"crypto/md5"
"encoding/hex"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"path"
"strings"
"sync/atomic"
"time"
"go-common/library/conf/env"
"go-common/library/log"
)
const (
// code
_codeOk = 0
_codeNotModified = -304
// api
_apiGet = "http://%s/v1/config/get2?%s"
_apiCheck = "http://%s/v1/config/check?%s"
// timeout
_retryInterval = 1 * time.Second
_httpTimeout = 60 * time.Second
_unknownVersion = -1
commonKey = "common.toml"
)
var (
conf config
)
type version struct {
Code int `json:"code"`
Message string `json:"message"`
Data *struct {
Version int64 `json:"version"`
} `json:"data"`
}
type result struct {
Code int `json:"code"`
Message string `json:"message"`
Data *data `json:"data"`
}
type data struct {
Version int64 `json:"version"`
Content string `json:"content"`
Md5 string `json:"md5"`
}
// Namespace the key-value config object.
type Namespace struct {
Name string `json:"name"`
Data map[string]string `json:"data"`
}
type config struct {
Svr string
Ver string
Path string
Filename string
Host string
Addr string
Env string
Token string
Appoint string
// NOTE: new caster
Region string
Zone string
AppID string
DeployEnv string
TreeID string
}
// Client is config client.
type Client struct {
ver int64 // NOTE: for config v1
diff *ver // NOTE: for config v2
customize string
httpCli *http.Client
data atomic.Value
event chan string
useV2 bool
watchFile map[string]struct{}
watchAll bool
}
func init() {
// env
conf.Svr = os.Getenv("CONF_APPID")
conf.Ver = os.Getenv("CONF_VERSION")
conf.Addr = os.Getenv("CONF_HOST")
conf.Host = os.Getenv("CONF_HOSTNAME")
conf.Path = os.Getenv("CONF_PATH")
conf.Env = os.Getenv("CONF_ENV")
conf.Token = os.Getenv("CONF_TOKEN")
conf.Appoint = os.Getenv("CONF_APPOINT")
conf.Region = os.Getenv("REGION")
conf.Zone = os.Getenv("ZONE")
conf.AppID = os.Getenv("APP_ID")
conf.DeployEnv = os.Getenv("DEPLOY_ENV")
conf.TreeID = os.Getenv("TREE_ID")
// flags
hostname, _ := os.Hostname()
flag.StringVar(&conf.Svr, "conf_appid", conf.Svr, `app name.`)
flag.StringVar(&conf.Ver, "conf_version", conf.Ver, `app version.`)
flag.StringVar(&conf.Addr, "conf_host", conf.Addr, `config center api host.`)
flag.StringVar(&conf.Host, "conf_hostname", hostname, `hostname.`)
flag.StringVar(&conf.Path, "conf_path", conf.Path, `config file path.`)
flag.StringVar(&conf.Env, "conf_env", conf.Env, `config Env.`)
flag.StringVar(&conf.Token, "conf_token", conf.Token, `config Token.`)
flag.StringVar(&conf.Appoint, "conf_appoint", conf.Appoint, `config Appoint.`)
/*
flag.StringVar(&conf.Region, "region", conf.Region, `region.`)
flag.StringVar(&conf.Zone, "zone", conf.Zone, `zone.`)
flag.StringVar(&conf.AppID, "app_id", conf.AppID, `app id.`)
flag.StringVar(&conf.DeployEnv, "deploy_env", conf.DeployEnv, `deploy env.`)
*/
conf.Region = env.Region
conf.Zone = env.Zone
conf.AppID = env.AppID
conf.DeployEnv = env.DeployEnv
// FIXME(linli) remove treeid
flag.StringVar(&conf.TreeID, "tree_id", conf.TreeID, `tree id.`)
}
// New new a ugc config center client.
func New() (cli *Client, err error) {
cli = &Client{
httpCli: &http.Client{Timeout: _httpTimeout},
event: make(chan string, 10),
}
if conf.Svr != "" && conf.Host != "" && conf.Path != "" && conf.Addr != "" && conf.Ver != "" && conf.Env != "" && conf.Token != "" &&
(strings.HasPrefix(conf.Ver, "shsb") || (strings.HasPrefix(conf.Ver, "shylf"))) {
if err = cli.init(); err != nil {
return nil, err
}
go cli.updateproc()
return
}
if conf.Zone != "" && conf.AppID != "" && conf.Host != "" && conf.Path != "" && conf.Addr != "" && conf.Ver != "" && conf.DeployEnv != "" && conf.Token != "" {
if err = cli.init2(); err != nil {
return nil, err
}
go cli.updateproc2()
cli.useV2 = true
return
}
err = fmt.Errorf("at least one params is empty. app=%s, version=%s, hostname=%s, addr=%s, path=%s, Env=%s, Token =%s, DeployEnv=%s, TreeID=%s, appID=%s",
conf.Svr, conf.Ver, conf.Host, conf.Addr, conf.Path, conf.Env, conf.Token, conf.DeployEnv, conf.TreeID, conf.AppID)
return
}
// Path get confFile Path.
func (c *Client) Path() string {
return conf.Path
}
// Toml return config value.
func (c *Client) Toml() (cf string, ok bool) {
if c.useV2 {
return c.Toml2()
}
var (
m map[string]*Namespace
n *Namespace
)
if m, ok = c.data.Load().(map[string]*Namespace); !ok {
return
}
if n, ok = m[""]; !ok {
return
}
cf, ok = n.Data[commonKey]
return
}
// Value return config value.
func (c *Client) Value(key string) (cf string, ok bool) {
if c.useV2 {
return c.Value2(key)
}
var (
m map[string]*Namespace
n *Namespace
)
if m, ok = c.data.Load().(map[string]*Namespace); !ok {
return
}
if n, ok = m[""]; !ok {
return
}
cf, ok = n.Data[key]
return
}
// SetCustomize set customize value.
func (c *Client) SetCustomize(value string) {
c.customize = value
}
// Event client update event.
func (c *Client) Event() <-chan string {
return c.event
}
// Watch watch filename change.
func (c *Client) Watch(filename ...string) {
if c.watchFile == nil {
c.watchFile = map[string]struct{}{}
}
for _, f := range filename {
c.watchFile[f] = struct{}{}
}
}
// WatchAll watch all filename change.
func (c *Client) WatchAll() {
c.watchAll = true
}
// checkLocal check local config is ok
func (c *Client) init() (err error) {
var ver int64
if ver, err = c.checkVersion(_unknownVersion); err != nil {
fmt.Printf("get remote version error(%v)\n", err)
return
}
for i := 0; i < 3; i++ {
if ver == _unknownVersion {
fmt.Println("get null version")
return
}
if err = c.download(ver); err == nil {
return
}
fmt.Printf("retry times: %d, c.download() error(%v)\n", i, err)
time.Sleep(_retryInterval)
}
return
}
func (c *Client) updateproc() (err error) {
var ver int64
for {
time.Sleep(_retryInterval)
if ver, err = c.checkVersion(c.ver); err != nil {
log.Error("c.checkVersion(%d) error(%v)", c.ver, err)
continue
} else if ver == c.ver {
continue
}
if err = c.download(ver); err != nil {
log.Error("c.download() error(%s)", err)
continue
}
c.event <- ""
}
}
// download download config from config service
func (c *Client) download(ver int64) (err error) {
var data *data
if data, err = c.getConfig(ver); err != nil {
return
}
return c.update(data)
}
// poll config server
func (c *Client) checkVersion(reqVer int64) (ver int64, err error) {
var (
url string
req *http.Request
resp *http.Response
rb []byte
)
if url = c.makeURL(_apiCheck, reqVer); url == "" {
err = fmt.Errorf("checkVersion() c.makeUrl() error url empty")
return
}
// http
if req, err = http.NewRequest("GET", url, nil); err != nil {
return
}
if resp, err = c.httpCli.Do(req); err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("checkVersion() http error url(%s) status: %d", url, resp.StatusCode)
return
}
// ok
if rb, err = ioutil.ReadAll(resp.Body); err != nil {
return
}
v := &version{}
if err = json.Unmarshal(rb, v); err != nil {
return
}
switch v.Code {
case _codeOk:
if v.Data == nil {
err = fmt.Errorf("checkVersion() response error result: %v", v)
return
}
ver = v.Data.Version
case _codeNotModified:
ver = reqVer
default:
err = fmt.Errorf("checkVersion() response error result: %v", v)
}
return
}
// updateVersion update config version
func (c *Client) getConfig(ver int64) (data *data, err error) {
var (
url string
req *http.Request
resp *http.Response
rb []byte
res = &result{}
)
if url = c.makeURL(_apiGet, ver); url == "" {
err = fmt.Errorf("getConfig() c.makeUrl() error url empty")
return
}
// http
if req, err = http.NewRequest("GET", url, nil); err != nil {
return
}
if resp, err = c.httpCli.Do(req); err != nil {
return
}
defer resp.Body.Close()
// ok
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("getConfig() http error url(%s) status: %d", url, resp.StatusCode)
return
}
if rb, err = ioutil.ReadAll(resp.Body); err != nil {
return
}
if err = json.Unmarshal(rb, res); err != nil {
return
}
switch res.Code {
case _codeOk:
// has new config
if res.Data == nil {
err = fmt.Errorf("getConfig() response error result: %v", res)
return
}
data = res.Data
default:
err = fmt.Errorf("getConfig() response error result: %v", res)
}
return
}
// update write config
func (c *Client) update(d *data) (err error) {
var (
tmp = make(map[string]*Namespace)
bs = []byte(d.Content)
buf = new(bytes.Buffer)
n *Namespace
ok bool
)
// md5 file
if mh := md5.Sum(bs); hex.EncodeToString(mh[:]) != d.Md5 {
err = fmt.Errorf("md5 mismatch, local:%s, remote:%s", hex.EncodeToString(mh[:]), d.Md5)
return
}
// write conf
if err = json.Unmarshal(bs, &tmp); err != nil {
return
}
for _, value := range tmp {
for k, v := range value.Data {
if strings.Contains(k, ".toml") {
buf.WriteString(v)
buf.WriteString("\n")
}
if err = ioutil.WriteFile(path.Join(conf.Path, k), []byte(v), 0644); err != nil {
return
}
}
}
if n, ok = tmp[""]; !ok {
n = &Namespace{Data: make(map[string]string)}
tmp[""] = n
}
n.Data[commonKey] = buf.String()
// update current version
c.ver = d.Version
c.data.Store(tmp)
return
}
// makeUrl signed url
func (c *Client) makeURL(api string, ver int64) (query string) {
params := url.Values{}
// service
params.Set("service", conf.Svr)
params.Set("hostname", conf.Host)
params.Set("build", conf.Ver)
params.Set("version", fmt.Sprint(ver))
params.Set("ip", localIP())
params.Set("environment", conf.Env)
params.Set("token", conf.Token)
params.Set("appoint", conf.Appoint)
params.Set("customize", c.customize)
// api
query = fmt.Sprintf(api, conf.Addr, params.Encode())
return
}
// localIP return local IP of the host.
func localIP() string {
addrs, err := net.InterfaceAddrs()
if err != nil {
return ""
}
for _, address := range addrs {
// check the address type and if it is not a loopback the display it
if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
if ipnet.IP.To4() != nil {
return ipnet.IP.String()
}
}
}
return ""
}

View File

@@ -0,0 +1,90 @@
package conf
import (
"net/http"
"testing"
)
func TestConf_client(t *testing.T) {
c := initConf()
testClientValue(t, c)
testCheckVersion(t, c)
testUpdate(t, c)
testDownload(t, c)
testGetConfig(t, c)
}
func TestClientNew(t *testing.T) {
initConf()
if _, err := New(); err != nil {
t.Errorf("client.New() error(%v)", err)
t.FailNow()
}
}
func testClientValue(t *testing.T, c *Client) {
key := "breaker"
testUpdate(t, c)
test1, ok := c.Value(key)
if !ok {
t.Errorf("client.Value() error")
t.FailNow()
}
t.Logf("get the result test1(%s)", test1)
}
func testCheckVersion(t *testing.T, c *Client) {
ver, err := c.checkVersion(_unknownVersion)
if err != nil && ver == _unknownVersion {
t.Errorf("client.checkVersion() error(%v) ver(%d)", err, ver)
t.FailNow()
}
}
func testDownload(t *testing.T, c *Client) {
ver := int64(102)
if err := c.download(ver); err != nil {
t.Errorf("client.downloda() error(%v) ", err)
t.FailNow()
}
}
func testUpdate(t *testing.T, c *Client) {
data := &data{
Version: 199,
Content: "{\"\":{\"name\":\"\",\"data\":{\"breaker\":\"fuck778\",\"degrade\":\"shit233333\"}},\"redis\":{\"name\":\"redis\",\"data\":{\"444\":\"555\",\"address\":\"172.123.0\",\"array\":\"4,12,test,4\",\"float\":\"3.123\",\"router\":\"test=1,fuck=shit,abc=test\",\"switch\":\"true\",\"timeout\":\"30s\"}}}",
Md5: "0843192c43148cbbf43aabb24e3e6442",
}
if err := c.update(data); err != nil {
t.Errorf("client.update() error(%v)", err)
t.FailNow()
}
}
func testGetConfig(t *testing.T, c *Client) {
ver := int64(102)
data, err := c.getConfig(ver)
if err != nil {
t.Errorf("client.getconfiig() error(%v)", err)
t.FailNow()
}
t.Logf("get the result data(%v)", data)
}
func initConf() (c *Client) {
conf.Addr = "172.16.33.134:9011"
conf.Host = "testHost"
conf.Path = "./"
conf.Svr = "config_test"
conf.Ver = "shsb-docker-1"
conf.Env = "10"
conf.Token = "qmVUPwNXnNfcSpuyqbiIBb0H4GcbSZFV"
//conf.Appoint = "88"
c = &Client{
httpCli: &http.Client{Timeout: _httpTimeout},
event: make(chan string, 10),
}
c.data.Store(make(map[string]*Namespace))
return
}

435
library/conf/client_v2.go Normal file
View File

@@ -0,0 +1,435 @@
package conf
import (
"bytes"
"crypto/md5"
"encoding/hex"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"path"
"strings"
"time"
"go-common/library/ecode"
"go-common/library/log"
)
const (
// api
_apiGet1 = "http://%s/config/v2/get?%s"
_apiCheck1 = "http://%s/config/v2/check?%s"
_apiCreate = "http://%s/config/v2/create"
_apiUpdate = "http://%s/config/v2/update"
_apiConfIng = "http://%s/config/v2/config/ing?%s"
)
type version1 struct {
Code int `json:"code"`
Message string `json:"message"`
Data *ver `json:"data"`
}
type ver struct {
Version int64 `json:"version"`
Diffs []int64 `json:"diffs"`
}
type confIng struct {
Code int `json:"code"`
Message string `json:"message"`
Data *Value `json:"data"`
}
type res struct {
Code int `json:"code"`
Message string `json:"message"`
}
//Value value.
type Value struct {
CID int64 `json:"cid"`
Name string `json:"name"`
Config string `json:"config"`
}
// Toml2 return config value.
func (c *Client) Toml2() (cf string, ok bool) {
var (
m map[string]*Value
val *Value
)
if m, ok = c.data.Load().(map[string]*Value); !ok {
return
}
if val, ok = m[commonKey]; !ok {
return
}
cf = val.Config
return
}
// Value2 return config value.
func (c *Client) Value2(key string) (cf string, ok bool) {
var (
m map[string]*Value
val *Value
)
if m, ok = c.data.Load().(map[string]*Value); !ok {
return
}
if val, ok = m[key]; !ok {
return
}
cf = val.Config
return
}
// init check local config is ok
func (c *Client) init2() (err error) {
var v *ver
c.data.Store(make(map[string]*Value))
if v, err = c.checkVersion2(&ver{Version: _unknownVersion}); err != nil {
fmt.Printf("get remote version error(%v)\n", err)
return
}
for i := 0; i < 3; i++ {
if v.Version == _unknownVersion {
fmt.Println("get null version")
return
}
if err = c.download2(v, true); err == nil {
return
}
fmt.Printf("retry times: %d, c.download() error(%v)\n", i, err)
time.Sleep(_retryInterval)
}
return
}
func (c *Client) updateproc2() (err error) {
var ver *ver
for {
time.Sleep(_retryInterval)
if ver, err = c.checkVersion2(c.diff); err != nil {
log.Error("c.checkVersion(%d) error(%v)", c.ver, err)
continue
} else if ver.Version == c.diff.Version {
continue
}
if err = c.download2(ver, false); err != nil {
log.Error("c.download() error(%s)", err)
continue
}
}
}
// download download config from config service
func (c *Client) download2(ver *ver, isFirst bool) (err error) {
var (
d *data
tmp []*Value
oConfs, confs map[string]*Value
buf = new(bytes.Buffer)
ok bool
)
if d, err = c.getConfig2(ver); err != nil {
return
}
bs := []byte(d.Content)
// md5 file
if mh := md5.Sum(bs); hex.EncodeToString(mh[:]) != d.Md5 {
err = fmt.Errorf("md5 mismatch, local:%s, remote:%s", hex.EncodeToString(mh[:]), d.Md5)
return
}
// write conf
if err = json.Unmarshal(bs, &tmp); err != nil {
return
}
confs = make(map[string]*Value)
if oConfs, ok = c.data.Load().(map[string]*Value); ok {
for k, v := range oConfs {
confs[k] = v
}
}
for _, v := range tmp {
if err = ioutil.WriteFile(path.Join(conf.Path, v.Name), []byte(v.Config), 0644); err != nil {
return
}
confs[v.Name] = v
}
for _, v := range confs {
if strings.Contains(v.Name, ".toml") {
buf.WriteString(v.Config)
buf.WriteString("\n")
}
}
confs[commonKey] = &Value{Config: buf.String()}
// update current version
c.diff = ver
c.data.Store(confs)
if isFirst {
return
}
for _, v := range tmp {
if c.watchAll {
c.event <- v.Name
continue
}
if c.watchFile == nil {
continue
}
if _, ok := c.watchFile[v.Name]; ok {
c.event <- v.Name
}
}
return
}
// poll config server
func (c *Client) checkVersion2(reqVer *ver) (ver *ver, err error) {
var (
url string
req *http.Request
resp *http.Response
rb []byte
)
if url, err = c.makeURL2(_apiCheck1, reqVer); err != nil {
err = fmt.Errorf("checkVersion() c.makeUrl() error url empty")
return
}
// http
if req, err = http.NewRequest("GET", url, nil); err != nil {
return
}
if resp, err = c.httpCli.Do(req); err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("checkVersion() http error url(%s) status: %d", url, resp.StatusCode)
return
}
// ok
if rb, err = ioutil.ReadAll(resp.Body); err != nil {
return
}
v := &version1{}
if err = json.Unmarshal(rb, v); err != nil {
return
}
switch v.Code {
case _codeOk:
if v.Data == nil {
err = fmt.Errorf("checkVersion() response error result: %v", v)
return
}
ver = v.Data
case _codeNotModified:
ver = reqVer
default:
err = fmt.Errorf("checkVersion() response error result: %v", v)
}
return
}
// updateVersion update config version
func (c *Client) getConfig2(ver *ver) (data *data, err error) {
var (
url string
req *http.Request
resp *http.Response
rb []byte
res = &result{}
)
if url, err = c.makeURL2(_apiGet1, ver); err != nil {
err = fmt.Errorf("getConfig() c.makeUrl() error url empty")
return
}
// http
if req, err = http.NewRequest("GET", url, nil); err != nil {
return
}
if resp, err = c.httpCli.Do(req); err != nil {
return
}
defer resp.Body.Close()
// ok
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("getConfig() http error url(%s) status: %d", url, resp.StatusCode)
return
}
if rb, err = ioutil.ReadAll(resp.Body); err != nil {
return
}
if err = json.Unmarshal(rb, res); err != nil {
return
}
switch res.Code {
case _codeOk:
// has new config
if res.Data == nil {
err = fmt.Errorf("getConfig() response error result: %v", res)
return
}
data = res.Data
default:
err = fmt.Errorf("getConfig() response error result: %v", res)
}
return
}
// makeUrl signed url
func (c *Client) makeURL2(api string, ver *ver) (query string, err error) {
var ids []byte
params := url.Values{}
// service
params.Set("service", service())
params.Set("hostname", conf.Host)
params.Set("build", conf.Ver)
params.Set("version", fmt.Sprint(ver.Version))
if ids, err = json.Marshal(ver.Diffs); err != nil {
return
}
params.Set("ids", string(ids))
params.Set("ip", localIP())
params.Set("token", conf.Token)
params.Set("appoint", conf.Appoint)
params.Set("customize", c.customize)
// api
query = fmt.Sprintf(api, conf.Addr, params.Encode())
return
}
//Create create.
func (c *Client) Create(name, content, operator, mark string) (err error) {
var (
resp *http.Response
rb []byte
res = &res{}
)
params := url.Values{}
params.Set("service", service())
params.Set("name", name)
params.Set("content", content)
params.Set("operator", operator)
params.Set("mark", mark)
params.Set("token", conf.Token)
if resp, err = c.httpCli.PostForm(fmt.Sprintf(_apiCreate, conf.Addr), params); err != nil {
return
}
defer resp.Body.Close()
// ok
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("Create() http error url(%s) status: %d", fmt.Sprintf(_apiCreate, conf.Addr), resp.StatusCode)
return
}
if rb, err = ioutil.ReadAll(resp.Body); err != nil {
return
}
if err = json.Unmarshal(rb, res); err != nil {
return
}
if res.Code != ecode.OK.Code() {
err = ecode.Int(res.Code)
}
return
}
//Update update.
func (c *Client) Update(ID int64, content, operator, mark string) (err error) {
var (
resp *http.Response
rb []byte
res = &result{}
)
params := url.Values{}
params.Set("conf_id", fmt.Sprintf("%d", ID))
params.Set("content", content)
params.Set("operator", operator)
params.Set("mark", mark)
params.Set("service", service())
params.Set("token", conf.Token)
if resp, err = c.httpCli.PostForm(fmt.Sprintf(_apiUpdate, conf.Addr), params); err != nil {
return
}
defer resp.Body.Close()
// ok
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("Update() http error url(%s) status: %d", fmt.Sprintf(_apiUpdate, conf.Addr), resp.StatusCode)
return
}
if rb, err = ioutil.ReadAll(resp.Body); err != nil {
return
}
if err = json.Unmarshal(rb, res); err != nil {
return
}
if res.Code != ecode.OK.Code() {
err = ecode.Int(res.Code)
}
return
}
//ConfIng confIng.
func (c *Client) ConfIng(name string) (v *Value, err error) {
var (
req *http.Request
resp *http.Response
rb []byte
res = &confIng{}
)
params := url.Values{}
params.Set("name", name)
params.Set("service", service())
params.Set("token", conf.Token)
// http
if req, err = http.NewRequest("GET", fmt.Sprintf(_apiConfIng, conf.Addr, params.Encode()), nil); err != nil {
return
}
if resp, err = c.httpCli.Do(req); err != nil {
return
}
defer resp.Body.Close()
// ok
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("ConfIng() http error url(%s) status: %d", _apiCreate, resp.StatusCode)
return
}
if rb, err = ioutil.ReadAll(resp.Body); err != nil {
return
}
if err = json.Unmarshal(rb, res); err != nil {
return
}
if res.Code != ecode.OK.Code() {
err = ecode.Int(res.Code)
return
}
v = res.Data
return
}
//Configs configs.
func (c *Client) Configs() (confs []*Value, ok bool) {
var (
m map[string]*Value
)
if m, ok = c.data.Load().(map[string]*Value); !ok {
return
}
for _, v := range m {
if v.CID == 0 {
continue
}
confs = append(confs, v)
}
return
}
func service() string {
return fmt.Sprintf("%s_%s_%s", conf.TreeID, conf.DeployEnv, conf.Zone)
}

View File

@@ -0,0 +1,96 @@
package conf
import (
"net/http"
"testing"
)
func TestConf_client2(t *testing.T) {
c := initConf2()
testClientValue2(t, c)
testCheckVersion2(t, c)
testDownload2(t, c)
testGetConfig2(t, c)
}
func testClientValue2(t *testing.T, c *Client) {
key := "test.toml"
testDownload2(t, c)
test1, ok := c.Value2(key)
if !ok {
t.Errorf("client.Value() error")
t.FailNow()
}
t.Logf("get the result test1(%s)", test1)
}
func testCheckVersion2(t *testing.T, c *Client) {
unknow := &ver{Version: _unknownVersion}
ver, err := c.checkVersion2(unknow)
if err != nil {
t.Errorf("client.checkVersion() error(%v) ver(%d)", err, ver)
t.FailNow()
}
}
func testDownload2(t *testing.T, c *Client) {
ver := &ver{Version: 13}
if err := c.download2(ver, true); err != nil {
t.Errorf("client.downloda() error(%v) ", err)
t.FailNow()
}
}
func testGetConfig2(t *testing.T, c *Client) {
ver := &ver{Version: 13}
data, err := c.getConfig2(ver)
if err != nil {
t.Errorf("client.getconfiig() error(%v)", err)
t.FailNow()
}
t.Logf("get the result data(%v)", data)
}
func TestClient_Create(t *testing.T) {
c := initConf2()
if err := c.Create("zjx11.toml", "test comment", "zjx", "mark"); err != nil {
t.Errorf("client.Create() error(%v)", err)
t.FailNow()
}
}
func TestClient_Update(t *testing.T) {
c := initConf2()
if err := c.Update(21, "test comment11", "zjx", "mark"); err != nil {
t.Errorf("client.Create() error(%v)", err)
t.FailNow()
}
}
func TestClient_ConfIng(t *testing.T) {
c := initConf2()
if val, err := c.ConfIng("zjx1.toml"); err != nil {
t.Errorf("client.Create() error(%v)", err)
t.FailNow()
} else {
t.Logf("%v", val)
}
}
func initConf2() (c *Client) {
conf.Addr = "172.16.33.134:9011"
conf.Host = "testHost"
conf.Path = "./"
conf.AppID = "main.common-arch.msm-service"
conf.Svr = "msm-service"
conf.Ver = "server-1"
conf.DeployEnv = "dev"
conf.Zone = "sh001"
conf.Token = "45338e440bdc11e880ce02420a0a0204"
conf.TreeID = "2888"
c = &Client{
httpCli: &http.Client{Timeout: _httpTimeout},
event: make(chan string, 10),
}
return
}

56
library/conf/dsn/BUILD Normal file
View File

@@ -0,0 +1,56 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_test",
"go_library",
)
go_test(
name = "go_default_test",
srcs = [
"dsn_test.go",
"query_test.go",
],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
deps = ["//library/time:go_default_library"],
)
go_library(
name = "go_default_library",
srcs = [
"doc.go",
"dsn.go",
"query.go",
],
importpath = "go-common/library/conf/dsn",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = ["//vendor/gopkg.in/go-playground/validator.v9:go_default_library"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)
go_test(
name = "go_default_xtest",
srcs = ["example_test.go"],
tags = ["automanaged"],
deps = [
"//library/conf/dsn:go_default_library",
"//library/time:go_default_library",
],
)

63
library/conf/dsn/doc.go Normal file
View File

@@ -0,0 +1,63 @@
// Package dsn implements dsn parse with struct bind
/*
DSN 格式类似 URI, DSN 结构如下图
network:[//[username[:password]@]address[:port][,address[:port]]][/path][?query][#fragment]
与 URI 的主要区别在于 scheme 被替换为 network, host 被替换为 address 并且支持多个 address.
network 与 net 包中 network 意义相同, tcp、udp、unix 等, address 支持多个使用 ',' 分割, 如果
network 为 unix 等本地 sock 协议则使用 Path, 有且只有一个
dsn 包主要提供了 Parse, Bind 和 validate 功能
Parse 解析 dsn 字符串成 DSN struct, DSN struct 与 url.URL 几乎完全一样
Bind 提供将 DSN 数据绑定到一个 struct 的功能, 通过 tag dsn:"key,[default]" 指定绑定的字段, 目前支持两种类型的数据绑定
内置变量 key:
network string tcp, udp, unix 等, 参考 net 包中的 network
username string
password string
address string or []string address 可以绑定到 string 或者 []string, 如果为 string 则取 address 第一个
Query: 通过 query.name 可以取到 query 上的数据
数组可以通过传递多个获得
array=1&array=2&array3 -> []int `tag:"query.array"`
struct 支持嵌套
foo.sub.name=hello&foo.tm=hello
struct Foo {
Tm string `dsn:"query.tm"`
Sub struct {
Name string `dsn:"query.name"`
} `dsn:"query.sub"`
}
默认值: 通过 dsn:"key,[default]" 默认值暂时不支持数组
忽略 Bind: 通过 dsn:"-" 忽略 Bind
自定义 Bind: 可以同时实现 encoding.TextUnmarshaler 自定义 Bind 实现
Validate: 参考 https://github.com/go-playground/validator
使用参考: example_test.go
DSN 命名规范:
没有历史遗留的情况下,尽量使用 Address, Network, Username, Password 等命名,代替之前的 Proto 和 Addr 等命名
Query 命名参考, 使用驼峰小写开头:
timeout 通用超时
dialTimeout 连接建立超时
readTimeout 读操作超时
writeTimeout 写操作超时
readsTimeout 批量读超时
writesTimeout 批量写超时
*/
package dsn

108
library/conf/dsn/dsn.go Normal file
View File

@@ -0,0 +1,108 @@
// Package dsn provide parse dsn and bind to struct
// see http://git.bilibili.co/platform/go-common/issues/279
package dsn
import (
"net/url"
"reflect"
"strings"
"gopkg.in/go-playground/validator.v9"
)
var _validator *validator.Validate
func init() {
_validator = validator.New()
}
// DSN a DSN represents a parsed DSN as same as url.URL.
type DSN struct {
*url.URL
}
// Bind dsn to specify struct and validate use use go-playground/validator format
//
// The bind of each struct field can be customized by the format string
// stored under the 'dsn' key in the struct field's tag. The format string
// gives the name of the field, possibly followed by a comma-separated
// list of options. The name may be empty in order to specify options
// without overriding the default field name.
//
// A two type data you can bind to struct
// built-in values, use below keys to bind built-in value
// username
// password
// address
// network
// the value in query string, use query.{name} to bind value in query string
//
// As a special case, if the field tag is "-", the field is always omitted.
// NOTE: that a field with name "-" can still be generated using the tag "-,".
//
// Examples of struct field tags and their meanings:
// // Field bind username
// Field string `dsn:"username"`
// // Field is ignored by this package.
// Field string `dsn:"-"`
// // Field bind value from query
// Field string `dsn:"query.name"`
//
func (d *DSN) Bind(v interface{}) (url.Values, error) {
assignFuncs := make(map[string]assignFunc)
if d.User != nil {
username := d.User.Username()
password, ok := d.User.Password()
if ok {
assignFuncs["password"] = stringsAssignFunc(password)
}
assignFuncs["username"] = stringsAssignFunc(username)
}
assignFuncs["address"] = addressesAssignFunc(d.Addresses())
assignFuncs["network"] = stringsAssignFunc(d.Scheme)
query, err := bindQuery(d.Query(), v, assignFuncs)
if err != nil {
return nil, err
}
return query, _validator.Struct(v)
}
func addressesAssignFunc(addresses []string) assignFunc {
return func(v reflect.Value, to tagOpt) error {
if v.Kind() == reflect.String {
if addresses[0] == "" && to.Default != "" {
v.SetString(to.Default)
} else {
v.SetString(addresses[0])
}
return nil
}
if !(v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.String) {
return &BindTypeError{Value: strings.Join(addresses, ","), Type: v.Type()}
}
vals := reflect.MakeSlice(v.Type(), len(addresses), len(addresses))
for i, address := range addresses {
vals.Index(i).SetString(address)
}
if v.CanSet() {
v.Set(vals)
}
return nil
}
}
// Addresses parse host split by ','
// For Unix networks, return ['path']
func (d *DSN) Addresses() []string {
switch d.Scheme {
case "unix", "unixgram", "unixpacket":
return []string{d.Path}
}
return strings.Split(d.Host, ",")
}
// Parse parses rawdsn into a URL structure.
func Parse(rawdsn string) (*DSN, error) {
u, err := url.Parse(rawdsn)
return &DSN{URL: u}, err
}

View File

@@ -0,0 +1,79 @@
package dsn
import (
"net/url"
"reflect"
"testing"
"time"
xtime "go-common/library/time"
)
type config struct {
Network string `dsn:"network"`
Addresses []string `dsn:"address"`
Username string `dsn:"username"`
Password string `dsn:"password"`
Timeout xtime.Duration `dsn:"query.timeout"`
Sub Sub `dsn:"query.sub"`
Def string `dsn:"query.def,hello"`
}
type Sub struct {
Foo int `dsn:"query.foo"`
}
func TestBind(t *testing.T) {
var cfg config
rawdsn := "tcp://root:toor@172.12.23.34,178.23.34.45?timeout=1s&sub.foo=1&hello=world"
dsn, err := Parse(rawdsn)
if err != nil {
t.Fatal(err)
}
values, err := dsn.Bind(&cfg)
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(values, url.Values{"hello": {"world"}}) {
t.Errorf("unexpect values get %v", values)
}
cfg2 := config{
Network: "tcp",
Addresses: []string{"172.12.23.34", "178.23.34.45"},
Password: "toor",
Username: "root",
Sub: Sub{Foo: 1},
Timeout: xtime.Duration(time.Second),
Def: "hello",
}
if !reflect.DeepEqual(cfg, cfg2) {
t.Errorf("unexpect config get %v, expect %v", cfg, cfg2)
}
}
type config2 struct {
Network string `dsn:"network"`
Address string `dsn:"address"`
Timeout xtime.Duration `dsn:"query.timeout"`
}
func TestUnix(t *testing.T) {
var cfg config2
rawdsn := "unix:///run/xxx.sock?timeout=1s&sub.foo=1&hello=world"
dsn, err := Parse(rawdsn)
if err != nil {
t.Fatal(err)
}
_, err = dsn.Bind(&cfg)
if err != nil {
t.Error(err)
}
cfg2 := config2{
Network: "unix",
Address: "/run/xxx.sock",
Timeout: xtime.Duration(time.Second),
}
if !reflect.DeepEqual(cfg, cfg2) {
t.Errorf("unexpect config2 get %v, expect %v", cfg, cfg2)
}
}

View File

@@ -0,0 +1,31 @@
package dsn_test
import (
"log"
"go-common/library/conf/dsn"
xtime "go-common/library/time"
)
// Config struct
type Config struct {
Network string `dsn:"network" validate:"required"`
Host string `dsn:"host" validate:"required"`
Username string `dsn:"username" validate:"required"`
Password string `dsn:"password" validate:"required"`
Timeout xtime.Duration `dsn:"query.timeout,1s"`
Offset int `dsn:"query.offset" validate:"gte=0"`
}
func ExampleParse() {
cfg := &Config{}
d, err := dsn.Parse("tcp://root:toor@172.12.12.23:2233?timeout=10s")
if err != nil {
log.Fatal(err)
}
_, err = d.Bind(cfg)
if err != nil {
log.Fatal(err)
}
log.Printf("%v", cfg)
}

422
library/conf/dsn/query.go Normal file
View File

@@ -0,0 +1,422 @@
package dsn
import (
"encoding"
"net/url"
"reflect"
"runtime"
"strconv"
"strings"
)
const (
_tagID = "dsn"
_queryPrefix = "query."
)
// InvalidBindError describes an invalid argument passed to DecodeQuery.
// (The argument to DecodeQuery must be a non-nil pointer.)
type InvalidBindError struct {
Type reflect.Type
}
func (e *InvalidBindError) Error() string {
if e.Type == nil {
return "Bind(nil)"
}
if e.Type.Kind() != reflect.Ptr {
return "Bind(non-pointer " + e.Type.String() + ")"
}
return "Bind(nil " + e.Type.String() + ")"
}
// BindTypeError describes a query value that was
// not appropriate for a value of a specific Go type.
type BindTypeError struct {
Value string
Type reflect.Type
}
func (e *BindTypeError) Error() string {
return "cannot decode " + e.Value + " into Go value of type " + e.Type.String()
}
type assignFunc func(v reflect.Value, to tagOpt) error
func stringsAssignFunc(val string) assignFunc {
return func(v reflect.Value, to tagOpt) error {
if v.Kind() != reflect.String || !v.CanSet() {
return &BindTypeError{Value: "string", Type: v.Type()}
}
if val == "" {
v.SetString(to.Default)
} else {
v.SetString(val)
}
return nil
}
}
// bindQuery parses url.Values and stores the result in the value pointed to by v.
// if v is nil or not a pointer, bindQuery returns an InvalidDecodeError
func bindQuery(query url.Values, v interface{}, assignFuncs map[string]assignFunc) (url.Values, error) {
if assignFuncs == nil {
assignFuncs = make(map[string]assignFunc)
}
d := decodeState{
data: query,
used: make(map[string]bool),
assignFuncs: assignFuncs,
}
err := d.decode(v)
ret := d.unused()
return ret, err
}
type tagOpt struct {
Name string
Default string
}
func parseTag(tag string) tagOpt {
vs := strings.SplitN(tag, ",", 2)
if len(vs) == 2 {
return tagOpt{Name: vs[0], Default: vs[1]}
}
return tagOpt{Name: vs[0]}
}
type decodeState struct {
data url.Values
used map[string]bool
assignFuncs map[string]assignFunc
}
func (d *decodeState) unused() url.Values {
ret := make(url.Values)
for k, v := range d.data {
if !d.used[k] {
ret[k] = v
}
}
return ret
}
func (d *decodeState) decode(v interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(r)
}
err = r.(error)
}
}()
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return &InvalidBindError{reflect.TypeOf(v)}
}
return d.root(rv)
}
func (d *decodeState) root(v reflect.Value) error {
var tu encoding.TextUnmarshaler
tu, v = d.indirect(v)
if tu != nil {
return tu.UnmarshalText([]byte(d.data.Encode()))
}
// TODO support map, slice as root
if v.Kind() != reflect.Struct {
return &BindTypeError{Value: d.data.Encode(), Type: v.Type()}
}
tv := v.Type()
for i := 0; i < tv.NumField(); i++ {
fv := v.Field(i)
field := tv.Field(i)
to := parseTag(field.Tag.Get(_tagID))
if to.Name == "-" {
continue
}
if af, ok := d.assignFuncs[to.Name]; ok {
if err := af(fv, tagOpt{}); err != nil {
return err
}
continue
}
if !strings.HasPrefix(to.Name, _queryPrefix) {
continue
}
to.Name = to.Name[len(_queryPrefix):]
if err := d.value(fv, "", to); err != nil {
return err
}
}
return nil
}
func combinekey(prefix string, to tagOpt) string {
key := to.Name
if prefix != "" {
key = prefix + "." + key
}
return key
}
func (d *decodeState) value(v reflect.Value, prefix string, to tagOpt) (err error) {
key := combinekey(prefix, to)
d.used[key] = true
var tu encoding.TextUnmarshaler
tu, v = d.indirect(v)
if tu != nil {
if val, ok := d.data[key]; ok {
return tu.UnmarshalText([]byte(val[0]))
}
if to.Default != "" {
return tu.UnmarshalText([]byte(to.Default))
}
return
}
switch v.Kind() {
case reflect.Bool:
err = d.valueBool(v, prefix, to)
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
err = d.valueInt64(v, prefix, to)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
err = d.valueUint64(v, prefix, to)
case reflect.Float32, reflect.Float64:
err = d.valueFloat64(v, prefix, to)
case reflect.String:
err = d.valueString(v, prefix, to)
case reflect.Slice:
err = d.valueSlice(v, prefix, to)
case reflect.Struct:
err = d.valueStruct(v, prefix, to)
case reflect.Ptr:
if !d.hasKey(combinekey(prefix, to)) {
break
}
if !v.CanSet() {
break
}
nv := reflect.New(v.Type().Elem())
v.Set(nv)
err = d.value(nv, prefix, to)
}
return
}
func (d *decodeState) hasKey(key string) bool {
for k := range d.data {
if strings.HasPrefix(k, key+".") || k == key {
return true
}
}
return false
}
func (d *decodeState) valueBool(v reflect.Value, prefix string, to tagOpt) error {
key := combinekey(prefix, to)
val := d.data.Get(key)
if val == "" {
if to.Default == "" {
return nil
}
val = to.Default
}
return d.setBool(v, val)
}
func (d *decodeState) setBool(v reflect.Value, val string) error {
bval, err := strconv.ParseBool(val)
if err != nil {
return &BindTypeError{Value: val, Type: v.Type()}
}
v.SetBool(bval)
return nil
}
func (d *decodeState) valueInt64(v reflect.Value, prefix string, to tagOpt) error {
key := combinekey(prefix, to)
val := d.data.Get(key)
if val == "" {
if to.Default == "" {
return nil
}
val = to.Default
}
return d.setInt64(v, val)
}
func (d *decodeState) setInt64(v reflect.Value, val string) error {
ival, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return &BindTypeError{Value: val, Type: v.Type()}
}
v.SetInt(ival)
return nil
}
func (d *decodeState) valueUint64(v reflect.Value, prefix string, to tagOpt) error {
key := combinekey(prefix, to)
val := d.data.Get(key)
if val == "" {
if to.Default == "" {
return nil
}
val = to.Default
}
return d.setUint64(v, val)
}
func (d *decodeState) setUint64(v reflect.Value, val string) error {
uival, err := strconv.ParseUint(val, 10, 64)
if err != nil {
return &BindTypeError{Value: val, Type: v.Type()}
}
v.SetUint(uival)
return nil
}
func (d *decodeState) valueFloat64(v reflect.Value, prefix string, to tagOpt) error {
key := combinekey(prefix, to)
val := d.data.Get(key)
if val == "" {
if to.Default == "" {
return nil
}
val = to.Default
}
return d.setFloat64(v, val)
}
func (d *decodeState) setFloat64(v reflect.Value, val string) error {
fval, err := strconv.ParseFloat(val, 64)
if err != nil {
return &BindTypeError{Value: val, Type: v.Type()}
}
v.SetFloat(fval)
return nil
}
func (d *decodeState) valueString(v reflect.Value, prefix string, to tagOpt) error {
key := combinekey(prefix, to)
val := d.data.Get(key)
if val == "" {
if to.Default == "" {
return nil
}
val = to.Default
}
return d.setString(v, val)
}
func (d *decodeState) setString(v reflect.Value, val string) error {
v.SetString(val)
return nil
}
func (d *decodeState) valueSlice(v reflect.Value, prefix string, to tagOpt) error {
key := combinekey(prefix, to)
strs, ok := d.data[key]
if !ok {
strs = strings.Split(to.Default, ",")
}
if len(strs) == 0 {
return nil
}
et := v.Type().Elem()
var setFunc func(reflect.Value, string) error
switch et.Kind() {
case reflect.Bool:
setFunc = d.setBool
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
setFunc = d.setInt64
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
setFunc = d.setUint64
case reflect.Float32, reflect.Float64:
setFunc = d.setFloat64
case reflect.String:
setFunc = d.setString
default:
return &BindTypeError{Type: et, Value: strs[0]}
}
vals := reflect.MakeSlice(v.Type(), len(strs), len(strs))
for i, str := range strs {
if err := setFunc(vals.Index(i), str); err != nil {
return err
}
}
if v.CanSet() {
v.Set(vals)
}
return nil
}
func (d *decodeState) valueStruct(v reflect.Value, prefix string, to tagOpt) error {
tv := v.Type()
for i := 0; i < tv.NumField(); i++ {
fv := v.Field(i)
field := tv.Field(i)
fto := parseTag(field.Tag.Get(_tagID))
if fto.Name == "-" {
continue
}
if af, ok := d.assignFuncs[fto.Name]; ok {
if err := af(fv, tagOpt{}); err != nil {
return err
}
continue
}
if !strings.HasPrefix(fto.Name, _queryPrefix) {
continue
}
fto.Name = fto.Name[len(_queryPrefix):]
if err := d.value(fv, to.Name, fto); err != nil {
return err
}
}
return nil
}
func (d *decodeState) indirect(v reflect.Value) (encoding.TextUnmarshaler, reflect.Value) {
v0 := v
haveAddr := false
if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
haveAddr = true
v = v.Addr()
}
for {
if v.Kind() == reflect.Interface && !v.IsNil() {
e := v.Elem()
if e.Kind() == reflect.Ptr && !e.IsNil() && e.Elem().Kind() == reflect.Ptr {
haveAddr = false
v = e
continue
}
}
if v.Kind() != reflect.Ptr {
break
}
if v.Elem().Kind() != reflect.Ptr && v.CanSet() {
break
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
if v.Type().NumMethod() > 0 {
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
return u, reflect.Value{}
}
}
if haveAddr {
v = v0
haveAddr = false
} else {
v = v.Elem()
}
}
return nil, v
}

View File

@@ -0,0 +1,128 @@
package dsn
import (
"net/url"
"reflect"
"testing"
"time"
xtime "go-common/library/time"
)
type cfg1 struct {
Name string `dsn:"query.name"`
Def string `dsn:"query.def,hello"`
DefSlice []int `dsn:"query.defslice,1,2,3,4"`
Ignore string `dsn:"-"`
FloatNum float64 `dsn:"query.floatNum"`
}
type cfg2 struct {
Timeout xtime.Duration `dsn:"query.timeout"`
}
type cfg3 struct {
Username string `dsn:"username"`
Timeout xtime.Duration `dsn:"query.timeout"`
}
type cfg4 struct {
Timeout xtime.Duration `dsn:"query.timeout,1s"`
}
func TestDecodeQuery(t *testing.T) {
type args struct {
query url.Values
v interface{}
assignFuncs map[string]assignFunc
}
tests := []struct {
name string
args args
want url.Values
cfg interface{}
wantErr bool
}{
{
name: "test generic",
args: args{
query: url.Values{
"name": {"hello"},
"Ignore": {"test"},
"floatNum": {"22.33"},
"adb": {"123"},
},
v: &cfg1{},
},
want: url.Values{
"Ignore": {"test"},
"adb": {"123"},
},
cfg: &cfg1{
Name: "hello",
Def: "hello",
DefSlice: []int{1, 2, 3, 4},
FloatNum: 22.33,
},
},
{
name: "test go-common/library/time",
args: args{
query: url.Values{
"timeout": {"1s"},
},
v: &cfg2{},
},
want: url.Values{},
cfg: &cfg2{xtime.Duration(time.Second)},
},
{
name: "test empty go-common/library/time",
args: args{
query: url.Values{},
v: &cfg2{},
},
want: url.Values{},
cfg: &cfg2{},
},
{
name: "test go-common/library/time",
args: args{
query: url.Values{},
v: &cfg4{},
},
want: url.Values{},
cfg: &cfg4{xtime.Duration(time.Second)},
},
{
name: "test build-in value",
args: args{
query: url.Values{
"timeout": {"1s"},
},
v: &cfg3{},
assignFuncs: map[string]assignFunc{"username": stringsAssignFunc("hello")},
},
want: url.Values{},
cfg: &cfg3{
Timeout: xtime.Duration(time.Second),
Username: "hello",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := bindQuery(tt.args.query, tt.args.v, tt.args.assignFuncs)
if (err != nil) != tt.wantErr {
t.Errorf("DecodeQuery() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("DecodeQuery() = %v, want %v", got, tt.want)
}
if !reflect.DeepEqual(tt.args.v, tt.cfg) {
t.Errorf("DecodeQuery() = %v, want %v", tt.args.v, tt.cfg)
}
})
}
}

37
library/conf/env/BUILD vendored Normal file
View File

@@ -0,0 +1,37 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_library(
name = "go_default_library",
srcs = ["env.go"],
importpath = "go-common/library/conf/env",
tags = ["automanaged"],
visibility = ["//visibility:public"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)
go_test(
name = "go_default_test",
srcs = ["env_test.go"],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
)

92
library/conf/env/env.go vendored Normal file
View File

@@ -0,0 +1,92 @@
// Package env get env & app config, all the public field must after init()
// finished and flag.Parse().
package env
import (
"flag"
"os"
)
// deploy env.
const (
DeployEnvDev = "dev"
DeployEnvFat1 = "fat1"
DeployEnvUat = "uat"
DeployEnvPre = "pre"
DeployEnvProd = "prod"
)
// env default value.
const (
// env
_region = "sh"
_zone = "sh001"
_deployEnv = "dev"
)
// env configuration.
var (
// Region avaliable region where app at.
Region string
// Zone avaliable zone where app at.
Zone string
// Hostname machine hostname.
Hostname string
// DeployEnv deploy env where app at.
DeployEnv string
// IP FIXME(haoguanwei) #240
IP = os.Getenv("POD_IP")
// AppID is global unique application id, register by service tree.
// such as main.arch.disocvery.
AppID string
// Color is the identification of different experimental group in one caster cluster.
Color string
)
// app default value.
const (
_httpPort = "8000"
_gorpcPort = "8099"
_grpcPort = "9000"
)
// app configraution.
var (
// HTTPPort app listen http port.
HTTPPort string
// GORPCPort app listen gorpc port.
GORPCPort string
// GRPCPort app listen grpc port.
GRPCPort string
)
func init() {
var err error
if Hostname, err = os.Hostname(); err != nil || Hostname == "" {
Hostname = os.Getenv("HOSTNAME")
}
addFlag(flag.CommandLine)
}
func addFlag(fs *flag.FlagSet) {
// env
fs.StringVar(&Region, "region", defaultString("REGION", _region), "avaliable region. or use REGION env variable, value: sh etc.")
fs.StringVar(&Zone, "zone", defaultString("ZONE", _zone), "avaliable zone. or use ZONE env variable, value: sh001/sh002 etc.")
fs.StringVar(&DeployEnv, "deploy.env", defaultString("DEPLOY_ENV", _deployEnv), "deploy env. or use DEPLOY_ENV env variable, value: dev/fat1/uat/pre/prod etc.")
fs.StringVar(&AppID, "appid", os.Getenv("APP_ID"), "appid is global unique application id, register by service tree. or use APP_ID env variable.")
fs.StringVar(&Color, "deploy.color", os.Getenv("DEPLOY_COLOR"), "deploy.color is the identification of different experimental group.")
// app
fs.StringVar(&HTTPPort, "http.port", defaultString("DISCOVERY_HTTP_PORT", _httpPort), "app listen http port, default: 8000")
fs.StringVar(&GORPCPort, "gorpc.port", defaultString("DISCOVERY_GORPC_PORT", _gorpcPort), "app listen gorpc port, default: 8099")
fs.StringVar(&GRPCPort, "grpc.port", defaultString("DISCOVERY_GRPC_PORT", _grpcPort), "app listen grpc port, default: 9000")
}
func defaultString(env, value string) string {
v := os.Getenv(env)
if v == "" {
return value
}
return v
}

122
library/conf/env/env_test.go vendored Normal file
View File

@@ -0,0 +1,122 @@
package env
import (
"flag"
"fmt"
"os"
"testing"
)
func TestDefaultString(t *testing.T) {
v := defaultString("a", "test")
if v != "test" {
t.Fatal("v must be test")
}
if err := os.Setenv("a", "test1"); err != nil {
t.Fatal(err)
}
v = defaultString("a", "test")
if v != "test1" {
t.Fatal("v must be test1")
}
}
func TestEnv(t *testing.T) {
tests := []struct {
flag string
env string
def string
val *string
}{
{
"region",
"REGION",
_region,
&Region,
},
{
"zone",
"ZONE",
_zone,
&Zone,
},
{
"deploy.env",
"DEPLOY_ENV",
_deployEnv,
&DeployEnv,
},
{
"appid",
"APP_ID",
"",
&AppID,
},
{
"http.port",
"DISCOVERY_HTTP_PORT",
_httpPort,
&HTTPPort,
},
{
"gorpc.port",
"DISCOVERY_GORPC_PORT",
_gorpcPort,
&GORPCPort,
},
{
"grpc.port",
"DISCOVERY_GRPC_PORT",
_grpcPort,
&GRPCPort,
},
{
"deploy.color",
"DEPLOY_COLOR",
"",
&Color,
},
}
for _, test := range tests {
// flag set value
t.Run(fmt.Sprintf("%s: flag set", test.env), func(t *testing.T) {
fs := flag.NewFlagSet("", flag.ContinueOnError)
addFlag(fs)
err := fs.Parse([]string{fmt.Sprintf("-%s=%s", test.flag, "test")})
if err != nil {
t.Fatal(err)
}
if *test.val != "test" {
t.Fatal("val must be test")
}
})
// flag not set, env set
t.Run(fmt.Sprintf("%s: flag not set, env set", test.env), func(t *testing.T) {
*test.val = ""
os.Setenv(test.env, "test2")
fs := flag.NewFlagSet("", flag.ContinueOnError)
addFlag(fs)
err := fs.Parse([]string{})
if err != nil {
t.Fatal(err)
}
if *test.val != "test2" {
t.Fatal("val must be test")
}
})
// flag not set, env not set
t.Run(fmt.Sprintf("%s: flag not set, env not set", test.env), func(t *testing.T) {
*test.val = ""
os.Setenv(test.env, "")
fs := flag.NewFlagSet("", flag.ContinueOnError)
addFlag(fs)
err := fs.Parse([]string{})
if err != nil {
t.Fatal(err)
}
if *test.val != test.def {
t.Fatal("val must be test")
}
})
}
}

View File

@@ -0,0 +1,28 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
)
go_library(
name = "go_default_library",
srcs = ["flagvar.go"],
importpath = "go-common/library/conf/flagvar",
tags = ["automanaged"],
visibility = ["//visibility:public"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,18 @@
package flagvar
import (
"strings"
)
// StringVars []string implement flag.Value
type StringVars []string
func (s StringVars) String() string {
return strings.Join(s, ",")
}
// Set implement flag.Value
func (s *StringVars) Set(val string) error {
*s = append(*s, val)
return nil
}

View File

@@ -0,0 +1,80 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_library(
name = "go_default_library",
srcs = [
"client.go",
"default.go",
"file.go",
"helper.go",
"map.go",
"mock.go",
"sven.go",
"toml.go",
"value.go",
],
importpath = "go-common/library/conf/paladin",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//library/conf/env:go_default_library",
"//library/ecode:go_default_library",
"//library/log:go_default_library",
"//library/net/ip:go_default_library",
"//library/net/netutil:go_default_library",
"//vendor/github.com/BurntSushi/toml:go_default_library",
"//vendor/github.com/naoina/toml:go_default_library",
"//vendor/github.com/pkg/errors:go_default_library",
],
)
go_test(
name = "go_default_xtest",
srcs = [
"example_test.go",
"file_test.go",
"map_test.go",
"mock_test.go",
],
tags = ["automanaged"],
deps = [
"//library/conf/paladin:go_default_library",
"//vendor/github.com/naoina/toml:go_default_library",
"//vendor/github.com/stretchr/testify/assert:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)
go_test(
name = "go_default_test",
srcs = [
"sven_test.go",
"value_test.go",
],
embed = [":go_default_library"],
tags = ["automanaged"],
deps = [
"//library/conf/env:go_default_library",
"//vendor/github.com/naoina/toml:go_default_library",
"//vendor/github.com/stretchr/testify/assert:go_default_library",
],
)

View File

@@ -0,0 +1,24 @@
### paladin
##### Version 1.1.2
> 1.修改map key to lower case
##### Version 1.1.1
> 1.修复错误时会panic
##### Version 1.1.0
> 1.修正Var统一显示处理panic
> 2.修复本地文件路径问题
> 3.添加helper快捷处理default value
##### Version 1.0.2
> 1.修下unmarshal方法为toml
##### Version 1.0.1
> 1.default/map/value添加Unmarshal方法
##### Version 1.0.0
> 1.支持sven、file、mock配置读取
> 2.支持struct、paladin.Map对象解析
> 3.支持Set接口进行配置Reload
> 4.支持Watch自定义订阅key value变化

View File

@@ -0,0 +1,10 @@
# Owner
maojian
# Author
chenzhihui
# Reviewer
maojian
haoguanwei
lintanghui

View File

@@ -0,0 +1,10 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- chenzhihui
- maojian
reviewers:
- chenzhihui
- haoguanwei
- lintanghui
- maojian

View File

@@ -0,0 +1,86 @@
#### paladin
##### 项目简介
paladin 是一个config SDK客户端包括了sven、file、mock几个抽象功能方便使用本地文件或者sven配置中心并且集成了对象自动reload功能。
sven:
```
caster配置项
配置地址CONF_HOST: config.bilibili.co
配置版本CONF_VERSION: docker-1/server-1
配置路径CONF_PATH: /data/conf/app
配置TokenCONF_TOKEN: token
配置指定版本CONF_APPOINT: 27600
依赖环境变量:
TREE_ID/DEPLOY_ENV/ZONE/HOSTNAME/POD_IP
```
local files:
```
/data/app/msm-service -conf=/data/conf/app/msm-servie.toml
// or multi file
/data/app/msm-service -conf=/data/conf/app/
```
example:
```
type exampleConf struct {
Bool bool
Int int64
Float float64
String string
}
func (e *exampleConf) Set(text string) error {
var ec exampleConf
if err := toml.Unmarshal([]byte(text), &ec); err != nil {
return err
}
*e = ec
return nil
}
func ExampleClient() {
if err := paladin.Init(); err != nil {
panic(err)
}
var (
ec exampleConf
eo exampleConf
m paladin.TOML
strs []string
)
// config unmarshal
if err := paladin.Get("example.toml").UnmarshalTOML(&ec); err != nil {
panic(err)
}
// config setter
if err := paladin.Watch("example.toml", &ec); err != nil {
panic(err)
}
// paladin map
if err := paladin.Watch("example.toml", &m); err != nil {
panic(err)
}
s, err := m.Value("key").String()
b, err := m.Value("key").Bool()
i, err := m.Value("key").Int64()
f, err := m.Value("key").Float64()
// value slice
err = m.Value("strings").Slice(&strs)
// watch key
for event := range paladin.WatchEvent(context.TODO(), "key") {
fmt.Println(event)
}
}
```
##### 编译环境
- **请只用 Golang v1.8.x 以上版本编译执行**
##### 依赖包
> 1. github.com/naoina/toml
> 2. github.com/pkg/errors

View File

@@ -0,0 +1,49 @@
package paladin
import (
"context"
)
const (
// EventAdd config add event.
EventAdd EventType = iota
// EventUpdate config update event.
EventUpdate
// EventRemove config remove event.
EventRemove
)
// EventType is config event.
type EventType int
// Event is watch event.
type Event struct {
Event EventType
Key string
Value string
}
// Watcher is config watcher.
type Watcher interface {
WatchEvent(context.Context, ...string) <-chan Event
Close() error
}
// Setter is value setter.
type Setter interface {
Set(string) error
}
// Getter is value getter.
type Getter interface {
// Get a config value by a config key(may be a sven filename).
Get(string) *Value
// GetAll return all config key->value map.
GetAll() *Map
}
// Client is config client.
type Client interface {
Watcher
Getter
}

View File

@@ -0,0 +1,84 @@
package paladin
import (
"context"
"flag"
"go-common/library/log"
)
var (
// DefaultClient default client.
DefaultClient Client
confPath string
vars = make(map[string][]Setter) // NOTE: no thread safe
)
func init() {
flag.StringVar(&confPath, "conf", "", "default config path")
}
// Init init config client.
func Init() (err error) {
if confPath != "" {
DefaultClient, err = NewFile(confPath)
} else {
DefaultClient, err = NewSven()
}
if err != nil {
return
}
go func() {
for event := range DefaultClient.WatchEvent(context.Background()) {
if event.Event != EventUpdate && event.Event != EventAdd {
continue
}
if sets, ok := vars[event.Key]; ok {
for _, s := range sets {
if err := s.Set(event.Value); err != nil {
log.Error("paladin: vars:%v event:%v error(%v)", s, event, err)
}
}
}
}
}()
return
}
// Watch watch on a key. The configuration implements the setter interface, which is invoked when the configuration changes.
func Watch(key string, s Setter) error {
v := DefaultClient.Get(key)
str, err := v.Raw()
if err != nil {
return err
}
if err := s.Set(str); err != nil {
return err
}
vars[key] = append(vars[key], s)
return nil
}
// WatchEvent watch on multi keys. Events are returned when the configuration changes.
func WatchEvent(ctx context.Context, keys ...string) <-chan Event {
return DefaultClient.WatchEvent(ctx, keys...)
}
// Get return value by key.
func Get(key string) *Value {
return DefaultClient.Get(key)
}
// GetAll return all config map.
func GetAll() *Map {
return DefaultClient.GetAll()
}
// Keys return values key.
func Keys() []string {
return DefaultClient.GetAll().Keys()
}
// Close close watcher.
func Close() error {
return DefaultClient.Close()
}

View File

@@ -0,0 +1,112 @@
package paladin_test
import (
"context"
"fmt"
"go-common/library/conf/paladin"
"github.com/naoina/toml"
)
type exampleConf struct {
Bool bool
Int int64
Float float64
String string
Strings []string
}
func (e *exampleConf) Set(text string) error {
var ec exampleConf
if err := toml.Unmarshal([]byte(text), &ec); err != nil {
return err
}
*e = ec
return nil
}
// ExampleClient is a example client usage.
// exmaple.toml:
/*
bool = true
int = 100
float = 100.1
string = "text"
strings = ["a", "b", "c"]
*/
func ExampleClient() {
if err := paladin.Init(); err != nil {
panic(err)
}
var ec exampleConf
// var setter
if err := paladin.Watch("example.toml", &ec); err != nil {
panic(err)
}
if err := paladin.Get("example.toml").UnmarshalTOML(&ec); err != nil {
panic(err)
}
// use exampleConf
// watch event key
go func() {
for event := range paladin.WatchEvent(context.TODO(), "key") {
fmt.Println(event)
}
}()
}
// ExampleMap is a example map usage.
// exmaple.toml:
/*
bool = true
int = 100
float = 100.1
string = "text"
strings = ["a", "b", "c"]
[object]
string = "text"
bool = true
int = 100
float = 100.1
strings = ["a", "b", "c"]
*/
func ExampleMap() {
var (
m paladin.TOML
strs []string
)
// paladin toml
if err := paladin.Watch("example.toml", &m); err != nil {
panic(err)
}
// value string
s, err := m.Get("string").String()
if err != nil {
s = "default"
}
fmt.Println(s)
// value bool
b, err := m.Get("bool").Bool()
if err != nil {
b = false
}
fmt.Println(b)
// value int
i, err := m.Get("int").Int64()
if err != nil {
i = 100
}
fmt.Println(i)
// value float
f, err := m.Get("float").Float64()
if err != nil {
f = 100.1
}
fmt.Println(f)
// value slice
if err = m.Get("strings").Slice(&strs); err == nil {
fmt.Println(strs)
}
}

View File

@@ -0,0 +1,82 @@
package paladin
import (
"context"
"errors"
"io/ioutil"
"os"
"path"
"path/filepath"
)
var _ Client = &file{}
// file is file config client.
type file struct {
ch chan Event
values *Map
}
// NewFile new a config file client.
// conf = /data/conf/app/
// conf = /data/conf/app/xxx.toml
func NewFile(base string) (Client, error) {
// paltform slash
base = filepath.FromSlash(base)
fi, err := os.Stat(base)
if err != nil {
panic(err)
}
// dirs or file to paths
var paths []string
if fi.IsDir() {
files, err := ioutil.ReadDir(base)
if err != nil {
panic(err)
}
for _, file := range files {
if !file.IsDir() {
paths = append(paths, path.Join(base, file.Name()))
}
}
} else {
paths = append(paths, base)
}
// laod config file to values
values := make(map[string]*Value, len(paths))
for _, file := range paths {
if file == "" {
return nil, errors.New("paladin: path is empty")
}
b, err := ioutil.ReadFile(file)
if err != nil {
return nil, err
}
s := string(b)
values[path.Base(file)] = &Value{val: s, raw: s}
}
m := new(Map)
m.Store(values)
return &file{values: m, ch: make(chan Event, 10)}, nil
}
// Get return value by key.
func (f *file) Get(key string) *Value {
return f.values.Get(key)
}
// GetAll return value map.
func (f *file) GetAll() *Map {
return f.values
}
// WatchEvent watch multi key.
func (f *file) WatchEvent(ctx context.Context, key ...string) <-chan Event {
return f.ch
}
// Close close watcher.
func (f *file) Close() error {
close(f.ch)
return nil
}

View File

@@ -0,0 +1,67 @@
package paladin_test
import (
"io/ioutil"
"os"
"testing"
"go-common/library/conf/paladin"
"github.com/stretchr/testify/assert"
)
func TestNewFile(t *testing.T) {
// test data
path := "/tmp/test_conf/"
assert.Nil(t, os.MkdirAll(path, 0700))
assert.Nil(t, ioutil.WriteFile(path+"test.toml", []byte(`
text = "hello"
number = 100
slice = [1, 2, 3]
sliceStr = ["1", "2", "3"]
`), 0644))
// test client
cli, err := paladin.NewFile(path + "test.toml")
assert.Nil(t, err)
assert.NotNil(t, cli)
// test map
m := paladin.Map{}
text, err := cli.Get("test.toml").String()
assert.Nil(t, err)
assert.Nil(t, m.Set(text), "text")
s, err := m.Get("text").String()
assert.Nil(t, err)
assert.Equal(t, s, "hello", "text")
n, err := m.Get("number").Int64()
assert.Nil(t, err)
assert.Equal(t, n, int64(100), "number")
}
func TestNewFilePath(t *testing.T) {
// test data
path := "/tmp/test_conf/"
assert.Nil(t, os.MkdirAll(path, 0700))
assert.Nil(t, ioutil.WriteFile(path+"test.toml", []byte(`
text = "hello"
number = 100
`), 0644))
assert.Nil(t, ioutil.WriteFile(path+"abc.toml", []byte(`
text = "hello"
number = 100
`), 0644))
// test client
cli, err := paladin.NewFile(path)
assert.Nil(t, err)
assert.NotNil(t, cli)
// test map
m := paladin.Map{}
text, err := cli.Get("test.toml").String()
assert.Nil(t, err)
assert.Nil(t, m.Set(text), "text")
s, err := m.Get("text").String()
assert.Nil(t, err, s)
assert.Equal(t, s, "hello", "text")
n, err := m.Get("number").Int64()
assert.Nil(t, err, s)
assert.Equal(t, n, int64(100), "number")
}

View File

@@ -0,0 +1,76 @@
package paladin
import "time"
// Bool return bool value.
func Bool(v *Value, def bool) bool {
b, err := v.Bool()
if err != nil {
return def
}
return b
}
// Int return int value.
func Int(v *Value, def int) int {
i, err := v.Int()
if err != nil {
return def
}
return i
}
// Int32 return int32 value.
func Int32(v *Value, def int32) int32 {
i, err := v.Int32()
if err != nil {
return def
}
return i
}
// Int64 return int64 value.
func Int64(v *Value, def int64) int64 {
i, err := v.Int64()
if err != nil {
return def
}
return i
}
// Float32 return float32 value.
func Float32(v *Value, def float32) float32 {
f, err := v.Float32()
if err != nil {
return def
}
return f
}
// Float64 return float32 value.
func Float64(v *Value, def float64) float64 {
f, err := v.Float64()
if err != nil {
return def
}
return f
}
// String return string value.
func String(v *Value, def string) string {
s, err := v.String()
if err != nil {
return def
}
return s
}
// Duration parses a duration string. A duration string is a possibly signed sequence of decimal numbers
// each with optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func Duration(v *Value, def time.Duration) time.Duration {
dur, err := v.Duration()
if err != nil {
return def
}
return dur
}

View File

@@ -0,0 +1,55 @@
package paladin
import (
"strings"
"sync/atomic"
)
// keyNamed key naming to lower case.
func keyNamed(key string) string {
return strings.ToLower(key)
}
// Map is config map, key(filename) -> value(file).
type Map struct {
values atomic.Value
}
// Store sets the value of the Value to values map.
func (m *Map) Store(values map[string]*Value) {
dst := make(map[string]*Value, len(values))
for k, v := range values {
dst[keyNamed(k)] = v
}
m.values.Store(dst)
}
// Load returns the value set by the most recent Store.
func (m *Map) Load() map[string]*Value {
return m.values.Load().(map[string]*Value)
}
// Exist check if values map exist a key.
func (m *Map) Exist(key string) bool {
_, ok := m.Load()[keyNamed(key)]
return ok
}
// Get return get value by key.
func (m *Map) Get(key string) *Value {
v, ok := m.Load()[keyNamed(key)]
if ok {
return v
}
return &Value{}
}
// Keys return map keys.
func (m *Map) Keys() []string {
values := m.Load()
keys := make([]string, 0, len(values))
for key := range values {
keys = append(keys, key)
}
return keys
}

View File

@@ -0,0 +1,94 @@
package paladin_test
import (
"testing"
"go-common/library/conf/paladin"
"github.com/naoina/toml"
"github.com/stretchr/testify/assert"
)
type fruit struct {
Fruit []struct {
Name string
}
}
func (f *fruit) Set(text string) error {
return toml.Unmarshal([]byte(text), f)
}
func TestMap(t *testing.T) {
s := `
# kv
text = "hello"
number = 100
point = 100.1
boolean = true
KeyCase = "test"
# slice
numbers = [1, 2, 3]
strings = ["a", "b", "c"]
empty = []
[[fruit]]
name = "apple"
[[fruit]]
name = "banana"
# table
[database]
server = "192.168.1.1"
connection_max = 5000
enabled = true
[pool]
[pool.breaker]
xxx = "xxx"
`
m := paladin.Map{}
assert.Nil(t, m.Set(s), s)
str, err := m.Get("text").String()
assert.Nil(t, err)
assert.Equal(t, str, "hello", "text")
n, err := m.Get("number").Int64()
assert.Nil(t, err)
assert.Equal(t, n, int64(100), "number")
p, err := m.Get("point").Float64()
assert.Nil(t, err)
assert.Equal(t, p, 100.1, "point")
b, err := m.Get("boolean").Bool()
assert.Nil(t, err)
assert.Equal(t, b, true, "boolean")
// key lower case
lb, err := m.Get("Boolean").Bool()
assert.Nil(t, err)
assert.Equal(t, lb, true, "boolean")
lt, err := m.Get("KeyCase").String()
assert.Nil(t, err)
assert.Equal(t, lt, "test", "key case")
var sliceInt []int64
err = m.Get("numbers").Slice(&sliceInt)
assert.Nil(t, err)
assert.Equal(t, sliceInt, []int64{1, 2, 3})
var sliceStr []string
err = m.Get("strings").Slice(&sliceStr)
assert.Nil(t, err)
assert.Equal(t, sliceStr, []string{"a", "b", "c"})
err = m.Get("strings").Slice(&sliceStr)
assert.Nil(t, err)
assert.Equal(t, sliceStr, []string{"a", "b", "c"})
// errors
err = m.Get("strings").Slice(sliceInt)
assert.NotNil(t, err)
err = m.Get("strings").Slice(&sliceInt)
assert.NotNil(t, err)
var obj struct {
Name string
}
err = m.Get("strings").Slice(obj)
assert.NotNil(t, err)
err = m.Get("strings").Slice(&obj)
assert.NotNil(t, err)
}

View File

@@ -0,0 +1,45 @@
package paladin
import (
"context"
)
var _ Client = &mock{}
// mock is mock config client.
type mock struct {
ch chan Event
values *Map
}
// NewMock new a config mock client.
func NewMock(vs map[string]string) Client {
values := make(map[string]*Value, len(vs))
for k, v := range vs {
values[k] = &Value{val: v, raw: v}
}
m := new(Map)
m.Store(values)
return &mock{values: m, ch: make(chan Event)}
}
// Get return value by key.
func (m *mock) Get(key string) *Value {
return m.values.Get(key)
}
// GetAll return value map.
func (m *mock) GetAll() *Map {
return m.values
}
// WatchEvent watch multi key.
func (m *mock) WatchEvent(ctx context.Context, key ...string) <-chan Event {
return m.ch
}
// Close close watcher.
func (m *mock) Close() error {
close(m.ch)
return nil
}

View File

@@ -0,0 +1,37 @@
package paladin_test
import (
"testing"
"go-common/library/conf/paladin"
"github.com/stretchr/testify/assert"
)
func TestMock(t *testing.T) {
cs := map[string]string{
"key_toml": `
key_bool = true
key_int = 100
key_float = 100.1
key_string = "text"
`,
}
cli := paladin.NewMock(cs)
// test vlaue
var m paladin.TOML
err := cli.Get("key_toml").Unmarshal(&m)
assert.Nil(t, err)
b, err := m.Get("key_bool").Bool()
assert.Nil(t, err)
assert.Equal(t, b, true)
i, err := m.Get("key_int").Int64()
assert.Nil(t, err)
assert.Equal(t, i, int64(100))
f, err := m.Get("key_float").Float64()
assert.Nil(t, err)
assert.Equal(t, f, float64(100.1))
s, err := m.Get("key_string").String()
assert.Nil(t, err)
assert.Equal(t, s, "text")
}

View File

@@ -0,0 +1,372 @@
package paladin
import (
"context"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"path"
"strconv"
"sync"
"time"
"go-common/library/conf/env"
"go-common/library/ecode"
"go-common/library/log"
xip "go-common/library/net/ip"
"go-common/library/net/netutil"
"github.com/pkg/errors"
)
const (
_apiGet = "http://%s/config/v2/get?%s"
_apiCheck = "http://%s/config/v2/check?%s"
_maxLoadRetries = 3
)
var (
_ Client = &sven{}
svenHost string
svenVersion string
svenPath string
svenToken string
svenAppoint string
svenTreeid string
_debug bool
)
func init() {
flag.StringVar(&svenHost, "conf_host", os.Getenv("CONF_HOST"), `config api host.`)
flag.StringVar(&svenVersion, "conf_version", os.Getenv("CONF_VERSION"), `app version.`)
flag.StringVar(&svenPath, "conf_path", os.Getenv("CONF_PATH"), `config file path.`)
flag.StringVar(&svenToken, "conf_token", os.Getenv("CONF_TOKEN"), `config token.`)
flag.StringVar(&svenAppoint, "conf_appoint", os.Getenv("CONF_APPOINT"), `config appoint.`)
flag.StringVar(&svenTreeid, "tree_id", os.Getenv("TREE_ID"), `tree id.`)
if env.DeployEnv == env.DeployEnvDev {
_debug = true
}
}
type watcher struct {
keys []string
ch chan Event
}
func newWatcher(keys []string) *watcher {
return &watcher{keys: keys, ch: make(chan Event, 5)}
}
func (w *watcher) HasKey(key string) bool {
if len(w.keys) == 0 {
return true
}
for _, k := range w.keys {
if k == key {
return true
}
}
return false
}
func (w *watcher) Handle(event Event) {
select {
case w.ch <- event:
default:
log.Error("paladin: discard event:%+v", event)
}
}
func (w *watcher) Chan() <-chan Event {
return w.ch
}
func (w *watcher) Close() {
close(w.ch)
}
// sven is sven config client.
type sven struct {
values *Map
wmu sync.RWMutex
watchers map[*watcher]struct{}
httpCli *http.Client
backoff *netutil.BackoffConfig
}
// NewSven new a config client.
func NewSven() (Client, error) {
s := &sven{
values: new(Map),
watchers: make(map[*watcher]struct{}),
httpCli: &http.Client{Timeout: 60 * time.Second},
backoff: &netutil.BackoffConfig{
MaxDelay: 5 * time.Second,
BaseDelay: 1.0 * time.Second,
Factor: 1.6,
Jitter: 0.2,
},
}
if err := s.checkEnv(); err != nil {
return nil, err
}
ver, err := s.load()
if err != nil {
return nil, err
}
go s.watchproc(ver)
return s, nil
}
func (s *sven) checkEnv() error {
if svenHost == "" || svenVersion == "" || svenPath == "" || svenToken == "" || svenTreeid == "" {
return fmt.Errorf("config env invalid. conf_host(%s) conf_version(%s) conf_path(%s) conf_token(%s) conf_appoint(%s) tree_id(%s)", svenHost, svenVersion, svenPath, svenToken, svenAppoint, svenTreeid)
}
return nil
}
// Get return value by key.
func (s *sven) Get(key string) *Value {
return s.values.Get(key)
}
// GetAll return value map.
func (s *sven) GetAll() *Map {
return s.values
}
// WatchEvent watch with the specified keys.
func (s *sven) WatchEvent(ctx context.Context, keys ...string) <-chan Event {
w := newWatcher(keys)
s.wmu.Lock()
s.watchers[w] = struct{}{}
s.wmu.Unlock()
return w.Chan()
}
// Close close watcher.
func (s *sven) Close() (err error) {
s.wmu.RLock()
for w := range s.watchers {
w.Close()
}
s.wmu.RUnlock()
return
}
func (s *sven) fireEvent(event Event) {
s.wmu.RLock()
for w := range s.watchers {
if w.HasKey(event.Key) {
w.Handle(event)
}
}
s.wmu.RUnlock()
}
func (s *sven) load() (ver int64, err error) {
var (
v *version
cs []*content
)
if v, err = s.check(-1); err != nil {
log.Error("paladin: s.check(-1) error(%v)", err)
return
}
for i := 0; i < _maxLoadRetries; i++ {
if cs, err = s.config(v); err == nil {
all := make(map[string]*Value, len(cs))
for _, v := range cs {
all[v.Name] = &Value{val: v.Config, raw: v.Config}
}
s.values.Store(all)
return v.Version, nil
}
log.Error("paladin: s.config(%v) error(%v)", ver, err)
time.Sleep(s.backoff.Backoff(i))
}
return 0, err
}
func (s *sven) watchproc(ver int64) {
var retry int
for {
v, err := s.check(ver)
if err != nil {
if ecode.NotModified.Equal(err) {
time.Sleep(time.Second)
continue
}
log.Error("paladin: s.check(%d) error(%v)", ver, err)
retry++
time.Sleep(s.backoff.Backoff(retry))
continue
}
cs, err := s.config(v)
if err != nil {
log.Error("paladin: s.config(%v) error(%v)", ver, err)
retry++
time.Sleep(s.backoff.Backoff(retry))
continue
}
all := s.values.Load()
news := make(map[string]*Value, len(cs))
for _, v := range cs {
if _, ok := all[v.Name]; !ok {
go s.fireEvent(Event{Event: EventAdd, Key: v.Name, Value: v.Config})
} else if v.Config != "" {
go s.fireEvent(Event{Event: EventUpdate, Key: v.Name, Value: v.Config})
} else {
go s.fireEvent(Event{Event: EventRemove, Key: v.Name, Value: v.Config})
}
news[v.Name] = &Value{val: v.Config, raw: v.Config}
}
for k, v := range all {
if _, ok := news[k]; !ok {
news[k] = v
}
}
s.values.Store(news)
ver = v.Version
retry = 0
}
}
type version struct {
Version int64 `json:"version"`
Diffs []int64 `json:"diffs"`
}
type config struct {
Version int64 `json:"version"`
Content string `json:"content"`
Md5 string `json:"md5"`
}
type content struct {
Cid int64 `json:"cid"`
Name string `json:"name"`
Config string `json:"config"`
}
func (s *sven) check(ver int64) (v *version, err error) {
params := newParams()
params.Set("version", strconv.FormatInt(ver, 10))
params.Set("appoint", svenAppoint)
var res struct {
Code int `json:"code"`
Data *version `json:"data"`
}
uri := fmt.Sprintf(_apiCheck, svenHost, params.Encode())
if _debug {
fmt.Printf("paladin: check(%d) uri(%s)\n", ver, uri)
}
req, err := http.NewRequest("GET", uri, nil)
if err != nil {
return
}
resp, err := s.httpCli.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
err = errors.Errorf("paladin: httpCli.GET(%s) error(%d)", params.Encode(), resp.StatusCode)
return
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return
}
if err = json.Unmarshal(b, &res); err != nil {
return
}
if ec := ecode.Int(res.Code); !ec.Equal(ecode.OK) {
err = ec
return
}
if res.Data == nil {
err = errors.Errorf("paladin: http version is nil. params(%s)", params.Encode())
return
}
v = res.Data
return
}
func (s *sven) config(ver *version) (cts []*content, err error) {
ids, _ := json.Marshal(ver.Diffs)
params := newParams()
params.Set("version", strconv.FormatInt(ver.Version, 10))
params.Set("ids", string(ids))
var res struct {
Code int `json:"code"`
Data *config `json:"data"`
}
uri := fmt.Sprintf(_apiGet, svenHost, params.Encode())
if _debug {
fmt.Printf("paladin: config(%+v) uri(%s)\n", ver, uri)
}
req, err := http.NewRequest("GET", uri, nil)
if err != nil {
return
}
resp, err := s.httpCli.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
err = errors.Errorf("paladin: httpCli.GET(%s) error(%d)", params.Encode(), resp.StatusCode)
return
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return
}
if err = json.Unmarshal(b, &res); err != nil {
return
}
if !ecode.Int(res.Code).Equal(ecode.OK) || res.Data == nil {
err = errors.Errorf("paladin: http config is nil. params(%s) ecode(%d)", params.Encode(), res.Code)
return
}
if err = json.Unmarshal([]byte(res.Data.Content), &cts); err != nil {
return
}
for _, c := range cts {
if err = ioutil.WriteFile(path.Join(svenPath, c.Name), []byte(c.Config), 0644); err != nil {
return
}
}
return
}
func newParams() url.Values {
params := url.Values{}
params.Set("service", serviceName())
params.Set("build", svenVersion)
params.Set("token", svenToken)
params.Set("hostname", env.Hostname)
params.Set("ip", ipAddr())
return params
}
func ipAddr() string {
if env.IP != "" {
return env.IP
}
return xip.InternalIP()
}
func serviceName() string {
return fmt.Sprintf("%s_%s_%s", svenTreeid, env.DeployEnv, env.Zone)
}

View File

@@ -0,0 +1,119 @@
package paladin
import (
"context"
"testing"
"time"
"go-common/library/conf/env"
"github.com/naoina/toml"
"github.com/stretchr/testify/assert"
)
type testObj struct {
Bool bool
Int int64
Float float64
String string
}
func (t *testObj) Set(text string) error {
return toml.Unmarshal([]byte(text), t)
}
type testConf struct {
Bool bool
Int int64
Float float64
String string
Object *testObj
}
func (t *testConf) Set(text string) error {
return toml.Unmarshal([]byte(text), t)
}
func TestSven(t *testing.T) {
svenHost = "config.bilibili.co"
svenVersion = "server-1"
svenPath = "/tmp"
svenToken = "1afe5efaf45e11e7b3f8c6cd4f230d8c"
svenAppoint = ""
svenTreeid = "2888"
env.Region = "sh"
env.Zone = "sh001"
env.Hostname = "test"
env.DeployEnv = "dev"
env.AppID = "main.common-arch.msm-service"
sven, err := NewSven()
assert.Nil(t, err)
testSvenMap(t, sven)
testSvenValue(t, sven)
testWatch(t, sven)
}
func testSvenMap(t *testing.T, cli Client) {
m := Map{}
text, err := cli.Get("test.toml").String()
assert.Nil(t, err)
assert.Nil(t, m.Set(text), text)
b, err := m.Get("bool").Bool()
assert.Nil(t, err)
assert.Equal(t, b, true, "bool")
// int64
i, err := m.Get("int").Int64()
assert.Nil(t, err)
assert.Equal(t, i, int64(100), "int64")
// float64
f, err := m.Get("float").Float64()
assert.Nil(t, err)
assert.Equal(t, f, 100.1, "float64")
// string
s, err := m.Get("string").String()
assert.Nil(t, err)
assert.Equal(t, s, "text", "string")
// error
n, err := m.Get("not_exsit").String()
assert.NotNil(t, err)
assert.Equal(t, n, "", "not_exsit")
obj := new(testObj)
text, err = m.Get("object").Raw()
assert.Nil(t, err)
assert.Nil(t, obj.Set(text))
assert.Equal(t, obj.Bool, true, "bool")
assert.Equal(t, obj.Int, int64(100), "int64")
assert.Equal(t, obj.Float, 100.1, "float64")
assert.Equal(t, obj.String, "text", "string")
}
func testSvenValue(t *testing.T, cli Client) {
v := new(testConf)
text, err := cli.Get("test.toml").Raw()
assert.Nil(t, err)
assert.Nil(t, v.Set(text))
assert.Equal(t, v.Bool, true, "bool")
assert.Equal(t, v.Int, int64(100), "int64")
assert.Equal(t, v.Float, 100.1, "float64")
assert.Equal(t, v.String, "text", "string")
assert.Equal(t, v.Object.Bool, true, "bool")
assert.Equal(t, v.Object.Int, int64(100), "int64")
assert.Equal(t, v.Object.Float, 100.1, "float64")
assert.Equal(t, v.Object.String, "text", "string")
}
func testWatch(t *testing.T, cli Client) {
ch := cli.WatchEvent(context.Background())
select {
case <-time.After(time.Second):
t.Log("watch timeout")
case e := <-ch:
s, err := cli.Get("static").String()
assert.Nil(t, err)
assert.Equal(t, s, e.Value, "watch value")
t.Logf("watch event:%+v", e)
}
}

View File

@@ -0,0 +1,68 @@
package paladin
import (
"reflect"
"strconv"
"github.com/naoina/toml"
"github.com/pkg/errors"
)
// TOML is toml map.
type TOML = Map
// Set set the map by value.
func (m *TOML) Set(text string) error {
if err := m.UnmarshalText([]byte(text)); err != nil {
return err
}
return nil
}
// UnmarshalText implemented toml.
func (m *TOML) UnmarshalText(text []byte) error {
raws := map[string]interface{}{}
if err := toml.Unmarshal(text, &raws); err != nil {
return err
}
values := map[string]*Value{}
for k, v := range raws {
k = keyNamed(k)
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Map:
b, err := toml.Marshal(v)
if err != nil {
return err
}
// NOTE: value is map[string]interface{}
values[k] = &Value{val: v, raw: string(b)}
case reflect.Slice:
raw := map[string]interface{}{
k: v,
}
b, err := toml.Marshal(raw)
if err != nil {
return err
}
// NOTE: value is []interface{}
values[k] = &Value{val: v, raw: string(b)}
case reflect.Bool:
b := v.(bool)
values[k] = &Value{val: b, raw: strconv.FormatBool(b)}
case reflect.Int64:
i := v.(int64)
values[k] = &Value{val: i, raw: strconv.FormatInt(i, 10)}
case reflect.Float64:
f := v.(float64)
values[k] = &Value{val: f, raw: strconv.FormatFloat(f, 'f', -1, 64)}
case reflect.String:
s := v.(string)
values[k] = &Value{val: s, raw: s}
default:
return errors.Errorf("UnmarshalTOML: unknown kind(%v)", rv.Kind())
}
}
m.Store(values)
return nil
}

View File

@@ -0,0 +1,170 @@
package paladin
import (
"encoding"
"reflect"
"time"
"github.com/BurntSushi/toml"
"github.com/pkg/errors"
)
// ErrNotExist value key not exist.
var (
ErrNotExist = errors.New("paladin: value key not exist")
ErrTypeAssertion = errors.New("paladin: value type assertion no match")
ErrDifferentTypes = errors.New("paladin: value different types")
)
// Value is config value, maybe a json/toml/ini/string file.
type Value struct {
val interface{}
slice interface{}
raw string
}
// Bool return bool value.
func (v *Value) Bool() (bool, error) {
if v.val == nil {
return false, ErrNotExist
}
b, ok := v.val.(bool)
if !ok {
return false, ErrTypeAssertion
}
return b, nil
}
// Int return int value.
func (v *Value) Int() (int, error) {
i, err := v.Int64()
if err != nil {
return 0, nil
}
return int(i), nil
}
// Int32 return int32 value.
func (v *Value) Int32() (int32, error) {
i, err := v.Int64()
if err != nil {
return 0, nil
}
return int32(i), nil
}
// Int64 return int64 value.
func (v *Value) Int64() (int64, error) {
if v.val == nil {
return 0, ErrNotExist
}
i, ok := v.val.(int64)
if !ok {
return 0, ErrTypeAssertion
}
return i, nil
}
// Float32 return float32 value.
func (v *Value) Float32() (float32, error) {
f, err := v.Float64()
if err != nil {
return 0.0, err
}
return float32(f), nil
}
// Float64 return float64 value.
func (v *Value) Float64() (float64, error) {
if v.val == nil {
return 0.0, ErrNotExist
}
f, ok := v.val.(float64)
if !ok {
return 0.0, ErrTypeAssertion
}
return f, nil
}
// String return string value.
func (v *Value) String() (string, error) {
if v.val == nil {
return "", ErrNotExist
}
s, ok := v.val.(string)
if !ok {
return "", ErrTypeAssertion
}
return s, nil
}
// Duration parses a duration string. A duration string is a possibly signed sequence of decimal numbers
// each with optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (v *Value) Duration() (time.Duration, error) {
s, err := v.String()
if err != nil {
return time.Duration(0), err
}
return time.ParseDuration(s)
}
// Raw return raw value.
func (v *Value) Raw() (string, error) {
if v.val == nil {
return "", ErrNotExist
}
return v.raw, nil
}
// Slice scan a slcie interface.
func (v *Value) Slice(dst interface{}) error {
// NOTE: val is []interface{}, slice is []type
if v.val == nil {
return ErrNotExist
}
rv := reflect.ValueOf(dst)
if rv.Kind() != reflect.Ptr || rv.Elem().Kind() != reflect.Slice {
return ErrDifferentTypes
}
el := rv.Elem()
kind := el.Type().Elem().Kind()
if v.slice == nil {
src, ok := v.val.([]interface{})
if !ok {
return ErrDifferentTypes
}
for _, s := range src {
if reflect.TypeOf(s).Kind() != kind {
return ErrTypeAssertion
}
el = reflect.Append(el, reflect.ValueOf(s))
}
v.slice = el.Interface()
rv.Elem().Set(el)
return nil
}
sv := reflect.ValueOf(v.slice)
if sv.Type().Elem().Kind() != kind {
return ErrTypeAssertion
}
rv.Elem().Set(sv)
return nil
}
// Unmarshal is the interface implemented by an object that can unmarshal a textual representation of itself.
func (v *Value) Unmarshal(un encoding.TextUnmarshaler) error {
text, err := v.Raw()
if err != nil {
return err
}
return un.UnmarshalText([]byte(text))
}
// UnmarshalTOML unmarhsal toml to struct.
func (v *Value) UnmarshalTOML(dst interface{}) error {
text, err := v.Raw()
if err != nil {
return err
}
return toml.Unmarshal([]byte(text), dst)
}

View File

@@ -0,0 +1,206 @@
package paladin
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
type testUnmarshler struct {
Text string
Int int
}
func TestValueUnmarshal(t *testing.T) {
s := `
int = 100
text = "hello"
`
v := Value{val: s, raw: s}
obj := new(testUnmarshler)
assert.Nil(t, v.UnmarshalTOML(obj))
// error
v = Value{val: nil, raw: ""}
assert.NotNil(t, v.UnmarshalTOML(obj))
}
func TestValue(t *testing.T) {
var tests = []struct {
in interface{}
out interface{}
}{
{
"text",
"text",
},
{
time.Duration(time.Second * 10),
"10s",
},
{
int64(100),
int64(100),
},
{
float64(100.1),
float64(100.1),
},
{
true,
true,
},
{
nil,
nil,
},
}
for _, test := range tests {
t.Run(fmt.Sprint(test.in), func(t *testing.T) {
v := Value{val: test.in, raw: fmt.Sprint(test.in)}
switch test.in.(type) {
case nil:
s, err := v.String()
assert.NotNil(t, err)
assert.Equal(t, s, "", test.in)
i, err := v.Int64()
assert.NotNil(t, err)
assert.Equal(t, i, int64(0), test.in)
f, err := v.Float64()
assert.NotNil(t, err)
assert.Equal(t, f, float64(0.0), test.in)
b, err := v.Bool()
assert.NotNil(t, err)
assert.Equal(t, b, false, test.in)
case string:
val, err := v.String()
assert.Nil(t, err)
assert.Equal(t, val, test.out.(string), test.in)
case int64:
val, err := v.Int()
assert.Nil(t, err)
assert.Equal(t, val, int(test.out.(int64)), test.in)
val32, err := v.Int32()
assert.Nil(t, err)
assert.Equal(t, val32, int32(test.out.(int64)), test.in)
val64, err := v.Int64()
assert.Nil(t, err)
assert.Equal(t, val64, test.out.(int64), test.in)
case float64:
val32, err := v.Float32()
assert.Nil(t, err)
assert.Equal(t, val32, float32(test.out.(float64)), test.in)
val64, err := v.Float64()
assert.Nil(t, err)
assert.Equal(t, val64, test.out.(float64), test.in)
case bool:
val, err := v.Bool()
assert.Nil(t, err)
assert.Equal(t, val, test.out.(bool), test.in)
case time.Duration:
v.val = test.out
val, err := v.Duration()
assert.Nil(t, err)
assert.Equal(t, val, test.in.(time.Duration), test.out)
}
})
}
}
func TestValueSlice(t *testing.T) {
var tests = []struct {
in interface{}
out interface{}
}{
{
nil,
nil,
},
{
[]interface{}{"a", "b", "c"},
[]string{"a", "b", "c"},
},
{
[]interface{}{1, 2, 3},
[]int64{1, 2, 3},
},
{
[]interface{}{1.1, 1.2, 1.3},
[]float64{1.1, 1.2, 1.3},
},
{
[]interface{}{true, false, true},
[]bool{true, false, true},
},
}
for _, test := range tests {
t.Run(fmt.Sprint(test.in), func(t *testing.T) {
v := Value{val: test.in, raw: fmt.Sprint(test.in)}
switch test.in.(type) {
case nil:
var s []string
assert.NotNil(t, v.Slice(&s))
case []string:
var s []string
assert.Nil(t, v.Slice(&s))
assert.Equal(t, s, test.out)
case []int64:
var s []int64
assert.Nil(t, v.Slice(&s))
assert.Equal(t, s, test.out)
case []float64:
var s []float64
assert.Nil(t, v.Slice(&s))
assert.Equal(t, s, test.out)
case []bool:
var s []bool
assert.Nil(t, v.Slice(&s))
assert.Equal(t, s, test.out)
}
})
}
}
func BenchmarkValueInt(b *testing.B) {
v := &Value{val: int64(100), raw: "100"}
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
v.Int64()
}
})
}
func BenchmarkValueFloat(b *testing.B) {
v := &Value{val: float64(100.1), raw: "100.1"}
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
v.Float64()
}
})
}
func BenchmarkValueBool(b *testing.B) {
v := &Value{val: true, raw: "true"}
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
v.Bool()
}
})
}
func BenchmarkValueString(b *testing.B) {
v := &Value{val: "text", raw: "text"}
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
v.String()
}
})
}
func BenchmarkValueSlice(b *testing.B) {
v := &Value{val: []interface{}{1, 2, 3}, raw: "100"}
b.RunParallel(func(pb *testing.PB) {
var slice []int64
for pb.Next() {
v.Slice(&slice)
}
})
}

18
library/container/BUILD Normal file
View File

@@ -0,0 +1,18 @@
package(default_visibility = ["//visibility:public"])
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [
":package-srcs",
"//library/container/pool:all-srcs",
"//library/container/queue/aqm:all-srcs",
],
tags = ["automanaged"],
)

View File

@@ -0,0 +1,49 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_test",
"go_library",
)
go_test(
name = "go_default_test",
srcs = [
"list_test.go",
"slice_test.go",
],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
deps = [
"//library/time:go_default_library",
"//vendor/github.com/stretchr/testify/assert:go_default_library",
],
)
go_library(
name = "go_default_library",
srcs = [
"list.go",
"pool.go",
"slice.go",
],
importpath = "go-common/library/container/pool",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = ["//library/time:go_default_library"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,16 @@
### pool
##### Version 1.0.4
> 1.优化context的设置超时逻辑减小锁时间
##### Version 1.0.3
> 1.优化单元测试
##### Version 1.0.2
> 1.完善Put方法的单元测试
##### Version 1.0.1
> 1.优化单元测试和修复expire的bug
##### Version 1.0.0
> 1.初始化pool库

View File

@@ -0,0 +1,9 @@
# Owner
zhapuyu
# Author
zhapuyu
# Reviewer
haoguanwei
maojian

View File

@@ -0,0 +1,11 @@
# See the OWNERS docs at https://go.k8s.io/owners
approvers:
- zhapuyu
labels:
- library
- library/container/pool
reviewers:
- haoguanwei
- maojian
- zhapuyu

View File

@@ -0,0 +1,13 @@
#### pool
##### 项目简介
通用连接池实现
##### 编译环境
- **请只用 Golang v1.8.x 以上版本编译执行**
##### 依赖包
> 1.公共包go-common

View File

@@ -0,0 +1,227 @@
package pool
import (
"container/list"
"context"
"io"
"sync"
"time"
)
var _ Pool = &List{}
// List .
type List struct {
// New is an application supplied function for creating and configuring a
// item.
//
// The item returned from new must not be in a special state
// (subscribed to pubsub channel, transaction started, ...).
New func(ctx context.Context) (io.Closer, error)
// mu protects fields defined below.
mu sync.Mutex
cond chan struct{}
closed bool
active int
// clean stale items
cleanerCh chan struct{}
// Stack of item with most recently used at the front.
idles list.List
// Config pool configuration
conf *Config
}
// NewList creates a new pool.
func NewList(c *Config) *List {
// check Config
if c == nil || c.Active < c.Idle {
panic("config nil or Idle Must <= Active")
}
// new pool
p := &List{conf: c}
p.cond = make(chan struct{})
p.startCleanerLocked(time.Duration(c.IdleTimeout))
return p
}
// Reload reload config.
func (p *List) Reload(c *Config) error {
p.mu.Lock()
p.startCleanerLocked(time.Duration(c.IdleTimeout))
p.conf = c
p.mu.Unlock()
return nil
}
// startCleanerLocked
func (p *List) startCleanerLocked(d time.Duration) {
if d <= 0 {
// if set 0, staleCleaner() will return directly
return
}
if d < time.Duration(p.conf.IdleTimeout) && p.cleanerCh != nil {
select {
case p.cleanerCh <- struct{}{}:
default:
}
}
// run only one, clean stale items.
if p.cleanerCh == nil {
p.cleanerCh = make(chan struct{}, 1)
go p.staleCleaner()
}
}
// staleCleaner clean stale items proc.
func (p *List) staleCleaner() {
ticker := time.NewTicker(100 * time.Millisecond)
for {
select {
case <-ticker.C:
case <-p.cleanerCh: // maxLifetime was changed or db was closed.
}
p.mu.Lock()
if p.closed || p.conf.IdleTimeout <= 0 {
p.mu.Unlock()
return
}
for i, n := 0, p.idles.Len(); i < n; i++ {
e := p.idles.Back()
if e == nil {
// no possible
break
}
ic := e.Value.(item)
if !ic.expired(time.Duration(p.conf.IdleTimeout)) {
// not need continue.
break
}
p.idles.Remove(e)
p.release()
p.mu.Unlock()
ic.c.Close()
p.mu.Lock()
}
p.mu.Unlock()
}
}
// Get returns a item from the idles List or
// get a new item.
func (p *List) Get(ctx context.Context) (io.Closer, error) {
p.mu.Lock()
if p.closed {
p.mu.Unlock()
return nil, ErrPoolClosed
}
for {
// get idles item.
for i, n := 0, p.idles.Len(); i < n; i++ {
e := p.idles.Front()
if e == nil {
break
}
ic := e.Value.(item)
p.idles.Remove(e)
p.mu.Unlock()
if !ic.expired(time.Duration(p.conf.IdleTimeout)) {
return ic.c, nil
}
ic.c.Close()
p.mu.Lock()
p.release()
}
// Check for pool closed before dialing a new item.
if p.closed {
p.mu.Unlock()
return nil, ErrPoolClosed
}
// new item if under limit.
if p.conf.Active == 0 || p.active < p.conf.Active {
newItem := p.New
p.active++
p.mu.Unlock()
c, err := newItem(ctx)
if err != nil {
p.mu.Lock()
p.release()
p.mu.Unlock()
c = nil
}
return c, err
}
if p.conf.WaitTimeout == 0 && !p.conf.Wait {
p.mu.Unlock()
return nil, ErrPoolExhausted
}
wt := p.conf.WaitTimeout
p.mu.Unlock()
// slowpath: reset context timeout
nctx := ctx
cancel := func() {}
if wt > 0 {
_, nctx, cancel = wt.Shrink(ctx)
}
select {
case <-nctx.Done():
cancel()
return nil, nctx.Err()
case <-p.cond:
}
cancel()
p.mu.Lock()
}
}
// Put put item into pool.
func (p *List) Put(ctx context.Context, c io.Closer, forceClose bool) error {
p.mu.Lock()
if !p.closed && !forceClose {
p.idles.PushFront(item{createdAt: nowFunc(), c: c})
if p.idles.Len() > p.conf.Idle {
c = p.idles.Remove(p.idles.Back()).(item).c
} else {
c = nil
}
}
if c == nil {
p.signal()
p.mu.Unlock()
return nil
}
p.release()
p.mu.Unlock()
return c.Close()
}
// Close releases the resources used by the pool.
func (p *List) Close() error {
p.mu.Lock()
idles := p.idles
p.idles.Init()
p.closed = true
p.active -= idles.Len()
p.mu.Unlock()
for e := idles.Front(); e != nil; e = e.Next() {
e.Value.(item).c.Close()
}
return nil
}
// release decrements the active count and signals waiters. The caller must
// hold p.mu during the call.
func (p *List) release() {
p.active--
p.signal()
}
func (p *List) signal() {
select {
default:
case p.cond <- struct{}{}:
}
}

View File

@@ -0,0 +1,322 @@
package pool
import (
"context"
"io"
"testing"
"time"
xtime "go-common/library/time"
"github.com/stretchr/testify/assert"
)
func TestListGetPut(t *testing.T) {
// new pool
config := &Config{
Active: 1,
Idle: 1,
IdleTimeout: xtime.Duration(90 * time.Second),
WaitTimeout: xtime.Duration(10 * time.Millisecond),
Wait: false,
}
pool := NewList(config)
pool.New = func(ctx context.Context) (io.Closer, error) {
return &closer{}, nil
}
// test Get Put
conn, err := pool.Get(context.TODO())
assert.Nil(t, err)
c1 := connection{pool: pool, c: conn}
c1.HandleNormal()
c1.Close()
}
func TestListPut(t *testing.T) {
var id = 0
type connID struct {
io.Closer
id int
}
config := &Config{
Active: 1,
Idle: 1,
IdleTimeout: xtime.Duration(1 * time.Second),
// WaitTimeout: xtime.Duration(10 * time.Millisecond),
Wait: false,
}
pool := NewList(config)
pool.New = func(ctx context.Context) (io.Closer, error) {
id = id + 1
return &connID{id: id, Closer: &closer{}}, nil
}
// test Put(ctx, conn, true)
conn, err := pool.Get(context.TODO())
assert.Nil(t, err)
conn1 := conn.(*connID)
// Put(ctx, conn, true) drop the connection.
pool.Put(context.TODO(), conn, true)
conn, err = pool.Get(context.TODO())
assert.Nil(t, err)
conn2 := conn.(*connID)
assert.NotEqual(t, conn1.id, conn2.id)
}
func TestListIdleTimeout(t *testing.T) {
var id = 0
type connID struct {
io.Closer
id int
}
config := &Config{
Active: 1,
Idle: 1,
// conn timeout
IdleTimeout: xtime.Duration(1 * time.Millisecond),
}
pool := NewList(config)
pool.New = func(ctx context.Context) (io.Closer, error) {
id = id + 1
return &connID{id: id, Closer: &closer{}}, nil
}
// test Put(ctx, conn, true)
conn, err := pool.Get(context.TODO())
assert.Nil(t, err)
conn1 := conn.(*connID)
// Put(ctx, conn, true) drop the connection.
pool.Put(context.TODO(), conn, false)
time.Sleep(5 * time.Millisecond)
// idletimeout and get new conn
conn, err = pool.Get(context.TODO())
assert.Nil(t, err)
conn2 := conn.(*connID)
assert.NotEqual(t, conn1.id, conn2.id)
}
func TestListContextTimeout(t *testing.T) {
// new pool
config := &Config{
Active: 1,
Idle: 1,
IdleTimeout: xtime.Duration(90 * time.Second),
WaitTimeout: xtime.Duration(10 * time.Millisecond),
Wait: false,
}
pool := NewList(config)
pool.New = func(ctx context.Context) (io.Closer, error) {
return &closer{}, nil
}
// test context timeout
ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
defer cancel()
conn, err := pool.Get(ctx)
assert.Nil(t, err)
_, err = pool.Get(ctx)
// context timeout error
assert.NotNil(t, err)
pool.Put(context.TODO(), conn, false)
_, err = pool.Get(ctx)
assert.Nil(t, err)
}
func TestListPoolExhausted(t *testing.T) {
// test pool exhausted
config := &Config{
Active: 1,
Idle: 1,
IdleTimeout: xtime.Duration(90 * time.Second),
// WaitTimeout: xtime.Duration(10 * time.Millisecond),
Wait: false,
}
pool := NewList(config)
pool.New = func(ctx context.Context) (io.Closer, error) {
return &closer{}, nil
}
ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
defer cancel()
conn, err := pool.Get(context.TODO())
assert.Nil(t, err)
_, err = pool.Get(ctx)
// config active == 1, so no avaliable conns make connection exhausted.
assert.NotNil(t, err)
pool.Put(context.TODO(), conn, false)
_, err = pool.Get(ctx)
assert.Nil(t, err)
}
func TestListStaleClean(t *testing.T) {
var id = 0
type connID struct {
io.Closer
id int
}
config := &Config{
Active: 1,
Idle: 1,
IdleTimeout: xtime.Duration(1 * time.Second),
// WaitTimeout: xtime.Duration(10 * time.Millisecond),
Wait: false,
}
pool := NewList(config)
pool.New = func(ctx context.Context) (io.Closer, error) {
id = id + 1
return &connID{id: id, Closer: &closer{}}, nil
}
conn, err := pool.Get(context.TODO())
assert.Nil(t, err)
conn1 := conn.(*connID)
pool.Put(context.TODO(), conn, false)
conn, err = pool.Get(context.TODO())
assert.Nil(t, err)
conn2 := conn.(*connID)
assert.Equal(t, conn1.id, conn2.id)
pool.Put(context.TODO(), conn, false)
// sleep more than idleTimeout
time.Sleep(2 * time.Second)
conn, err = pool.Get(context.TODO())
assert.Nil(t, err)
conn3 := conn.(*connID)
assert.NotEqual(t, conn1.id, conn3.id)
}
func BenchmarkList1(b *testing.B) {
config := &Config{
Active: 30,
Idle: 30,
IdleTimeout: xtime.Duration(90 * time.Second),
WaitTimeout: xtime.Duration(10 * time.Millisecond),
Wait: false,
}
pool := NewList(config)
pool.New = func(ctx context.Context) (io.Closer, error) {
return &closer{}, nil
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
conn, err := pool.Get(context.TODO())
if err != nil {
b.Error(err)
continue
}
c1 := connection{pool: pool, c: conn}
c1.HandleQuick()
c1.Close()
}
})
}
func BenchmarkList2(b *testing.B) {
config := &Config{
Active: 30,
Idle: 30,
IdleTimeout: xtime.Duration(90 * time.Second),
WaitTimeout: xtime.Duration(10 * time.Millisecond),
Wait: false,
}
pool := NewList(config)
pool.New = func(ctx context.Context) (io.Closer, error) {
return &closer{}, nil
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
conn, err := pool.Get(context.TODO())
if err != nil {
b.Error(err)
continue
}
c1 := connection{pool: pool, c: conn}
c1.HandleNormal()
c1.Close()
}
})
}
func BenchmarkPool3(b *testing.B) {
config := &Config{
Active: 30,
Idle: 30,
IdleTimeout: xtime.Duration(90 * time.Second),
WaitTimeout: xtime.Duration(10 * time.Millisecond),
Wait: false,
}
pool := NewList(config)
pool.New = func(ctx context.Context) (io.Closer, error) {
return &closer{}, nil
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
conn, err := pool.Get(context.TODO())
if err != nil {
b.Error(err)
continue
}
c1 := connection{pool: pool, c: conn}
c1.HandleSlow()
c1.Close()
}
})
}
func BenchmarkList4(b *testing.B) {
config := &Config{
Active: 30,
Idle: 30,
IdleTimeout: xtime.Duration(90 * time.Second),
// WaitTimeout: xtime.Duration(10 * time.Millisecond),
Wait: false,
}
pool := NewList(config)
pool.New = func(ctx context.Context) (io.Closer, error) {
return &closer{}, nil
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
conn, err := pool.Get(context.TODO())
if err != nil {
b.Error(err)
continue
}
c1 := connection{pool: pool, c: conn}
c1.HandleSlow()
c1.Close()
}
})
}
func BenchmarkList5(b *testing.B) {
config := &Config{
Active: 30,
Idle: 30,
IdleTimeout: xtime.Duration(90 * time.Second),
// WaitTimeout: xtime.Duration(10 * time.Millisecond),
Wait: true,
}
pool := NewList(config)
pool.New = func(ctx context.Context) (io.Closer, error) {
return &closer{}, nil
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
conn, err := pool.Get(context.TODO())
if err != nil {
b.Error(err)
continue
}
c1 := connection{pool: pool, c: conn}
c1.HandleSlow()
c1.Close()
}
})
}

Some files were not shown because too many files have changed in this diff Show More