diff --git a/tars/message.go b/tars/message.go index 9e001ba2..4bfb66e0 100644 --- a/tars/message.go +++ b/tars/message.go @@ -38,6 +38,7 @@ type Message struct { isHash bool Async bool Callback model.Callback + RespCh chan *requestf.ResponsePacket } // Init define the beginTime diff --git a/tars/servant.go b/tars/servant.go index 6e4551d8..65eea035 100755 --- a/tars/servant.go +++ b/tars/servant.go @@ -129,9 +129,7 @@ func (s *ServantProxy) TarsInvoke(ctx context.Context, cType byte, msg := buildMessage(ctx, cType, sFuncName, buf, status, reqContext, resp, s) timeout := time.Duration(s.syncTimeout) * time.Millisecond - err := s.invokeFilters(ctx, msg, timeout) - - if err != nil { + if err := s.invokeFilters(ctx, msg, timeout); err != nil { return err } *resp = *msg.Resp @@ -238,8 +236,8 @@ func (s *ServantProxy) doInvoke(ctx context.Context, msg *Message, timeout time. } atomic.AddInt32(&s.queueLen, 1) - readCh := make(chan *requestf.ResponsePacket) - adp.resp.Store(msg.Req.IRequestId, readCh) + msg.RespCh = make(chan *requestf.ResponsePacket) + adp.resp.Store(msg.Req.IRequestId, msg.RespCh) var releaseFunc = func() { CheckPanic() atomic.AddInt32(&s.queueLen, -1) @@ -265,7 +263,7 @@ func (s *ServantProxy) doInvoke(ctx context.Context, msg *Message, timeout time. if msg.Async { go func() { defer releaseFunc() - err := s.waitInvoke(msg, adp, timeout, needCheck) + err := s.waitResp(msg, timeout, needCheck) s.manager.postInvoke() msg.End() s.reportStat(msg, err) @@ -280,21 +278,18 @@ func (s *ServantProxy) doInvoke(ctx context.Context, msg *Message, timeout time. return nil } - return s.waitInvoke(msg, adp, timeout, needCheck) + return s.waitResp(msg, timeout, needCheck) } -func (s *ServantProxy) waitInvoke(msg *Message, adp *AdapterProxy, timeout time.Duration, needCheck bool) error { - ch, _ := adp.resp.Load(msg.Req.IRequestId) - readCh := ch.(chan *requestf.ResponsePacket) - +func (s *ServantProxy) waitResp(msg *Message, timeout time.Duration, needCheck bool) error { + adp := msg.Adp select { case <-rtimer.After(timeout): msg.Status = basef.TARSINVOKETIMEOUT adp.failAdd() - msg.End() return fmt.Errorf("request timeout, begin time:%d, cost:%d, obj:%s, func:%s, addr:(%s:%d), reqid:%d", msg.BeginTime, msg.Cost(), msg.Req.SServantName, msg.Req.SFuncName, adp.point.Host, adp.point.Port, msg.Req.IRequestId) - case msg.Resp = <-readCh: + case msg.Resp = <-msg.RespCh: if needCheck { go func() { adp.reset() diff --git a/tars/tools/tars2go/gen_go.go b/tars/tools/tars2go/gen_go.go index 2c187f65..f91e7afc 100755 --- a/tars/tools/tars2go/gen_go.go +++ b/tars/tools/tars2go/gen_go.go @@ -1257,6 +1257,9 @@ func (gen *GenGo) genIFProxyFun(interfName string, fun *FunInfo, withContext boo // trace if !isOneWay && !withoutTrace { + if isAsync { + c.WriteString(`if callback != nil {`) + } c.WriteString(` trace, ok := current.GetTarsTrace(tarsCtx) if ok && trace.Call() { @@ -1280,6 +1283,9 @@ if ok && trace.Call() { } tars.Trace(trace.GetTraceKey(tarstrace.EstCS), tarstrace.AnnotationCS, tars.GetClientConfig().ModuleName, obj.servant.Name(), "` + fun.Name + `", 0, traceParam, "") }`) + if isAsync { + c.WriteString(`}`) + } c.WriteString("\n\n") } c.WriteString(`var statusMap map[string]string