Create & Init Project...

This commit is contained in:
2019-04-22 18:49:16 +08:00
commit fc4fa37393
25440 changed files with 4054998 additions and 0 deletions

35
app/interface/main/mcn/tool/cache/BUILD vendored Normal file
View File

@@ -0,0 +1,35 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
)
go_library(
name = "go_default_library",
srcs = ["cache.go"],
importpath = "go-common/app/interface/main/mcn/tool/cache",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//library/cache:go_default_library",
"//library/cache/memcache:go_default_library",
"//library/log:go_default_library",
"//library/net/metadata:go_default_library",
"//library/stat/prom:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,180 @@
package cache
import (
"context"
"fmt"
"go-common/library/cache"
"go-common/library/cache/memcache"
"go-common/library/log"
"go-common/library/net/metadata"
"go-common/library/stat/prom"
"time"
)
//DataLoader cache interface
type DataLoader interface {
Key() (key string)
Value() (value interface{})
// LoadValue return value need cache
// if err, nothing will cache
// if value == nil, and IsNullCached is true, empty will be cached
LoadValue(c context.Context) (value interface{}, err error)
Expire() time.Duration
Desc() string
}
// Get
// Delete
// Add
//MCWrapper wrapper for mc
type MCWrapper struct {
mc *memcache.Pool
cache *cache.Cache
// 是否缓存空值,防止缓存穿透
IsNullCached bool
}
// null definition
const (
IsNull = 1
NotNull = 0
)
type cacheValue struct {
Null int8 `json:"n"` // not 0 means null
Value interface{} `json:"v"`
}
//IsNull return true is it's null
func (s *cacheValue) IsNull() bool {
return s.Null != NotNull
}
//New new memcache wrapper
func New(mc *memcache.Pool) *MCWrapper {
return &MCWrapper{
mc: mc,
cache: cache.New(10, 1024),
}
}
func (m *MCWrapper) addRaw(c context.Context, data DataLoader, cacheV *cacheValue) (err error) {
if data == nil {
return
}
conn := m.mc.Get(c)
defer conn.Close()
key := data.Key()
item := &memcache.Item{Key: key, Object: cacheV, Expiration: int32(data.Expire() / time.Second), Flags: memcache.FlagJSON}
if err = conn.Set(item); err != nil {
actionDesc := "Add" + data.Desc()
prom.BusinessErrCount.Incr("mc:" + actionDesc)
log.Errorv(c, log.KV(actionDesc, fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
log.Info("Add key ok, key=%s, null=%d", key, cacheV.Null)
return
}
//Add add cache data
func (m *MCWrapper) Add(c context.Context, data DataLoader) (err error) {
var cacheV = &cacheValue{
Value: data.Value(),
}
return m.addRaw(c, data, cacheV)
}
//Delete delete cache data
func (m *MCWrapper) Delete(c context.Context, data DataLoader) (err error) {
conn := m.mc.Get(c)
defer conn.Close()
key := data.Key()
if err = conn.Delete(key); err != nil {
if err == memcache.ErrNotFound {
err = nil
return
}
actionDesc := "Del" + data.Desc()
prom.BusinessErrCount.Incr("mc:" + actionDesc)
log.Errorv(c, log.KV(actionDesc, fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
return
}
//Get get data
func (m *MCWrapper) Get(c context.Context, data DataLoader) (err error) {
_, err = m.getRaw(c, data)
return
}
func (m *MCWrapper) getRaw(c context.Context, data DataLoader) (v *cacheValue, err error) {
conn := m.mc.Get(c)
defer conn.Close()
key := data.Key()
value, err := conn.Get(key)
if err != nil {
if err == memcache.ErrNotFound {
err = nil
return
}
actionDesc := "Cache" + data.Desc()
prom.BusinessErrCount.Incr("mc:" + actionDesc)
log.Errorv(c, log.KV(actionDesc, fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
var cacheV = cacheValue{
Value: data.Value(),
}
err = conn.Scan(value, &cacheV)
if err != nil {
actionDesc := "Cache" + data.Desc()
prom.BusinessErrCount.Incr("mc:" + actionDesc)
log.Errorv(c, log.KV(actionDesc, fmt.Sprintf("%+v", err)), log.KV("key", key))
return
}
v = &cacheV
return
}
//GetOrLoad get from cache, if not found, then call data.LoadValue to load
func (m *MCWrapper) GetOrLoad(c context.Context, data DataLoader) (err error) {
var v *cacheValue
v, err = m.getRaw(c, data)
if err != nil {
return
}
if v != nil && !v.IsNull() {
prom.CacheHit.Incr(data.Desc())
return
}
// 没有找到对应的缓存,需求去拉取
prom.CacheMiss.Incr(data.Desc())
res, err := data.LoadValue(c)
if err != nil {
return
}
// 没有查到值,并且不缓存空值
if res == nil && !m.IsNullCached {
return
}
var cacheV = &cacheValue{
Value: res,
}
if res == nil {
cacheV.Null = IsNull
}
m.cache.Save(func() {
m.addRaw(metadata.WithContext(c), data, cacheV)
})
return
}

View File

@@ -0,0 +1,48 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_test",
"go_library",
)
go_test(
name = "go_default_test",
srcs = [
"http_client_test.go",
"sqltool_test.go",
],
embed = [":go_default_library"],
tags = ["automanaged"],
deps = ["//library/log:go_default_library"],
)
go_library(
name = "go_default_library",
srcs = [
"http_client.go",
"sqltool.go",
],
importpath = "go-common/app/interface/main/mcn/tool/datacenter",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//library/log:go_default_library",
"//library/stat:go_default_library",
"//vendor/github.com/pkg/errors:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,256 @@
package datacenter
import (
"bytes"
"context"
"crypto/md5"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strings"
"time"
"go-common/library/log"
pkgerr "github.com/pkg/errors"
"go-common/library/stat"
"strconv"
)
/*
访问数据平台的http client处理了签名、接口监控等
*/
//ClientConfig client config
type ClientConfig struct {
Key string
Secret string
Dial time.Duration
Timeout time.Duration
KeepAlive time.Duration
}
//New new client
func New(c *ClientConfig) *HttpClient {
return &HttpClient{
client: &http.Client{},
conf: c,
}
}
//HttpClient http client
type HttpClient struct {
client *http.Client
conf *ClientConfig
Debug bool
}
//Response response
type Response struct {
Code int `json:"code"`
Msg string `json:"msg"`
Result interface{} `json:"result"`
}
const (
keyAppKey = "appKey"
keyAppID = "apiId"
keyTimeStamp = "timestamp"
keySign = "sign"
keySignMethod = "signMethod"
keyVersion = "version"
//TimeStampFormat time format in second
TimeStampFormat = "2006-01-02 15:04:05"
)
var (
clientStats = stat.HTTPClient
)
// Get issues a GET to the specified URL.
func (client *HttpClient) Get(c context.Context, uri string, params url.Values, res interface{}) (err error) {
req, err := client.NewRequest(http.MethodGet, uri, params)
if err != nil {
return
}
return client.Do(c, req, res)
}
// NewRequest new http request with method, uri, ip, values and headers.
// TODO(zhoujiahui): param realIP should be removed later.
func (client *HttpClient) NewRequest(method, uri string, params url.Values) (req *http.Request, err error) {
signStr, err := client.sign(params)
if err != nil {
err = pkgerr.Wrapf(err, "uri:%s,params:%v", uri, params)
return
}
params.Add(keySign, signStr)
enc := params.Encode()
ru := uri
if enc != "" {
ru = uri + "?" + enc
}
if method == http.MethodGet {
req, err = http.NewRequest(http.MethodGet, ru, nil)
} else {
req, err = http.NewRequest(http.MethodPost, uri, strings.NewReader(enc))
}
if err != nil {
err = pkgerr.Wrapf(err, "method:%s,uri:%s", method, ru)
return
}
const (
_contentType = "Content-Type"
_urlencoded = "application/x-www-form-urlencoded"
)
if method == http.MethodPost {
req.Header.Set(_contentType, _urlencoded)
}
return
}
// Do sends an HTTP request and returns an HTTP json response.
func (client *HttpClient) Do(c context.Context, req *http.Request, res interface{}, v ...string) (err error) {
var bs []byte
if bs, err = client.Raw(c, req, v...); err != nil {
return
}
if res != nil {
if err = json.Unmarshal(bs, res); err != nil {
err = pkgerr.Wrapf(err, "host:%s, url:%s, response:%s", req.URL.Host, realURL(req), string(bs))
}
}
return
}
//Raw get from url
func (client *HttpClient) Raw(c context.Context, req *http.Request, v ...string) (bs []byte, err error) {
var resp *http.Response
var uri = fmt.Sprintf("%s://%s%s", req.URL.Scheme, req.Host, req.URL.Path)
var now = time.Now()
var code string
defer func() {
clientStats.Timing(uri, int64(time.Since(now)/time.Millisecond))
if code != "" {
clientStats.Incr(uri, code)
}
}()
req = req.WithContext(c)
if resp, err = client.client.Do(req); err != nil {
err = pkgerr.Wrapf(err, "host:%s, url:%s", req.URL.Host, realURL(req))
code = "failed"
return
}
defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
err = pkgerr.Errorf("incorrect http status:%d host:%s, url:%s", resp.StatusCode, req.URL.Host, realURL(req))
code = strconv.Itoa(resp.StatusCode)
return
}
if bs, err = readAll(resp.Body, 16*1024); err != nil {
err = pkgerr.Wrapf(err, "host:%s, url:%s", req.URL.Host, realURL(req))
return
}
if client.Debug {
log.Info("reqeust: host:%s, url:%s, response body:%s", req.URL.Host, realURL(req), string(bs))
}
return
}
// sign calc appkey and appsecret sign.
// see http://info.bilibili.co/pages/viewpage.action?pageId=5410881#id-%E6%95%B0%E6%8D%AE%E7%9B%98%EF%BC%8D%E5%AE%89%E5%85%A8%E8%AE%A4%E8%AF%81-%E4%BA%8C%E7%AD%BE%E5%90%8D%E7%AE%97%E6%B3%95
func (client *HttpClient) sign(params url.Values) (sign string, err error) {
key := client.conf.Key
secret := client.conf.Secret
if params == nil {
params = url.Values{}
}
params.Set(keyAppKey, key)
params.Set(keyVersion, "1.0")
if params.Get(keyTimeStamp) == "" {
params.Set(keyTimeStamp, time.Now().Format(TimeStampFormat))
}
params.Set(keySignMethod, "md5")
var needSignParams = url.Values{}
needSignParams.Add(keyAppKey, key)
needSignParams.Add(keyTimeStamp, params.Get(keyTimeStamp))
needSignParams.Add(keyVersion, params.Get(keyVersion))
//tmp := params.Encode()
var valueMap = map[string][]string(needSignParams)
var buf bytes.Buffer
// 开头与结尾加secret
buf.Write([]byte(secret))
keys := make([]string, 0, len(valueMap))
for k := range valueMap {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
vs := valueMap[k]
prefix := k
buf.WriteString(prefix)
for _, v := range vs {
buf.WriteString(v)
break
}
}
buf.Write([]byte(secret))
var md5 = md5.New()
md5.Write(buf.Bytes())
sign = fmt.Sprintf("%X", md5.Sum(nil))
return
}
// readAll reads from r until an error or EOF and returns the data it read
// from the internal buffer allocated with a specified capacity.
func readAll(r io.Reader, capacity int64) (b []byte, err error) {
buf := bytes.NewBuffer(make([]byte, 0, capacity))
// If the buffer overflows, we will get bytes.ErrTooLarge.
// Return that as an error. Any other panic remains.
defer func() {
e := recover()
if e == nil {
return
}
if panicErr, ok := e.(error); ok && panicErr == bytes.ErrTooLarge {
err = panicErr
} else {
panic(e)
}
}()
_, err = buf.ReadFrom(r)
return buf.Bytes(), err
}
// realUrl return url with http://host/params.
func realURL(req *http.Request) string {
if req.Method == http.MethodGet {
return req.URL.String()
} else if req.Method == http.MethodPost {
ru := req.URL.Path
if req.Body != nil {
rd, ok := req.Body.(io.Reader)
if ok {
buf := bytes.NewBuffer([]byte{})
buf.ReadFrom(rd)
ru = ru + "?" + buf.String()
}
}
return ru
}
return req.URL.Path
}
// SetTransport set client transport
func (client *HttpClient) SetTransport(t http.RoundTripper) {
client.client.Transport = t
}

View File

@@ -0,0 +1,81 @@
package datacenter
import (
"context"
"crypto/md5"
"fmt"
"go-common/library/log"
"net/url"
"os"
"testing"
)
func TestMain(m *testing.M) {
log.Init(nil)
os.Exit(m.Run())
}
func TestSign(t *testing.T) {
var params = url.Values{
"timestamp": {"2018-11-19 18:50:28"},
"version": {"1.0"},
}
var result = "appKey12345timestamp2018-11-19 18:50:28version1.0"
var md5hash = md5.New()
var conf = ClientConfig{
Key: "12345", Secret: "56473",
}
md5hash.Write([]byte(conf.Secret + result + conf.Secret))
var signString = fmt.Sprintf("%X", md5hash.Sum(nil))
t.Logf("sign string=%s", signString)
c := New(&conf)
s, e := c.sign(params)
if e != nil {
t.Errorf("err happend, %+v", e)
t.FailNow()
return
}
if s != signString {
t.Logf("fail, expect=%s, get=%s", signString, s)
t.FailNow()
return
}
}
func TestAPI(t *testing.T) {
var q = &Query{}
q.Select("id, day").
Where(
ConditionMapType{"day": ConditionLte("2018-10-28")}).
Limit(20, 0).Order("day")
var params = url.Values{
//"timestamp": {"2018-11-19 18:50:28"},
//"version": {"1.0"},
"query": {q.String()},
}
var conf = ClientConfig{
Key: "b9739fc84d087c4b3c1aa297d01999e6", Secret: "5018928e83a23c0cc9773f2571de01e5",
}
c := New(&conf)
var res = &Response{}
var realResult struct {
Result []struct {
ID int `json:"id"`
Day string `json:"day"`
}
}
var middle interface{} = &realResult.Result
res.Result = middle
c.Debug = true
var e = c.Get(context.Background(), "http://berserker.bilibili.co/avenger/api/151/query", params, res)
if e != nil {
t.Errorf("err=%+v", e)
t.FailNow()
}
t.Logf("result=%+v", res)
for k, v := range realResult.Result {
t.Logf("k=%v, v=%+v", k, v)
}
}

View File

@@ -0,0 +1,233 @@
package datacenter
import (
"encoding/json"
"fmt"
"go-common/library/log"
"strings"
"text/scanner"
)
// operator
const (
opIn = "in"
opNin = "nin"
opLike = "like"
opLte = "lte" // <=
opLt = "lt" // <
opGte = "gte" // >=
opGt = "gt" // >
opNull = "null"
)
// value for Null operator
const (
IsNull = 1
IsNotNull = -1
)
// value for sort
const (
Desc = -1
Asc = 1
)
//ConditionMapType condition map
type ConditionMapType map[string]map[string]interface{}
//ConditionType condition's in map
type ConditionType map[string]interface{}
//ConditionIn in
func ConditionIn(v ...interface{}) ConditionType {
return ConditionType{
opIn: v,
}
}
func conditionHelper(k string, v interface{}) ConditionType {
return ConditionType{
k: v,
}
}
//ConditionLte <=
func ConditionLte(v interface{}) ConditionType {
return conditionHelper(opLte, v)
}
//ConditionLt <
func ConditionLt(v interface{}) ConditionType {
return conditionHelper(opLt, v)
}
//ConditionGte >=
func ConditionGte(v interface{}) ConditionType {
return conditionHelper(opGte, v)
}
//ConditionGt >
func ConditionGt(v interface{}) ConditionType {
return conditionHelper(opGt, v)
}
//SortType sort
type SortType map[string]int
//Query query
type Query struct {
selection []map[string]string
// <field, <operator, value> >
where map[string]map[string]interface{}
sort map[string]int
limit map[string]int
err error
}
const (
keyField = "name"
keyAs = "as"
)
func makeField(field string) map[string]string {
return map[string]string{keyField: field}
}
func makeFieldAs(field, as string) map[string]string {
return map[string]string{keyField: field, keyAs: as}
}
//Select select fields, use similar as sql
func (q *Query) Select(fields string) *Query {
if q.err != nil {
return q
}
var fieldsAll = strings.Split(fields, ",")
for _, v := range fieldsAll {
var s scanner.Scanner
s.Init(strings.NewReader(v))
var tokens []string
for tok := s.Scan(); tok != scanner.EOF; tok = s.Scan() {
txt := s.TokenText()
tokens = append(tokens, txt)
}
switch len(tokens) {
case 1:
if tokens[0] == "*" {
q.selection = []map[string]string{}
return q
}
q.selection = append(q.selection, makeField(tokens[0]))
case 2:
q.selection = append(q.selection, makeFieldAs(tokens[0], tokens[1]))
case 3:
q.selection = append(q.selection, makeFieldAs(tokens[0], tokens[2]))
}
}
return q
}
//Where where condition, see test for examples
func (q *Query) Where(conditions ...ConditionMapType) *Query {
if q.err != nil {
return q
}
if q.where == nil {
q.where = make(ConditionMapType, len(conditions))
}
for _, mapData := range conditions {
for k1, v1 := range mapData {
if q.where[k1] == nil {
q.where[k1] = make(map[string]interface{})
}
// combine all pair of map[string]interface{}(v1) into q.where[k1]
for k2, v2 := range v1 {
q.where[k1][k2] = v2
}
}
}
return q
}
//Order order field, use similar as sql
func (q *Query) Order(sort string) *Query {
if q.err != nil {
return q
}
var fields = strings.Split(sort, ",")
if q.sort == nil {
q.sort = make(map[string]int, len(fields))
}
for _, v := range fields {
var s scanner.Scanner
s.Init(strings.NewReader(v))
var tokens []string
for tok := s.Scan(); tok != scanner.EOF; tok = s.Scan() {
txt := s.TokenText()
tokens = append(tokens, txt)
}
switch len(tokens) {
case 1:
q.sort[tokens[0]] = Asc
case 2:
var order = Asc
switch strings.ToLower(tokens[1]) {
case "asc":
order = Asc
case "desc":
order = Desc
}
q.sort[tokens[0]] = order
default:
q.err = fmt.Errorf("parse order fail, [%s]", sort)
log.Error("%s", q.err)
return q
}
}
return q
}
//Limit limit, same as sql
func (q *Query) Limit(limit, offset int) *Query {
if q.err != nil {
return q
}
if q.limit == nil {
q.limit = make(map[string]int, 2)
}
q.limit["limit"] = limit
q.limit["skip"] = offset
return q
}
//String to string
func (q *Query) String() (res string) {
if q.err != nil {
return q.err.Error()
}
var resultMap = map[string]interface{}{}
if q.selection != nil {
resultMap["select"] = q.selection
}
if q.where != nil {
resultMap["where"] = q.where
}
if q.sort != nil {
resultMap["sort"] = q.sort
}
if q.limit != nil {
resultMap["page"] = q.limit
}
resBytes, _ := json.Marshal(resultMap)
res = string(resBytes)
return
}
//Error return error if get error
func (q *Query) Error() error {
return q.err
}

View File

@@ -0,0 +1,34 @@
package datacenter
import "testing"
func TestSelect(t *testing.T) {
var q = &Query{}
q.Select("a,b b2,c as c2")
t.Logf("query=%s", q)
q = &Query{}
q.Select(" * ")
t.Logf("query=%s", q)
}
func TestWhere(t *testing.T) {
var q = &Query{}
q.Select("a,b,c as yeah").Where(
ConditionMapType{
"field1": ConditionIn(1, 2, 3, 4),
"field3": ConditionIn("OK"),
},
ConditionMapType{
"field2": ConditionGt(100),
},
ConditionMapType{
"field1": ConditionGte(100),
})
t.Logf("query=%s", q)
}
func TestSort(t *testing.T) {
var q = &Query{}
q.Order("field1 desc, field2")
t.Logf("query=%s", q)
}

View File

@@ -0,0 +1,32 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
)
go_library(
name = "go_default_library",
srcs = ["validate.go"],
importpath = "go-common/app/interface/main/mcn/tool/validate",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//library/net/http/blademaster/binding:go_default_library",
"//vendor/gopkg.in/go-playground/validator.v9:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,36 @@
package validate
import (
"regexp"
"go-common/library/net/http/blademaster/binding"
"gopkg.in/go-playground/validator.v9"
)
var (
//RegIDcheck 检查身份证
RegIDcheck = regexp.MustCompile(`(^\d{15}$)|(^\d{18}$)|(^\d{17}(\d|X)$)`)
//RegHTTPCheck 检查HTTP格式
RegHTTPCheck = regexp.MustCompile(`^((https|http|ftp|rtsp|mms)?:\/\/)[^\s]+`)
//RegPhoneCheck 检查电话格式
RegPhoneCheck = regexp.MustCompile(`1[345678]\d{9}`)
)
func init() {
binding.Validator.RegisterValidation("idcheck", idcheck)
binding.Validator.RegisterValidation("httpcheck", httpcheck)
binding.Validator.RegisterValidation("phonecheck", phonecheck)
}
func idcheck(fl validator.FieldLevel) bool {
return RegIDcheck.MatchString(fl.Field().String())
}
func httpcheck(fl validator.FieldLevel) bool {
return RegHTTPCheck.MatchString(fl.Field().String())
}
func phonecheck(fl validator.FieldLevel) bool {
return RegPhoneCheck.MatchString(fl.Field().String())
}

View File

@@ -0,0 +1,38 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_library(
name = "go_default_library",
srcs = ["pool.go"],
importpath = "go-common/app/interface/main/mcn/tool/worker",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = ["//library/log:go_default_library"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)
go_test(
name = "go_default_test",
srcs = ["pool_test.go"],
embed = [":go_default_library"],
rundir = ".",
tags = ["automanaged"],
)

View File

@@ -0,0 +1,137 @@
package worker
import (
"fmt"
"runtime"
"sync"
"time"
"go-common/library/log"
)
const (
_ratio = float32(0.8)
)
var (
_default = &Conf{
QueueSize: 1024,
WorkerProcMax: 32,
WorkerNumber: runtime.NumCPU() - 1,
}
)
// Conf .
type Conf struct {
QueueSize int
WorkerProcMax int
WorkerNumber int
}
// Pool .
type Pool struct {
c *Conf
queue chan func()
workerNumber int
close chan struct{}
wg sync.WaitGroup
}
// New .
func New(conf *Conf) (w *Pool) {
if conf == nil {
conf = _default
}
w = &Pool{
c: conf,
queue: make(chan func(), conf.QueueSize),
workerNumber: conf.WorkerNumber,
close: make(chan struct{}),
}
w.start()
go w.moni()
return
}
func (w *Pool) start() {
for i := 0; i < w.workerNumber; i++ {
w.wg.Add(1)
go w.workerRoutine()
}
}
func (w *Pool) moni() {
var conf = w.c
for {
time.Sleep(time.Second * 5)
var ratio = float32(len(w.queue)) / float32(conf.QueueSize)
if ratio >= _ratio {
if w.workerNumber >= conf.WorkerProcMax {
log.Warn("work thread more than max(%d)", conf.WorkerProcMax)
return
}
var next = minInt(w.workerNumber<<1, w.c.WorkerProcMax)
var diff = next - w.workerNumber
log.Info("current thread count=%d, queue ratio=%f, create new thread number=(%d)", w.workerNumber, ratio, diff)
for i := 0; i < diff; i++ {
w.wg.Add(1)
go w.workerRoutine()
}
w.workerNumber = next
}
}
}
// Close .
func (w *Pool) Close() {
close(w.close)
}
// Wait .
func (w *Pool) Wait() {
w.wg.Wait()
}
func (w *Pool) workerRoutine() {
defer func() {
w.wg.Done()
if x := recover(); x != nil {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
log.Error("w.workerRoutine panic(%+v) :\n %s", x, buf)
w.wg.Add(1)
go w.workerRoutine()
}
}()
loop:
for {
select {
case f := <-w.queue:
f()
case <-w.close:
log.Info("workerRoutine close()")
break loop
}
}
for f := range w.queue {
f()
}
}
// Add .
func (w *Pool) Add(f func()) error {
select {
case w.queue <- f:
default:
return fmt.Errorf("task channel is full")
}
return nil
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,61 @@
package worker
import (
"testing"
"time"
)
func TestIncrease(t *testing.T) {
var (
conf = &Conf{
QueueSize: 10,
WorkerProcMax: 10,
WorkerNumber: 1,
}
workerPool = New(conf)
)
for i := 0; i < 10; i++ {
workerPool.Add(longtime)
}
time.Sleep(6 * time.Second)
var expect = minInt(conf.WorkerNumber<<1, conf.WorkerProcMax)
if workerPool.workerNumber != expect {
t.Logf("worker number=%d, expect=%d", workerPool.workerNumber, expect)
t.FailNow()
}
for i := 0; i < 10; i++ {
workerPool.Add(longtime)
}
time.Sleep(6 * time.Second)
expect = minInt(conf.WorkerNumber<<2, conf.WorkerProcMax)
if workerPool.workerNumber != expect {
t.Logf("worker number=%d, expect=%d", workerPool.workerNumber, expect)
t.FailNow()
}
for i := 0; i < 10; i++ {
workerPool.Add(longtime)
}
time.Sleep(6 * time.Second)
expect = minInt(conf.WorkerNumber<<3, conf.WorkerProcMax)
if workerPool.workerNumber != expect {
t.Logf("worker number=%d, expect=%d", workerPool.workerNumber, expect)
t.FailNow()
}
for i := 0; i < 10; i++ {
workerPool.Add(longtime)
}
time.Sleep(6 * time.Second)
expect = minInt(conf.WorkerNumber<<4, conf.WorkerProcMax)
if workerPool.workerNumber != expect {
t.Logf("worker number=%d, expect=%d", workerPool.workerNumber, expect)
t.FailNow()
}
}
func longtime() {
time.Sleep(20 * time.Second)
}