495 lines
15 KiB
Go
495 lines
15 KiB
Go
package mysql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"go-common/app/admin/main/aegis/model/common"
|
|
modtask "go-common/app/admin/main/aegis/model/task"
|
|
xsql "go-common/library/database/sql"
|
|
"go-common/library/log"
|
|
"go-common/library/xstr"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
const (
|
|
_taskSQL = "SELECT id,business_id,flow_id,rid,admin_id,uid,state,weight,utime,gtime,mid,fans,`group`,reason,ctime,mtime from task WHERE id=?"
|
|
_listCheckSQL = "SELECT id FROM task WHERE id IN (%s)"
|
|
|
|
_dispatchByIDSQL = "UPDATE task SET gtime=? WHERE id=? AND state=? AND uid=? AND gtime=0"
|
|
_queryGtimeSQL = "SELECT gtime FROM task WHERE id=? AND state=? AND uid=?"
|
|
_dispatchSQL = "UPDATE task SET gtime=? WHERE state=? AND uid=? AND gtime='0000-00-00 00:00:00' ORDER BY weight LIMIT ?"
|
|
_releaseSQL = "UPDATE task SET admin_id=0,uid=0,state=0,gtime='0000-00-00 00:00:00' WHERE business_id=? AND flow_id=? AND uid=? AND (state=? OR (state=0 AND admin_id>0))"
|
|
_resetGtimeSQL = "UPDATE task SET gtime='0000-00-00 00:00:00' WHERE state=? AND business_id=? AND flow_id=? AND uid=?"
|
|
_seizeSQL = "UPDATE task SET state=?,uid=? WHERE id=? AND state=?"
|
|
_submitSQL = "UPDATE task SET state=?,uid=?,utime=? WHERE id=? AND state=? AND uid=?"
|
|
_delaySQL = "UPDATE task SET state=?,uid=?,reason=?,gtime='0000-00-00 00:00:00' WHERE id=? AND state=? AND uid=?"
|
|
|
|
_consumerSQL = "INSERT INTO task_consumer (business_id,flow_id,uid,state) VALUES (?,?,?,?) ON DUPLICATE KEY UPDATE state=?"
|
|
_onlinesSQL = "SELECT uid,mtime FROM task_consumer WHERE business_id=? AND flow_id=? AND state=?"
|
|
_isconsumerOnSQL = "SELECT state FROM task_consumer WHERE business_id=? AND flow_id=? AND uid=?"
|
|
|
|
_queryTaskSQL = "SELECT id,business_id,flow_id,uid,weight FROM task WHERE state=? AND mtime<=? AND id>? ORDER BY id LIMIT ?"
|
|
_countPersonalSQL = "SELECT count(*) FROM task WHERE state=? AND business_id=? AND flow_id=? AND uid=?"
|
|
_queryForSeizeSQL = "SELECT id FROM task WHERE state=? AND business_id=? AND flow_id=? AND uid IN (0,?) ORDER BY weight DESC LIMIT ?"
|
|
_listTasksSQL = "SELECT `id`,`business_id`,`flow_id`,`rid`,`admin_id`,`uid`,`state`,`weight`,`utime`,`gtime`,`mid`,`fans`,`group`,`reason`,`ctime`,`mtime` FROM task %s ORDER BY weight DESC LIMIT ?,?"
|
|
)
|
|
|
|
// TaskFromDB .
|
|
func (d *Dao) TaskFromDB(c context.Context, id int64) (task *modtask.Task, err error) {
|
|
task = &modtask.Task{}
|
|
err = d.db.QueryRow(c, _taskSQL, id).
|
|
Scan(&task.ID, &task.BusinessID, &task.FlowID, &task.RID, &task.AdminID, &task.UID, &task.State,
|
|
&task.Weight, &task.Utime, &task.Gtime, &task.MID, &task.Fans, &task.Group, &task.Reason, &task.Ctime, &task.Mtime)
|
|
if err != nil {
|
|
task = nil
|
|
if err == sql.ErrNoRows {
|
|
log.Error("TaskFromDB(%d) norows", id)
|
|
err = nil
|
|
return
|
|
}
|
|
log.Error("TaskFromDB(%d) error(%v)", id, errors.WithStack(err))
|
|
}
|
|
return
|
|
}
|
|
|
|
// DispatchByID 派遣任务,更新gtime
|
|
func (d *Dao) DispatchByID(c context.Context, mtasks map[int64]*modtask.Task, ids []int64, args ...interface{}) (missids map[int64]struct{}, err error) {
|
|
var (
|
|
gtime = time.Now()
|
|
uid = args[0].(int64)
|
|
)
|
|
|
|
missids = make(map[int64]struct{})
|
|
|
|
for _, id := range ids {
|
|
var (
|
|
rows int64
|
|
gt time.Time
|
|
res sql.Result
|
|
)
|
|
if err = d.db.QueryRow(c, _queryGtimeSQL, id, modtask.TaskStateDispatch, uid).Scan(>); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
missids[id] = struct{}{}
|
|
err = nil
|
|
continue
|
|
}
|
|
log.Error("d.db.QueryRow error(%v)", errors.WithStack(err))
|
|
return
|
|
}
|
|
|
|
if gt.IsZero() {
|
|
res, err = d.db.Exec(c, _dispatchByIDSQL, gtime, id, modtask.TaskStateDispatch, uid)
|
|
if err != nil {
|
|
log.Error("Exec error(%v)", errors.WithStack(err))
|
|
return
|
|
}
|
|
if rows, err = res.RowsAffected(); err != nil {
|
|
log.Error("RowsAffected error(%v)", errors.WithStack(err))
|
|
return
|
|
}
|
|
if rows == 0 {
|
|
missids[id] = struct{}{}
|
|
} else {
|
|
mtasks[id].Gtime = common.IntTime(gtime.Unix())
|
|
}
|
|
} else {
|
|
mtasks[id].Gtime = common.IntTime(gt.Unix())
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// DBDispatch 直接数据库派遣
|
|
func (d *Dao) DBDispatch(c context.Context, opt *modtask.NextOptions) (tasks []*modtask.Task, count int64, err error) {
|
|
var (
|
|
res sql.Result
|
|
gtime = time.Now()
|
|
)
|
|
|
|
// 1.直接更新派遣时间
|
|
res, err = d.db.Exec(c, _dispatchSQL, gtime, modtask.TaskStateDispatch, opt.UID, opt.DispatchCount)
|
|
if err != nil {
|
|
log.Error("Exec error(%v)", errors.WithStack(err))
|
|
return
|
|
}
|
|
if count, err = res.RowsAffected(); err != nil {
|
|
log.Error("RowsAffected error(%v)", errors.WithStack(err))
|
|
return
|
|
}
|
|
|
|
// 2.读取任务
|
|
wherecache := fmt.Sprintf("WHERE state=%d AND uid=%d AND gtime!='0000-00-00 00:00:00'", modtask.TaskStateDispatch, opt.UID)
|
|
|
|
return d.listTasks(c, &modtask.ListOptions{BaseOptions: opt.BaseOptions, Pager: common.Pager{Pn: 1, Ps: int(opt.DispatchCount)}}, wherecache)
|
|
}
|
|
|
|
// Release 释放任务
|
|
func (d *Dao) Release(c context.Context, opt *common.BaseOptions, delay bool) (rows int64, err error) {
|
|
sql := _releaseSQL
|
|
if delay {
|
|
sql = _releaseSQL + " AND gtime='0000-00-00 00:00:00'"
|
|
}
|
|
|
|
log.Info("Mysql Release(%+v) delay(%v)", opt, delay)
|
|
res, err := d.db.Exec(c, sql, opt.BusinessID, opt.FlowID, opt.UID, modtask.TaskStateDispatch)
|
|
if err != nil {
|
|
log.Error("db.Exec(%s)[%d,%d,%d,%d] error(%v)", sql, opt.BusinessID, opt.FlowID, opt.UID, modtask.TaskStateDispatch, err)
|
|
return
|
|
}
|
|
// 已经下发的延迟5分钟释放
|
|
if delay {
|
|
_, err = d.db.Exec(c, _resetGtimeSQL, modtask.TaskStateDispatch, opt.BusinessID, opt.FlowID, opt.UID)
|
|
if err != nil {
|
|
log.Error("db.Exec(%s)[%d,%d,%d,%d] error(%v)", sql, modtask.TaskStateDispatch, opt.BusinessID, opt.FlowID, opt.UID, err)
|
|
}
|
|
|
|
time.AfterFunc(5*time.Minute, func() {
|
|
d.Release(context.Background(), opt, false)
|
|
})
|
|
}
|
|
|
|
return res.RowsAffected()
|
|
}
|
|
|
|
// Seize 抢占任务
|
|
func (d *Dao) Seize(c context.Context, mapids map[int64]int64) (count int64, err error) {
|
|
tx, err := d.db.Begin(c)
|
|
if err != nil {
|
|
log.Error("d.Seize.Begin error(%v)", errors.WithStack(err))
|
|
return
|
|
}
|
|
defer tx.Commit()
|
|
|
|
for tid, uid := range mapids {
|
|
var (
|
|
rows int64
|
|
res sql.Result
|
|
)
|
|
res, err = tx.Exec(_seizeSQL, modtask.TaskStateDispatch, uid, tid, modtask.TaskStateInit)
|
|
if err != nil {
|
|
log.Error("Exec error(%v)", errors.WithStack(err))
|
|
tx.Rollback()
|
|
return
|
|
}
|
|
if rows, err = res.RowsAffected(); err != nil {
|
|
log.Error("RowsAffected error(%v)", errors.WithStack(err))
|
|
tx.Rollback()
|
|
return
|
|
}
|
|
if rows == 1 {
|
|
count++
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// Delay 延迟任务
|
|
func (d *Dao) Delay(c context.Context, opt *modtask.DelayOptions) (rows int64, err error) {
|
|
var (
|
|
res sql.Result
|
|
)
|
|
res, err = d.db.Exec(c, _delaySQL, modtask.TaskStateDelay, opt.UID, opt.Reason, opt.TaskID, modtask.TaskStateDispatch, opt.UID)
|
|
if err != nil {
|
|
log.Error("Exec error(%v)", errors.WithStack(err))
|
|
return
|
|
}
|
|
if rows, err = res.RowsAffected(); err != nil {
|
|
log.Error("RowsAffected error(%v)", errors.WithStack(err))
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
// ListCheckUnSeized .
|
|
func (d *Dao) ListCheckUnSeized(c context.Context, mtasks map[int64]*modtask.Task, ids []int64, args ...interface{}) (missids map[int64]struct{}, err error) {
|
|
wherecase := fmt.Sprintf("state = %d", modtask.TaskStateInit)
|
|
return d.listCheck(c, wherecase, ids)
|
|
}
|
|
|
|
// ListCheckSeized .
|
|
func (d *Dao) ListCheckSeized(c context.Context, mtasks map[int64]*modtask.Task, ids []int64, args ...interface{}) (missids map[int64]struct{}, err error) {
|
|
if len(args) < 1 {
|
|
return
|
|
}
|
|
uid := args[0].(int64)
|
|
wherecase := fmt.Sprintf("state = %d", modtask.TaskStateDispatch)
|
|
if uid != 0 {
|
|
wherecase += fmt.Sprintf(" AND uid=%d", uid)
|
|
}
|
|
return d.listCheck(c, wherecase, ids)
|
|
}
|
|
|
|
// ListCheckDelay .
|
|
func (d *Dao) ListCheckDelay(c context.Context, mtasks map[int64]*modtask.Task, ids []int64, args ...interface{}) (missids map[int64]struct{}, err error) {
|
|
if len(args) < 1 {
|
|
return
|
|
}
|
|
uid := args[0].(int64)
|
|
wherecase := fmt.Sprintf("state=%d", modtask.TaskStateDelay)
|
|
if uid != 0 {
|
|
wherecase += fmt.Sprintf(" AND uid=%d", uid)
|
|
}
|
|
return d.listCheck(c, wherecase, ids)
|
|
}
|
|
|
|
// ListTasks .
|
|
func (d *Dao) ListTasks(c context.Context, opt *modtask.ListOptions) (tasks []*modtask.Task, count int64, err error) {
|
|
var (
|
|
wherecase string
|
|
cases []string
|
|
state int8
|
|
isDefault bool
|
|
)
|
|
|
|
switch opt.State {
|
|
case 1:
|
|
state = modtask.TaskStateInit
|
|
case 2:
|
|
state = modtask.TaskStateDispatch
|
|
case 3:
|
|
state = modtask.TaskStateDelay
|
|
case 4:
|
|
state = modtask.TaskStateDispatch
|
|
cases = append(cases, "admin_id>0")
|
|
default:
|
|
isDefault = true
|
|
cases = append(cases, fmt.Sprintf("state<%d", modtask.TaskStateSubmit))
|
|
}
|
|
if !isDefault {
|
|
cases = append(cases, fmt.Sprintf("state=%d", state))
|
|
if !opt.BisLeader && (opt.State == 2 || opt.State == 3 || opt.State == 4) {
|
|
cases = append(cases, fmt.Sprintf("uid=%d", opt.UID))
|
|
}
|
|
}
|
|
|
|
wherecase = fmt.Sprintf("WHERE business_id=%d AND flow_id=%d AND ", opt.BusinessID, opt.FlowID) + strings.Join(cases, " AND ")
|
|
return d.listTasks(c, opt, wherecase)
|
|
}
|
|
|
|
func (d *Dao) listTasks(c context.Context, opt *modtask.ListOptions, wherecase string) (tasks []*modtask.Task, count int64, err error) {
|
|
countSQL := fmt.Sprintf("SELECT count(*) FROM task %s", wherecase)
|
|
|
|
if err = d.db.QueryRow(c, countSQL).Scan(&count); err != nil {
|
|
log.Error("QueryRow error(%v)", err)
|
|
return
|
|
}
|
|
|
|
if count > 0 {
|
|
var (
|
|
rows *xsql.Rows
|
|
listSQL = fmt.Sprintf(_listTasksSQL, wherecase)
|
|
)
|
|
|
|
if rows, err = d.db.Query(c, listSQL, (opt.Pn-1)*opt.Ps, opt.Pn*opt.Ps); err != nil {
|
|
log.Error("Query error(%v)", err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
task := &modtask.Task{}
|
|
if err = rows.Scan(&task.ID, &task.BusinessID, &task.FlowID, &task.RID, &task.AdminID, &task.UID, &task.State,
|
|
&task.Weight, &task.Utime, &task.Gtime, &task.MID, &task.Fans, &task.Group, &task.Reason, &task.Ctime, &task.Mtime); err != nil {
|
|
log.Error("Scan error(%v)", err)
|
|
return
|
|
}
|
|
tasks = append(tasks, task)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (d *Dao) listCheck(c context.Context, wherecase string, ids []int64) (missids map[int64]struct{}, err error) {
|
|
if len(ids) == 0 {
|
|
return
|
|
}
|
|
missids = make(map[int64]struct{})
|
|
mapids := make(map[int64]struct{})
|
|
|
|
log.Info("listCheck ids(%v)", ids)
|
|
defer func() {
|
|
log.Info("listCheck missids(%v)", missids)
|
|
}()
|
|
|
|
for _, id := range ids {
|
|
mapids[id] = struct{}{}
|
|
}
|
|
|
|
var (
|
|
rows *xsql.Rows
|
|
sqlstring = fmt.Sprintf(_listCheckSQL, xstr.JoinInts(ids)) + " AND " + wherecase
|
|
)
|
|
|
|
if rows, err = d.db.Query(c, sqlstring); err != nil {
|
|
log.Error("db.Query(%s) error(%v)", sqlstring, errors.WithStack(err))
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var id int64
|
|
if err = rows.Scan(&id); err != nil {
|
|
log.Error("rows.Scan error(%v)", errors.WithStack(err))
|
|
return
|
|
}
|
|
delete(mapids, id)
|
|
}
|
|
|
|
for id := range mapids {
|
|
missids[id] = struct{}{}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// ConsumerOn .
|
|
func (d *Dao) ConsumerOn(c context.Context, opt *common.BaseOptions) (err error) {
|
|
return d.consumer(c, opt, modtask.ActionConsumerOn)
|
|
}
|
|
|
|
// ConsumerOff .
|
|
func (d *Dao) ConsumerOff(c context.Context, opt *common.BaseOptions) (err error) {
|
|
return d.consumer(c, opt, modtask.ActionConsumerOff)
|
|
}
|
|
|
|
// IsConsumerOn .
|
|
func (d *Dao) IsConsumerOn(c context.Context, opt *common.BaseOptions) (on bool, err error) {
|
|
var state int8
|
|
if err = d.db.QueryRow(c, _isconsumerOnSQL, opt.BusinessID, opt.FlowID, opt.UID).Scan(&state); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
err = nil
|
|
return
|
|
}
|
|
log.Error("d.db.QueryRow error(%v)", err)
|
|
return
|
|
}
|
|
if state == modtask.ActionConsumerOn {
|
|
on = true
|
|
}
|
|
return
|
|
}
|
|
|
|
func (d *Dao) consumer(c context.Context, opt *common.BaseOptions, action int8) (err error) {
|
|
var (
|
|
res sql.Result
|
|
)
|
|
res, err = d.db.Exec(c, _consumerSQL, opt.BusinessID, opt.FlowID, opt.UID, action, action)
|
|
if err != nil {
|
|
log.Error("Exec error(%v)", errors.WithStack(err))
|
|
return
|
|
}
|
|
if _, err = res.RowsAffected(); err != nil {
|
|
log.Error("RowsAffected error(%v)", errors.WithStack(err))
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
// ConsumerStat 24小时内有活动或者在线的用户
|
|
func (d *Dao) ConsumerStat(c context.Context, bizid, flowid int64) (items []*modtask.WatchItem, err error) {
|
|
var rows *xsql.Rows
|
|
sql := "SELECT uid,mtime,state from task_consumer where business_id=? AND flow_id=? AND (mtime > ? or state=1) order by mtime desc"
|
|
if rows, err = d.db.Query(c, sql, bizid, flowid, time.Now().Add(-24*time.Hour)); err != nil {
|
|
log.Error("ConsumerStat error(%v)", err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
item := &modtask.WatchItem{}
|
|
if err = rows.Scan(&item.UID, &item.Mtime, &item.State); err != nil {
|
|
log.Error("ConsumerStat error(%v)", err)
|
|
return
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Onlines 在线列表
|
|
func (d *Dao) Onlines(c context.Context, opt *common.BaseOptions) (uids map[int64]time.Time, err error) {
|
|
var (
|
|
rows *xsql.Rows
|
|
)
|
|
|
|
rows, err = d.db.Query(c, _onlinesSQL, opt.BusinessID, opt.FlowID, modtask.ActionConsumerOn)
|
|
if err != nil {
|
|
log.Error("db.Query error(%v)", err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
uids = make(map[int64]time.Time)
|
|
for rows.Next() {
|
|
var (
|
|
uid int64
|
|
mtime time.Time
|
|
)
|
|
if err = rows.Scan(&uid, &mtime); err != nil {
|
|
log.Error("rows.Scan error(%v)", err)
|
|
return
|
|
}
|
|
uids[uid] = mtime
|
|
}
|
|
return
|
|
}
|
|
|
|
// QueryTask .
|
|
func (d *Dao) QueryTask(c context.Context, state int8, mtime time.Time, id, limit int64) (tasks []*modtask.Task, lastid int64, err error) {
|
|
var rows *xsql.Rows
|
|
rows, err = d.db.Query(c, _queryTaskSQL, state, mtime, id, limit)
|
|
if err != nil {
|
|
log.Error("db.Query error(%v)", err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
task := &modtask.Task{}
|
|
if err = rows.Scan(&task.ID, &task.BusinessID, &task.FlowID, &task.UID, &task.Weight); err != nil {
|
|
log.Error("rows.Scan error(%v)", err)
|
|
return
|
|
}
|
|
|
|
tasks = append(tasks, task)
|
|
lastid = task.ID
|
|
}
|
|
return
|
|
}
|
|
|
|
// CountPersonal count personal task
|
|
func (d *Dao) CountPersonal(c context.Context, opt *common.BaseOptions) (count int64, err error) {
|
|
if err = d.db.QueryRow(c, _countPersonalSQL, modtask.TaskStateDispatch, opt.BusinessID, opt.FlowID, opt.UID).Scan(&count); err != nil {
|
|
log.Error("QueryRow error(%v)", errors.WithStack(err))
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
// QueryForSeize 查询当前可抢占的任务
|
|
func (d *Dao) QueryForSeize(c context.Context, businessID, flowID, uid, seizecount int64) (hitids []int64, err error) {
|
|
log.Info("task-QueryForSeize businessID(%d), flowID(%d), uid(%d), seizecount(%d)", businessID, flowID, uid, seizecount)
|
|
defer func() { log.Info("task-QueryForSeize hitids(%v), err(%v)", hitids, err) }()
|
|
var rows *xsql.Rows
|
|
rows, err = d.db.Query(c, _queryForSeizeSQL, modtask.TaskStateInit, businessID, flowID, uid, seizecount)
|
|
if err != nil {
|
|
log.Error("db.Query error(%v)", err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var id int64
|
|
if err = rows.Scan(&id); err != nil {
|
|
log.Error("rows.Scan error(%v)", err)
|
|
return
|
|
}
|
|
hitids = append(hitids, id)
|
|
}
|
|
return
|
|
}
|