Skip to content

Commit

Permalink
feat(tars): Support rpc async callback call
Browse files Browse the repository at this point in the history
  • Loading branch information
reallovelei authored and lbbniu committed Apr 26, 2023
1 parent 35ac1fa commit fe8a4e9
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 63 deletions.
2 changes: 1 addition & 1 deletion tars/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ func (c *AdapterProxy) doKeepAlive() {
IRequestId: c.servantProxy.genRequestID(),
SServantName: c.servantProxy.name,
SFuncName: "tars_ping",
ITimeout: int32(c.servantProxy.timeout),
ITimeout: int32(c.servantProxy.asyncTimeout),
}
msg := &Message{Req: &req, Ser: c.servantProxy}
msg.Init()
Expand Down
1 change: 1 addition & 0 deletions tars/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ func (a *application) initConfig() {
a.cltCfg.Stat = cMap["stat"]
a.cltCfg.Property = cMap["property"]
a.cltCfg.ModuleName = cMap["modulename"]
a.cltCfg.SyncInvokeTimeout = c.GetIntWithDef("/tars/application/client<sync-invoke-timeout>", SyncInvokeTimeout)
a.cltCfg.AsyncInvokeTimeout = c.GetIntWithDef("/tars/application/client<async-invoke-timeout>", AsyncInvokeTimeout)
a.cltCfg.RefreshEndpointInterval = c.GetIntWithDef("/tars/application/client<refresh-endpoint-interval>", refreshEndpointInterval)
a.cltCfg.ReportInterval = c.GetIntWithDef("/tars/application/client<report-interval>", reportInterval)
Expand Down
4 changes: 3 additions & 1 deletion tars/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ type clientConfig struct {
ReportInterval int
CheckStatusInterval int
KeepAliveInterval int
AsyncInvokeTimeout int
// add client timeout
SyncInvokeTimeout int
AsyncInvokeTimeout int
ClientQueueLen int
ClientIdleTimeout time.Duration
ClientReadTimeout time.Duration
Expand Down Expand Up @@ -152,6 +153,7 @@ func newClientConfig() *clientConfig {
ReportInterval: reportInterval,
CheckStatusInterval: checkStatusInterval,
KeepAliveInterval: keepAliveInterval,
SyncInvokeTimeout: SyncInvokeTimeout,
AsyncInvokeTimeout: AsyncInvokeTimeout,
ClientQueueLen: ClientQueueLen,
ClientIdleTimeout: tools.ParseTimeOut(ClientIdleTimeout),
Expand Down
61 changes: 61 additions & 0 deletions tars/message.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package tars

import (
"context"
"time"

"github.com/TarsCloud/TarsGo/tars/model"
"github.com/TarsCloud/TarsGo/tars/protocol/res/basef"
"github.com/TarsCloud/TarsGo/tars/protocol/res/requestf"
"github.com/TarsCloud/TarsGo/tars/selector"
"github.com/TarsCloud/TarsGo/tars/util/current"
"github.com/TarsCloud/TarsGo/tars/util/tools"
)

// HashType is the hash type
Expand All @@ -31,6 +36,8 @@ type Message struct {
hashCode uint32
hashType HashType
isHash bool
Async bool
Callback model.Callback
}

// Init define the beginTime
Expand Down Expand Up @@ -66,3 +73,57 @@ func (m *Message) HashType() selector.HashType {
func (m *Message) IsHash() bool {
return m.isHash
}

func buildMessage(ctx context.Context, cType byte,
sFuncName string,
buf []byte,
status map[string]string,
reqContext map[string]string,
resp *requestf.ResponsePacket,
s *ServantProxy) *Message {

// 将ctx中的dyeing信息传入到request中
var msgType int32
if dyeingKey, ok := current.GetDyeingKey(ctx); ok {
TLOG.Debug("dyeing debug: find dyeing key:", dyeingKey)
if status == nil {
status = make(map[string]string)
}
status[current.StatusDyedKey] = dyeingKey
msgType |= basef.TARSMESSAGETYPEDYED
}

// 将ctx中的trace信息传入到request中
if trace, ok := current.GetTarsTrace(ctx); ok && trace.Call() {
traceKey := trace.GetTraceFullKey(false)
TLOG.Debug("trace debug: find trace key:", traceKey)
if status == nil {
status = make(map[string]string)
}
status[current.StatusTraceKey] = traceKey
msgType |= basef.TARSMESSAGETYPETRACE
}

req := requestf.RequestPacket{
IVersion: s.version,
CPacketType: int8(cType),
IMessageType: msgType,
IRequestId: s.genRequestID(),
SServantName: s.name,
SFuncName: sFuncName,
ITimeout: int32(s.syncTimeout),
SBuffer: tools.ByteToInt8(buf),
Context: reqContext,
Status: status,
}
msg := &Message{Req: &req, Ser: s, Resp: resp}
msg.Init()

if ok, hashType, hashCode, isHash := current.GetClientHash(ctx); ok {
msg.isHash = isHash
msg.hashType = HashType(hashType)
msg.hashCode = hashCode
}

return msg
}
13 changes: 13 additions & 0 deletions tars/model/servant.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import (
"github.com/TarsCloud/TarsGo/tars/protocol/res/requestf"
)

type Callback interface {
Dispatch(context.Context, *requestf.RequestPacket, *requestf.ResponsePacket, error) (int32, error)
}

// Servant is interface for call the remote server.
type Servant interface {
Name() string
Expand All @@ -17,6 +21,15 @@ type Servant interface {
status map[string]string,
context map[string]string,
resp *requestf.ResponsePacket) error

TarsInvokeAsync(ctx context.Context, cType byte,
sFuncName string,
buf []byte,
status map[string]string,
context map[string]string,
resp *requestf.ResponsePacket,
callback Callback) error

TarsSetTimeout(t int)
TarsSetProtocol(Protocol)
Endpoints() []*endpoint.Endpoint
Expand Down
155 changes: 95 additions & 60 deletions tars/servant.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"github.com/TarsCloud/TarsGo/tars/util/current"
"github.com/TarsCloud/TarsGo/tars/util/endpoint"
"github.com/TarsCloud/TarsGo/tars/util/rtimer"
"github.com/TarsCloud/TarsGo/tars/util/tools"
)

var (
Expand All @@ -31,13 +30,14 @@ const (

// ServantProxy tars servant proxy instance
type ServantProxy struct {
name string
comm *Communicator
manager EndpointManager
timeout int
version int16
proto model.Protocol
queueLen int32
name string
comm *Communicator
manager EndpointManager
syncTimeout int
asyncTimeout int
version int16
proto model.Protocol
queueLen int32

pushCallback func([]byte)
}
Expand All @@ -51,7 +51,6 @@ func newServantProxy(comm *Communicator, objName string, opts ...EndpointManager
s := &ServantProxy{
comm: comm,
proto: &protocol.TarsProtocol{},
timeout: comm.Client.AsyncInvokeTimeout,
version: basef.TARSVERSION,
}
pos := strings.Index(objName, "@")
Expand All @@ -67,6 +66,12 @@ func newServantProxy(comm *Communicator, objName string, opts ...EndpointManager

// init manager
s.manager = GetManager(comm, objName, opts...)

s.comm = comm
s.proto = &protocol.TarsProtocol{}
s.syncTimeout = s.comm.Client.SyncInvokeTimeout
s.asyncTimeout = s.comm.Client.AsyncInvokeTimeout
s.version = basef.TARSVERSION
return s
}

Expand All @@ -77,7 +82,7 @@ func (s *ServantProxy) Name() string {

// TarsSetTimeout sets the timeout for client calling the server , which is in ms.
func (s *ServantProxy) TarsSetTimeout(t int) {
s.timeout = t
s.syncTimeout = t
}

// TarsSetVersion set tars version
Expand Down Expand Up @@ -122,53 +127,44 @@ func (s *ServantProxy) TarsInvoke(ctx context.Context, cType byte,
resp *requestf.ResponsePacket) error {
defer CheckPanic()

// 将ctx中的dyeing信息传入到request中
var msgType int32
if dyeingKey, ok := current.GetDyeingKey(ctx); ok {
TLOG.Debug("dyeing debug: find dyeing key:", dyeingKey)
if status == nil {
status = make(map[string]string)
}
status[current.StatusDyedKey] = dyeingKey
msgType |= basef.TARSMESSAGETYPEDYED
}
msg := buildMessage(ctx, cType, sFuncName, buf, status, reqContext, resp, s)
timeout := time.Duration(s.syncTimeout) * time.Millisecond
err := s.invokeFilters(ctx, msg, timeout)

// 将ctx中的trace信息传入到request中
if trace, ok := current.GetTarsTrace(ctx); ok && trace.Call() {
traceKey := trace.GetTraceFullKey(false)
TLOG.Debug("trace debug: find trace key:", traceKey)
if status == nil {
status = make(map[string]string)
}
status[current.StatusTraceKey] = traceKey
msgType |= basef.TARSMESSAGETYPETRACE
if err != nil {
return err
}
*resp = *msg.Resp
return nil
}

req := requestf.RequestPacket{
IVersion: s.version,
CPacketType: int8(cType),
IRequestId: s.genRequestID(),
SServantName: s.name,
SFuncName: sFuncName,
SBuffer: tools.ByteToInt8(buf),
ITimeout: int32(s.timeout),
Context: reqContext,
Status: status,
IMessageType: msgType,
}
msg := &Message{Req: &req, Ser: s, Resp: resp}
msg.Init()

timeout := time.Duration(s.timeout) * time.Millisecond
if ok, hashType, hashCode, isHash := current.GetClientHash(ctx); ok {
msg.isHash = isHash
msg.hashType = HashType(hashType)
msg.hashCode = hashCode
// TarsInvokeAsync is used for client invoking server.
func (s *ServantProxy) TarsInvokeAsync(ctx context.Context, cType byte,
sFuncName string,
buf []byte,
status map[string]string,
reqContext map[string]string,
resp *requestf.ResponsePacket,
callback model.Callback) error {
defer CheckPanic()

msg := buildMessage(ctx, cType, sFuncName, buf, status, reqContext, resp, s)
msg.Req.ITimeout = int32(s.asyncTimeout)
if callback == nil {
msg.Req.CPacketType = basef.TARSONEWAY
} else {
msg.Async = true
msg.Callback = callback
}

timeout := time.Duration(s.asyncTimeout) * time.Millisecond
return s.invokeFilters(ctx, msg, timeout)
}

func (s *ServantProxy) invokeFilters(ctx context.Context, msg *Message, timeout time.Duration) error {
if ok, to, isTimeout := current.GetClientTimeout(ctx); ok && isTimeout {
timeout = time.Duration(to) * time.Millisecond
req.ITimeout = int32(to)
msg.Req.ITimeout = int32(to)
}

var err error
Expand Down Expand Up @@ -196,27 +192,32 @@ func (s *ServantProxy) TarsInvoke(ctx context.Context, cType byte,
}
}
}
s.manager.postInvoke()
// no async rpc call
if !msg.Async {
s.manager.postInvoke()
msg.End()
s.reportStat(msg, err)
}

return err
}

func (s *ServantProxy) reportStat(msg *Message, err error) {
if err != nil {
msg.End()
TLOG.Errorf("Invoke error: %s, %s, %v, cost:%d", s.name, sFuncName, err.Error(), msg.Cost())
TLOG.Errorf("Invoke error: %s, %s, %v, cost:%d", s.name, msg.Req.SFuncName, err.Error(), msg.Cost())
if msg.Resp == nil {
ReportStat(msg, StatSuccess, StatSuccess, StatFailed)
} else if msg.Status == basef.TARSINVOKETIMEOUT {
ReportStat(msg, StatSuccess, StatFailed, StatSuccess)
} else {
ReportStat(msg, StatSuccess, StatSuccess, StatFailed)
}
return err
return
}
msg.End()
*resp = *msg.Resp
ReportStat(msg, StatFailed, StatSuccess, StatSuccess)
return err
}

func (s *ServantProxy) doInvoke(ctx context.Context, msg *Message, timeout time.Duration) error {
func (s *ServantProxy) doInvoke(ctx context.Context, msg *Message, timeout time.Duration) (err error) {
adp, needCheck := s.manager.SelectAdapterProxy(msg)
if adp == nil {
return errors.New("no adapter Proxy selected:" + msg.Req.SServantName)
Expand All @@ -239,19 +240,53 @@ func (s *ServantProxy) doInvoke(ctx context.Context, msg *Message, timeout time.
atomic.AddInt32(&s.queueLen, 1)
readCh := make(chan *requestf.ResponsePacket)
adp.resp.Store(msg.Req.IRequestId, readCh)
defer func() {
var releaseFunc = func() {
CheckPanic()
atomic.AddInt32(&s.queueLen, -1)
adp.resp.Delete(msg.Req.IRequestId)
}
defer func() {
if !msg.Async || err != nil {
releaseFunc()
}
}()
if err := adp.Send(msg.Req); err != nil {

if err = adp.Send(msg.Req); err != nil {
adp.failAdd()
return err
}

if msg.Req.CPacketType == basef.TARSONEWAY {
adp.successAdd()
return nil
}

// async call rpc
if msg.Async {
go func() {
defer releaseFunc()
err := s.waitInvoke(msg, adp, timeout, needCheck)
s.manager.postInvoke()
msg.End()
s.reportStat(msg, err)
if msg.Status != basef.TARSINVOKETIMEOUT {
current.SetResponseContext(ctx, msg.Resp.Context)
current.SetResponseStatus(ctx, msg.Resp.Status)
}
if _, err := msg.Callback.Dispatch(ctx, msg.Req, msg.Resp, err); err != nil {
TLOG.Errorf("Callback error: %s, %s, %+v", s.name, msg.Req.SFuncName, err)
}
}()
return nil
}

return s.waitInvoke(msg, adp, timeout, needCheck)
}

func (s *ServantProxy) waitInvoke(msg *Message, adp *AdapterProxy, timeout time.Duration, needCheck bool) error {
ch, _ := adp.resp.Load(msg.Req.IRequestId)
readCh := ch.(chan *requestf.ResponsePacket)

select {
case <-rtimer.After(timeout):
msg.Status = basef.TARSINVOKETIMEOUT
Expand Down
4 changes: 3 additions & 1 deletion tars/setting.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@ const (
// communicator default ,update from remote config
refreshEndpointInterval int = 60000
reportInterval int = 5000
// SyncInvokeTimeout sync invoke timeout
SyncInvokeTimeout int = 3000
// AsyncInvokeTimeout async invoke timeout
AsyncInvokeTimeout int = 3000
AsyncInvokeTimeout int = 5000

// check endpoint status every 1000 ms
checkStatusInterval int = 1000
Expand Down

0 comments on commit fe8a4e9

Please sign in to comment.