Create & Init Project...

This commit is contained in:
2019-04-22 18:49:16 +08:00
commit fc4fa37393
25440 changed files with 4054998 additions and 0 deletions

View File

@@ -0,0 +1,65 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_test",
"go_library",
)
go_test(
name = "go_default_test",
srcs = [
"common_message_test.go",
"mids_test.go",
"push_test.go",
"service_test.go",
"start_live_test.go",
],
embed = [":go_default_library"],
tags = ["automanaged"],
deps = [
"//app/interface/live/push-live/conf:go_default_library",
"//app/interface/live/push-live/dao:go_default_library",
"//app/interface/live/push-live/model:go_default_library",
"//library/cache/redis:go_default_library",
"//vendor/github.com/smartystreets/goconvey/convey:go_default_library",
],
)
go_library(
name = "go_default_library",
srcs = [
"common_message.go",
"mids.go",
"push.go",
"service.go",
"start_live.go",
],
importpath = "go-common/app/interface/live/push-live/service",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = [
"//app/interface/live/push-live/conf:go_default_library",
"//app/interface/live/push-live/dao:go_default_library",
"//app/interface/live/push-live/model:go_default_library",
"//library/cache/redis:go_default_library",
"//library/log:go_default_library",
"//library/queue/databus:go_default_library",
"//library/sync/errgroup:go_default_library",
"//vendor/github.com/pkg/errors:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,94 @@
package service
import (
"context"
"encoding/json"
"go-common/app/interface/live/push-live/dao"
"go-common/app/interface/live/push-live/model"
"go-common/library/cache/redis"
"go-common/library/log"
"go-common/library/queue/databus"
)
// LiveCommonMessage 直播通用消息
func (s *Service) LiveCommonMessage(ctx context.Context, msg *databus.Message) (err error) {
defer msg.Commit()
var (
mids []int64
mMap = make(map[int64]bool) // mid去重
midMap = make(map[int][]int64) // 最终格式化后的mid map
)
m := new(model.LiveCommonMessage)
if err = json.Unmarshal(msg.Value, &m); err != nil {
log.Error("[service.common_message|LiveCommonMessage] json Unmarshal error(%v), model(%v)", err, m)
return
}
task := s.InitCommonTask(m)
if mids, err = s.convertStrToInt64(m.MsgContent.Mids); err != nil {
log.Error("[service.push|LiveCommonMessage] format Mids error(%v), task(%v), model(%v)", err, task, m)
return
}
// remove duplicated mid
for _, mid := range mids {
mMap[mid] = true
}
// mid filter
business := m.MsgContent.Business
filteredMids := s.midFilter(mMap, business, task)
midMap[business] = filteredMids
log.Info("[service.push|LiveCommonMessage] message info: before(%d), after(%d), model(%v), task(%v)",
len(mMap), len(midMap[business]), m, task)
total := s.Push(task, midMap)
// create push task
go s.CreatePushTask(task, total)
go s.setPushInterval(business, s.safeGetExpired(), filteredMids, task)
log.Info("[service.push|LiveCommonMessage] common message push done, total(%d), err(%v)", total, err)
return
}
// InitCommonTask Init push task by common message model
func (s *Service) InitCommonTask(m *model.LiveCommonMessage) (task *model.ApPushTask) {
task = &model.ApPushTask{
Type: model.LivePushType,
TargetID: 0,
AlertTitle: m.MsgContent.AlertTitle,
AlertBody: m.MsgContent.AlertBody,
MidSource: m.MsgContent.Business,
LinkType: m.MsgContent.LinkType,
LinkValue: m.MsgContent.LinkValue,
ExpireTime: m.MsgContent.ExpireTime,
Group: m.MsgContent.Group,
}
return task
}
// setPushInterval 活动预约对每个mid设置推送平滑key
func (s *Service) setPushInterval(business int, expired int32, mids []int64, task *model.ApPushTask) (total int, err error) {
if business != 111 {
return
}
var conn redis.Conn
defer func() {
if conn != nil {
conn.Close()
}
}()
// redis conn
conn, err = redis.Dial(s.c.Redis.PushInterval.Proto, s.c.Redis.PushInterval.Addr, s.dao.RedisOption()...)
if err != nil {
log.Error("[service.common_message|setPushInterval] redis.Dial error(%v), task(%v), mids(%d)",
err, task, len(mids))
return
}
for _, mid := range mids {
key := dao.GetIntervalKey(mid)
_, err = conn.Do("SET", key, task.LinkValue, "EX", expired)
if err != nil {
log.Error("[service.common_message|setPushInterval] set redis error(%v), task(%v), mid(%d)",
err, task, mid)
continue
}
total++
}
return
}

View File

@@ -0,0 +1,98 @@
package service
import (
. "github.com/smartystreets/goconvey/convey"
"go-common/app/interface/live/push-live/dao"
"go-common/app/interface/live/push-live/model"
"go-common/library/cache/redis"
"math/rand"
"strconv"
"testing"
)
func makeTestCommonPushTask(title, body, linkValue, group string, business, expireTime int) (task *model.ApPushTask) {
m := &model.LiveCommonMessage{}
m.MsgContent = model.LiveCommonMessageContent{
Business: business,
Group: group,
Mids: "",
AlertTitle: title,
AlertBody: body,
LinkValue: linkValue,
ExpireTime: expireTime,
}
task = s.InitCommonTask(m)
return
}
func TestService_InitCommonTask(t *testing.T) {
initd()
Convey("should return init struct", t, func() {
title := "room_title"
body := "测试"
group := "group"
linkValue := strconv.Itoa(rand.Intn(9999))
expireTime := rand.Intn(10000) + 1
business := rand.Intn(9999)
task := makeTestCommonPushTask(title, body, linkValue, group, business, expireTime)
So(task.AlertTitle, ShouldResemble, title)
So(task.AlertBody, ShouldResemble, body)
So(task.ExpireTime, ShouldResemble, expireTime)
So(task.LinkValue, ShouldResemble, linkValue)
So(task.MidSource, ShouldEqual, business)
So(task.Group, ShouldEqual, group)
})
}
func TestService_setPushInterval(t *testing.T) {
initd()
Convey("test setPushInterval", t, func() {
var (
resTotal int
total int
business int
task *model.ApPushTask
mids []int64
err error
)
Convey("test business will not exec logic", func() {
business = rand.Intn(100)
task = &model.ApPushTask{}
total = 10
mids = makeMids(total)
resTotal, err = s.setPushInterval(business, rand.Int31(), mids, task)
So(err, ShouldBeNil)
So(resTotal, ShouldEqual, 0)
})
Convey("test business will exec logic", func() {
var conn redis.Conn
business = 111
task = &model.ApPushTask{
LinkValue: "test",
}
total = 10
mids = makeMids(total)
resTotal, err = s.setPushInterval(business, 300, mids, task)
So(err, ShouldBeNil)
So(resTotal, ShouldEqual, total)
// clean
conn, err = redis.Dial(s.c.Redis.PushInterval.Proto, s.c.Redis.PushInterval.Addr, s.dao.RedisOption()...)
So(err, ShouldBeNil)
for _, mid := range mids {
key := dao.GetIntervalKey(mid)
conn.Do("DEL", key)
}
conn.Close()
})
})
}
func makeMids(total int) []int64 {
mids := make([]int64, 0, total)
for i := 0; i < total; i++ {
mids = append(mids, rand.Int63())
}
return mids
}

View File

@@ -0,0 +1,96 @@
package service
import (
"context"
"go-common/app/interface/live/push-live/dao"
"go-common/app/interface/live/push-live/model"
"go-common/library/log"
"strings"
"sync"
"time"
)
// MidFilter 收敛所有mid过滤逻辑入口
func (s *Service) midFilter(ml map[int64]bool, business int, task *model.ApPushTask) (midMap []int64) {
var (
mutex sync.Mutex
i int
midsList [][]int64
wg sync.WaitGroup
needDecrease = needDecrease(business)
filterConf = &dao.FilterConfig{
Business: business,
IntervalExpired: s.safeGetExpired(),
IntervalValue: intervalValueByLinkValue(task.LinkValue),
DailyExpired: dailyExpired(time.Now()),
Task: task}
)
midMap = make([]int64, 0, len(ml))
// split mids by limit
mids := make([]int64, 0, s.c.Push.IntervalLimit)
for mid := range ml {
mids = append(mids, mid)
i++
if i == s.c.Push.IntervalLimit {
i = 0
midsList = append(midsList, mids)
mids = make([]int64, 0, s.c.Push.IntervalLimit)
}
}
if len(mids) > 0 {
midsList = append(midsList, mids)
}
// filter goroutines
for i := 0; i < len(midsList); i++ {
wg.Add(1)
go func(index int, mids []int64) {
var (
filteredMids []int64
f *dao.Filter
err error
ctx = context.TODO()
)
defer func() {
log.Info("[service.mids|midFilter] BatchFilter before(%d), after(%d), task(%v), business(%d), err(%v)",
len(mids), len(filteredMids), task, business, err)
wg.Done()
}()
// new filter
f, err = s.dao.NewFilter(filterConf)
if err != nil {
return
}
filteredMids = f.BatchFilter(ctx, s.dao.NewFilterChain(f), mids)
if len(filteredMids) == 0 {
f.Done()
return
}
// after filter, do something
if needDecrease {
go f.BatchDecreaseLimit(ctx, filteredMids)
}
mutex.Lock()
midMap = append(midMap, filteredMids...)
mutex.Unlock()
}(i, midsList[i])
}
wg.Wait()
log.Info("[service.mids|midFilter] filtered task(%v), before(%d), after(%d), type(%d)",
task, len(ml), len(midMap), business)
return
}
// intervalValueByLinkValue get roomid by link value
func intervalValueByLinkValue(linkValue string) string {
s := strings.Split(linkValue, ",")
return s[0]
}
// needDecrease
func needDecrease(business int) bool {
return business != model.ActivityBusiness
}

View File

@@ -0,0 +1,53 @@
package service
import (
"testing"
. "github.com/smartystreets/goconvey/convey"
"go-common/app/interface/live/push-live/dao"
"go-common/app/interface/live/push-live/model"
"go-common/library/cache/redis"
"math/rand"
)
func TestService_MidFilter(t *testing.T) {
initd()
Convey("test mid filter", t, func() {
var (
total int
midList map[int64]bool
business int
resMids []int64
task *model.ApPushTask
err error
conn redis.Conn
)
business = rand.Intn(99) + 1 // should through all filters
// init mids input
total = 10
midList = make(map[int64]bool, total)
for i := 0; i < total; i++ {
mid := rand.Int63()
midList[mid] = true
}
// init task
task = &model.ApPushTask{
LinkValue: "test",
}
// do mid filter
resMids = s.midFilter(midList, business, task)
So(len(resMids), ShouldEqual, total)
// clean
conn, err = redis.Dial(s.c.Redis.PushInterval.Proto, s.c.Redis.PushInterval.Addr, s.dao.RedisOption()...)
So(err, ShouldBeNil)
for mid := range midList {
keys := []string{
dao.GetDailyLimitKey(mid),
dao.GetIntervalKey(mid),
}
for _, key := range keys {
conn.Do("DEL", key)
}
}
})
}

View File

@@ -0,0 +1,92 @@
package service
import (
"context"
"go-common/app/interface/live/push-live/model"
"go-common/library/log"
"time"
)
// liveMessageConsumeproc Live push related message handler
func (s *Service) liveMessageConsumeproc() {
defer func() {
log.Warn("liveMessageConsumeproc exited.")
s.wg.Done()
}()
var (
liveStartMsgs = s.liveStartSub.Messages()
liveCommonMsgs = s.liveCommonSub.Messages()
)
for {
select {
case msg, ok := <-liveStartMsgs:
if !ok {
log.Warn("[service.push|liveMessageConsumeproc] liveStartSub has been closed.")
return
}
log.Info("[service.push|liveMessageConsumeproc] consume liveStartSub key(%s) offset(%d) message(%s)",
msg.Key, msg.Offset, msg.Value)
s.LiveStartMessage(context.TODO(), msg)
case msg, ok := <-liveCommonMsgs:
if !ok {
log.Warn("[service.push|liveMessageConsumeproc] liveCommonSub has been closed.")
return
}
log.Info("[service.push|liveMessageConsumeproc] consume liveCommonSub key(%s) offset(%d) message(%s)",
msg.Key, msg.Offset, msg.Value)
s.LiveCommonMessage(context.TODO(), msg)
default:
time.Sleep(time.Second * 3)
continue
}
}
}
// Push 组装业务参数,调用推送平台接口
func (s *Service) Push(task *model.ApPushTask, midMap map[int][]int64) (total int) {
var shouldPushCount int
for t, list := range midMap {
length := len(list)
shouldPushCount += length
if length > 0 {
// 调用批量推送方法批量推送逻辑会切分mid与出错重试最后返回实际推送成功数量
task.Group = s.GetPushGroup(t, task.Group)
pushCount := s.dao.BatchPush(&list, task)
log.Info("[service.push|Push] push type(%d), count(%d), target_id(%v)", t, pushCount, task.TargetID)
total += pushCount
}
}
if shouldPushCount == 0 {
log.Info("[service.push|Push] None to push, task(%+v)", task)
return
}
log.Info("[service.push|Push] push done.should(%d), actual(%d), task(%+v).", shouldPushCount, total, task)
return
}
// CreatePushTask create push task
func (s *Service) CreatePushTask(task *model.ApPushTask, total int) (affected int64, err error) {
task.Total = total
affected, err = s.dao.CreateNewTask(context.TODO(), task)
if err != nil || affected == 0 {
log.Error("[service.push|CreatePushTask] CreateNewTask error(%v), task(%+v)", err, task)
return
}
log.Info("[service.push|CreatePushTask] CreateNewTask success, task(%+v)", task)
return
}
// GetPushGroup 获取不同类型的group
// 兼容逻辑: 开播提醒topic有指定的group(并且单次开播需要区分关注与特别关注两个group)其余common message topic会传group
func (s *Service) GetPushGroup(t int, g string) string {
var group string
switch t {
case model.RelationAttention:
group = model.AttentionGroup
case model.RelationSpecial:
group = model.SpecialGroup
default:
group = g
}
return group
}

View File

@@ -0,0 +1,59 @@
package service
import (
. "github.com/smartystreets/goconvey/convey"
"go-common/app/interface/live/push-live/model"
"math/rand"
"strconv"
"testing"
)
func makeTestInitPushTask(targetID int64, uname, linkValue,
roomTitle string, expireTime int) (task *model.ApPushTask) {
m := &model.StartLiveMessage{
TargetID: targetID,
Uname: uname,
LinkValue: linkValue,
RoomTitle: roomTitle,
ExpireTime: expireTime,
}
task = s.InitPushTask(m)
return
}
func TestService_Push(t *testing.T) {
initd()
Convey("test push func", t, func() {
// test empty mids
targetID := rand.Int63n(100) + 1
uname := "测试"
linkValue := strconv.Itoa(rand.Intn(9999))
roomTitle := "room_title"
expireTime := rand.Intn(10000) + 1
task := makeTestInitPushTask(targetID, uname, linkValue, roomTitle, expireTime)
midMap := make(map[int][]int64)
midMap[model.RelationAttention] = []int64{}
total := s.Push(task, midMap)
So(total, ShouldEqual, 0)
})
}
func TestService_GetPushGroup(t *testing.T) {
initd()
Convey("test get group by different push type", t, func() {
var (
group string
testGroup = "test_group"
)
group = s.GetPushGroup(model.RelationAttention, "")
So(group, ShouldEqual, model.AttentionGroup)
group = s.GetPushGroup(model.RelationSpecial, "")
So(group, ShouldEqual, model.SpecialGroup)
group = s.GetPushGroup(rand.Intn(9999), testGroup)
So(group, ShouldEqual, testGroup)
})
}

View File

@@ -0,0 +1,238 @@
package service
import (
"context"
"errors"
"fmt"
"math"
"strconv"
"strings"
"sync"
"time"
"go-common/app/interface/live/push-live/conf"
"go-common/app/interface/live/push-live/dao"
"go-common/library/cache/redis"
"go-common/library/log"
"go-common/library/queue/databus"
)
var (
_limitDecreaseUUIDKey = "ld:%s" // 接口请求防重复key
errLimitRequestRepeat = errors.New("limit decrease request repeat")
errConvertMidString = errors.New("convert mid string error")
errConvertBusiness = errors.New("convert business error")
)
// Service struct
type Service struct {
c *conf.Config
dao *dao.Dao
liveStartSub *databus.Databus
liveCommonSub *databus.Databus
wg sync.WaitGroup
closeCh chan bool
pushTypes []string
intervalExpired int32
mutex sync.RWMutex
}
// New init
func New(c *conf.Config) (s *Service) {
s = &Service{
c: c,
dao: dao.New(c),
liveStartSub: databus.New(c.LiveRoomSub),
liveCommonSub: databus.New(c.LiveCommonSub),
closeCh: make(chan bool),
pushTypes: make([]string, 0, 4),
mutex: sync.RWMutex{},
}
s.wg.Add(1)
go s.loadPushConfig()
for i := 0; i < c.Push.ConsumerProcNum; i++ {
s.wg.Add(1)
go s.liveMessageConsumeproc()
}
return s
}
// loadPushConfig Load push config
func (s *Service) loadPushConfig() {
var ctx = context.TODO()
defer s.wg.Done()
for {
select {
case _, ok := <-s.closeCh:
if !ok {
log.Info("[service.push|loadPushConfig] s.loadPushConfig is closed by closeCh")
return
}
default:
}
// get push delay time
interval, err := s.dao.GetPushInterval(ctx)
if err != nil || interval < 0 {
time.Sleep(time.Duration(time.Minute))
continue
}
s.mutex.Lock()
s.intervalExpired = interval
s.mutex.Unlock()
// get push options
types, err := s.dao.GetPushConfig(ctx)
if err != nil || len(types) == 0 {
time.Sleep(time.Duration(time.Minute))
continue
}
s.mutex.Lock()
s.pushTypes = types
s.mutex.Unlock()
time.Sleep(time.Duration(time.Minute))
}
}
// safeGetExpired
func (s *Service) safeGetExpired() int32 {
s.mutex.RLock()
expired := s.intervalExpired
s.mutex.RUnlock()
return expired
}
// LimitDecrease do mid string limit decrease
func (s *Service) LimitDecrease(ctx context.Context, business, targetID, uuid, midStr string) (err error) {
var (
f *dao.Filter
mids []int64
b int
)
// 判断请求是否重复
err = s.limitDecreaseUnique(getUniqueKey(business, targetID, uuid))
if err != nil {
log.Error("[service.service|LimitDecrease] limitDecreaseUnique error(%v), uuid(%s), business(%s), targetID(%s), mid(%s)",
err, uuid, business, targetID, midStr)
return
}
b, err = strconv.Atoi(business)
if err != nil {
log.Error("[service.service|LimitDecrease] strconv business params error(%v)", err)
err = errConvertBusiness
return
}
filterConf := &dao.FilterConfig{
Business: b,
DailyExpired: dailyExpired(time.Now())}
// convert mid string to []int64
mids, err = s.convertStrToInt64(midStr)
if err != nil {
log.Error("[service.service|LimitDecrease] convertStrToInt64 error(%v), business(%s), uuid(%s), mids(%s)",
err, business, uuid, midStr)
err = errConvertMidString
return
}
// aysnc decrease limit
f, err = s.dao.NewFilter(filterConf)
if err != nil {
log.Error("[service.service|LimitDecrease] new filter error(%v), business(%s), uuid(%s), mids(%v)",
err, business, uuid, mids)
return
}
go f.BatchDecreaseLimit(ctx, mids)
return
}
// Ping Service
func (s *Service) Ping(c context.Context) (err error) {
return nil
}
// Close Service
func (s *Service) Close() {
close(s.closeCh)
s.subClose()
s.wg.Wait()
s.dao.Close()
}
// subClose Close all sub channels
func (s *Service) subClose() {
s.liveCommonSub.Close()
s.liveStartSub.Close()
}
// dailyExpired
func dailyExpired(from time.Time) float64 {
tm1 := time.Date(from.Year(), from.Month(), from.Day(), 0, 0, 0, 0, from.Location())
tm2 := tm1.AddDate(0, 0, 1)
return math.Floor(tm2.Sub(from).Seconds())
}
// convertStrToInt64 convert mid string to []int64 slice
func (s *Service) convertStrToInt64(m string) (mInts []int64, err error) {
var (
mSplit []string
errCount int
)
if m == "" {
return
}
mSplit = strings.Split(m, ",")
for _, mStr := range mSplit {
mInt, convErr := strconv.Atoi(mStr)
if convErr != nil {
log.Error("[service.push|formatMidstr] convert mid(%v), error(%v)", mStr, convErr)
errCount++
continue
}
mInts = append(mInts, int64(mInt))
}
if errCount == len(mSplit) {
err = fmt.Errorf("[service.push|formatMidstr] convert all mid failed, midstr(%s)", m)
}
return
}
// limitDecreaseUnique
func (s *Service) limitDecreaseUnique(key string) (err error) {
var (
conn redis.Conn
reply interface{}
)
defer func() {
if conn != nil {
conn.Close()
}
}()
conn, err = redis.Dial(s.c.Redis.PushInterval.Proto, s.c.Redis.PushInterval.Addr, s.dao.RedisOption()...)
if err != nil {
log.Error("[service.service|limitDecreaseUnique] redis.Dial error(%v)", err)
return
}
// redis cache exists judgement
reply, err = conn.Do("SET", key, time.Now(), "EX", dailyExpired(time.Now()), "NX")
if err != nil {
return
}
// key exists
if reply == nil {
err = errLimitRequestRepeat
return
}
return
}
// getUniqueKey get request unique key
func getUniqueKey(a, b, c string) string {
return fmt.Sprintf(_limitDecreaseUUIDKey, a+b+c)
}

View File

@@ -0,0 +1,84 @@
package service
import (
"context"
"flag"
. "github.com/smartystreets/goconvey/convey"
"go-common/app/interface/live/push-live/conf"
"go-common/library/cache/redis"
"path/filepath"
"testing"
)
var (
s *Service
targetID int64
)
func initd() {
dir, _ := filepath.Abs("../cmd/push-live-test.toml")
flag.Set("conf", dir)
conf.Init()
s = New(conf.Conf)
}
func TestService_ConvertStrToInt64(t *testing.T) {
initd()
Convey("test convert", t, func() {
mStr := "1,2,3"
mInt64 := []int64{
int64(1), int64(2), int64(3),
}
mRes, err := s.convertStrToInt64(mStr)
So(err, ShouldBeNil)
So(mRes, ShouldResemble, mInt64)
})
}
func TestService_limitDecreaseUnique(t *testing.T) {
initd()
Convey("test limit decrease request unique", t, func() {
var (
err error
conn redis.Conn
key string
)
Convey("test success request", func() {
key = "test_request_unique"
conn, err = redis.Dial(s.c.Redis.PushInterval.Proto, s.c.Redis.PushInterval.Addr, s.dao.RedisOption()...)
So(err, ShouldBeNil)
err = s.limitDecreaseUnique(key)
So(err, ShouldBeNil)
// clean
conn.Do("DEL", key)
conn.Close()
})
})
}
func TestService_LimitDecrease(t *testing.T) {
initd()
Convey("test LimitDecrease service", t, func() {
var (
ctx = context.Background()
business, targetID, uuid, midStr string
err error
conn redis.Conn
)
Convey("test success", func() {
business = "111"
targetID = "123"
uuid = "test"
midStr = "1,2,3"
conn, err = redis.Dial(s.c.Redis.PushInterval.Proto, s.c.Redis.PushInterval.Addr, s.dao.RedisOption()...)
So(err, ShouldBeNil)
err = s.LimitDecrease(ctx, business, targetID, uuid, midStr)
So(err, ShouldBeNil)
// clean
key := getUniqueKey(business, targetID, uuid)
conn.Do("DEL", key)
conn.Close()
})
})
}

View File

@@ -0,0 +1,185 @@
package service
import (
"context"
"encoding/json"
"fmt"
"go-common/app/interface/live/push-live/model"
"go-common/library/log"
"go-common/library/queue/databus"
"go-common/library/sync/errgroup"
"sync"
"github.com/pkg/errors"
)
// LiveStartMessage 直播开播提醒推送消息
func (s *Service) LiveStartMessage(ctx context.Context, msg *databus.Message) (err error) {
defer msg.Commit()
var total int
// message
m := new(model.StartLiveMessage)
if err = json.Unmarshal(msg.Value, &m); err != nil {
log.Error("[service.start_live|LiveStartMessage] json Unmarshal error(%v)", err)
return
}
task := s.InitPushTask(m)
midMap := s.GetMids(ctx, task)
// do push
total = s.Push(task, midMap)
// create push task
go s.CreatePushTask(task, total)
log.Info("[service.push|LiveStartMessage] start live push done, total(%d), task(%v), model(%v), err(%v)",
total, task, m, err)
return
}
// InitPushTask 初始化开播提醒推送task
func (s *Service) InitPushTask(m *model.StartLiveMessage) (task *model.ApPushTask) {
s.mutex.RLock()
currentPushTypes := s.pushTypes
s.mutex.RUnlock()
// push task model
task = &model.ApPushTask{
Type: model.LivePushType,
TargetID: m.TargetID,
AlertTitle: m.Uname,
AlertBody: m.RoomTitle,
MidSource: s.getSourceByTypes(currentPushTypes),
LinkType: s.c.Push.LinkType,
LinkValue: m.LinkValue,
ExpireTime: m.ExpireTime,
}
return task
}
// GetMids 开播提醒根据配置的策略从不同来源获取需要推送的用户id
func (s *Service) GetMids(c context.Context, task *model.ApPushTask) map[int][]int64 {
var (
mutex sync.Mutex
group = errgroup.Group{}
fans = make(map[int64]bool)
fansSP = make(map[int64]bool)
midMap = make(map[int][]int64)
midBlackList = make(map[int64]bool)
)
// 获取黑名单
mb, err := s.dao.GetBlackList(c, task)
if err != nil {
log.Error("[service.start_live|GetMids] get black list error(%v), task(%+v)", err, task)
} else {
midBlackList = mb
log.Info("[service.start_live|GetMids] get black list len(%d), task(%+v)", len(midBlackList), task)
}
// try get latest push options and expired time
s.mutex.RLock()
currentPushTypes := s.pushTypes
s.mutex.RUnlock()
// 开多个协程获取后求并集
for _, t := range currentPushTypes {
tp := string(t)
group.Go(func() (e error) {
var mFans, mSpe map[int64]bool
switch tp {
case model.StrategySwitch:
// 直播开关
mFans, mSpe, e = s.GetFansBySwitch(context.TODO(), task.TargetID)
case model.StrategySpecial:
// 只获取特别关注
mFans, mSpe, e = s.dao.Fans(context.TODO(), task.TargetID, model.RelationSpecial)
case model.StrategyFans:
// 只获取普通关注
mFans, mSpe, e = s.dao.Fans(context.TODO(), task.TargetID, model.RelationAttention)
case model.StrategySwitchSpecial:
// 只获取特别关注(直播开关中的特别关注)
mFans, mSpe, e = s.GetFansBySwitchAndSpecial(context.TODO(), task.TargetID)
default:
log.Error("[service.mids|GetMids] strategy invalid, type(%s), task(%+v)", tp, task)
e = fmt.Errorf("[service.mids|GetMids] strategy invalid, type(%s), task(%+v)", tp, task)
return e
}
if e != nil {
log.Error("[service.mids|GetMids] get mid error(%v), type(%s), task(%+v)", e, tp, task)
return e
}
// 来源之间求并集并过滤重复出现的id
// filter by black list
mutex.Lock()
for fansID := range mFans {
if _, ok := midBlackList[fansID]; !ok {
fans[fansID] = true
}
}
for fansID := range mSpe {
if _, ok := midBlackList[fansID]; !ok {
fansSP[fansID] = true
}
}
mutex.Unlock()
log.Info("[service.mids|GetMids] get mids by type(%s), task(%+v), common(%d), special(%d)",
tp, task, len(mFans), len(mSpe))
return e
})
}
group.Wait()
if len(fansSP) > 0 {
midMap[model.RelationSpecial] = s.midFilter(fansSP, model.StartLiveBusiness, task)
}
if len(fans) > 0 {
midMap[model.RelationAttention] = s.midFilter(fans, model.StartLiveBusiness, task)
}
return midMap
}
// GetFansBySwitch 开播提醒获取开关mids
func (s *Service) GetFansBySwitch(c context.Context, targetID int64) (fans map[int64]bool, fansSP map[int64]bool, err error) {
// 获取直播侧开关数据(可能包含普通关注与特别关注)
m, err := s.dao.GetFansBySwitch(c, targetID)
if err != nil {
err = errors.WithStack(err)
log.Error("[service.mids|GetMidsBySwitch] get switch mids error(%v), targetID(%v)", err, targetID)
return
}
// 区分普通关注与特别关注
fans, fansSP, err = s.dao.SeparateFans(c, targetID, m)
return
}
// GetFansBySwitchAndSpecial 开播提醒,获取开关用户与特别关注用户的交集
func (s *Service) GetFansBySwitchAndSpecial(c context.Context, targetID int64) (fans map[int64]bool, fansSP map[int64]bool, err error) {
// 获取直播侧开关数据(可能包含普通关注与特别关注)
m, err := s.dao.GetFansBySwitch(c, targetID)
if err != nil {
err = errors.WithStack(err)
log.Error("[service.mids|GetMidsBySwitch] get switch mids error(%v), targetID(%v)", err, targetID)
return
}
// 从开关数据中获取到特别关注的部分
_, fansSP, err = s.dao.SeparateFans(c, targetID, m)
return
}
// getSourceByTypes 根据不同的推送策略构造Task.MidSource字段
func (s *Service) getSourceByTypes(types []string) int {
var source, midSource int
for _, t := range types {
switch t {
case model.StrategySwitch:
source = model.TaskSourceSwitch
case model.StrategySpecial:
source = model.TaskSourceSpecial
case model.StrategyFans:
source = model.TaskSourceFans
case model.StrategySwitchSpecial:
source = model.TaskSourceSwitchSpe
default:
source = 0
}
midSource = midSource ^ source
}
return midSource
}

View File

@@ -0,0 +1,92 @@
package service
import (
"context"
. "github.com/smartystreets/goconvey/convey"
"go-common/app/interface/live/push-live/model"
"math/rand"
"strconv"
"testing"
)
func TestService_InitPushTask(t *testing.T) {
initd()
Convey("should return init struct", t, func() {
targetID = rand.Int63n(100) + 1
uname := "测试"
linkValue := strconv.Itoa(rand.Intn(9999))
roomTitle := "room_title"
expireTime := rand.Intn(10000) + 1
task := makeTestInitPushTask(targetID, uname, linkValue, roomTitle, expireTime)
So(task.TargetID, ShouldResemble, targetID)
So(task.AlertTitle, ShouldResemble, uname)
So(task.AlertBody, ShouldResemble, roomTitle)
So(task.ExpireTime, ShouldResemble, expireTime)
So(task.LinkValue, ShouldResemble, linkValue)
})
}
func TestDao_GetSourceByTypes(t *testing.T) {
initd()
Convey("Get mid_source by different types", t, func() {
types := []string{model.StrategySwitch, model.StrategyFans, model.StrategySpecial, model.StrategySwitchSpecial}
length := len(types)
currentX := rand.Intn(length)
currentY := rand.Intn(length)
var currentTypes []string
if currentX >= currentY {
currentTypes = types[currentY:currentX]
} else {
currentTypes = types[currentX:currentY]
}
midSource := s.getSourceByTypes(currentTypes)
So(midSource, ShouldBeGreaterThanOrEqualTo, 0)
So(midSource, ShouldBeLessThanOrEqualTo, 15)
})
}
func TestService_GetFansBySwitch(t *testing.T) {
initd()
Convey("should find some fans id by given target id", t, func() {
targetID = 27515316
fans, fansSP, err := s.GetFansBySwitch(context.Background(), targetID)
So(len(fans), ShouldBeGreaterThan, 0)
So(len(fansSP), ShouldBeGreaterThan, 0)
So(err, ShouldBeNil)
})
}
func TestService_GetFansBySwitchAndSpecial(t *testing.T) {
initd()
Convey("should find some fans id by given target id", t, func() {
targetID = 27515316
fans, fansSP, err := s.GetFansBySwitchAndSpecial(context.Background(), targetID)
So(len(fans), ShouldEqual, 0)
So(len(fansSP), ShouldBeGreaterThan, 0)
So(err, ShouldBeNil)
})
}
func TestService_GetMids(t *testing.T) {
initd()
Convey("should find some fans id by given target id", t, func() {
targetID = 27515316
uname := "测试"
linkValue := strconv.Itoa(rand.Intn(9999))
roomTitle := "room_title"
expireTime := rand.Intn(10000) + 1
task := makeTestInitPushTask(targetID, uname, linkValue, roomTitle, expireTime)
types := []string{"Switch", "Special"}
s.pushTypes = types
midMap := s.GetMids(context.Background(), task)
for _, list := range midMap {
So(len(list), ShouldBeGreaterThan, 0)
}
})
}