From a30b34e137bfacebff5a39d13792c00044bee937 Mon Sep 17 00:00:00 2001 From: what Date: Thu, 1 Jun 2023 10:28:08 +0800 Subject: [PATCH] =?UTF-8?q?[feat]=20=E5=B0=86=20gRPC=20recv=20=E7=9A=84=20?= =?UTF-8?q?channel=20=E6=94=B9=E4=B8=BA=20io=20read/write,=20=E6=96=B9?= =?UTF-8?q?=E4=BE=BF=E4=BD=BF=E7=94=A8=20json.Decoder=20=E8=A7=A3=E6=9E=90?= =?UTF-8?q?=20message?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- grpcall.go | 95 ++++++++++++++++++++++++++----------------------- grpcall_test.go | 74 +++++++++++++++++--------------------- request.go | 10 ++++-- request_test.go | 29 ++++++++------- response.go | 51 +++++++++++--------------- 5 files changed, 126 insertions(+), 133 deletions(-) diff --git a/grpcall.go b/grpcall.go index 29ddc4a..a78ce40 100644 --- a/grpcall.go +++ b/grpcall.go @@ -1,7 +1,9 @@ package grpcall import ( + "bytes" "context" + "encoding/json" "fmt" "io" "sort" @@ -67,7 +69,7 @@ func (this Grpcall) invokeUnary(ctx context.Context, stub grpcdynamic.Stub, meth // 创建请求 message req := msgFactory.NewMessage(method.GetInputType()) // 将数据格式化到 message - if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil { + if err := NewJsonRequestParser(anyResolver, strings.NewReader(data)).Next(req); err != nil { return nil, err } @@ -86,14 +88,21 @@ func (this Grpcall) invokeUnary(ctx context.Context, stub grpcdynamic.Stub, meth AnyResolver: anyResolver, } - respData, err := marshaler.MarshalToString(respMsg) + pReader, pWriter := io.Pipe() resp := &Response{ method: method, header: respHeaders, - data: respData, + recv: pReader, } + go func() { + defer pWriter.Close() + if err = marshaler.Marshal(pWriter, respMsg); err != nil { + pWriter.CloseWithError(err) + } + }() + return resp, err } @@ -101,7 +110,7 @@ func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub, // 创建请求 message req := msgFactory.NewMessage(method.GetInputType()) // 将数据格式化到 message - if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil { + if err := NewJsonRequestParser(anyResolver, strings.NewReader(data)).Next(req); err != nil { return nil, err } @@ -114,11 +123,12 @@ func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub, return nil, err } + pReader, pWriter := io.Pipe() + resp := &Response{ method: method, header: &metadata.MD{}, - recv: make(chan string), - done: make(chan error), + recv: pReader, cancel: cancel, } @@ -129,9 +139,8 @@ func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub, } defer func() { - close(resp.done) - close(resp.recv) cancel() + pWriter.Close() }() for { @@ -139,25 +148,22 @@ func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub, msg, err := stream.RecvMsg() if err == io.EOF { - resp.done <- nil return } if err != nil { - resp.done <- err + pWriter.CloseWithError(err) return } if *resp.header, err = stream.Header(); err != nil { - resp.done <- err + pWriter.CloseWithError(err) return } - if data, err := msgParser.MarshalToString(msg); err != nil { - resp.done <- err + if err = msgParser.Marshal(pWriter, msg); err != nil { + pWriter.CloseWithError(err) return - } else { - resp.recv <- data } msg.Reset() @@ -171,7 +177,7 @@ func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub, // 创建请求 message req := msgFactory.NewMessage(method.GetInputType()) // 将数据格式化到 message - if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil { + if err := NewJsonRequestParser(anyResolver, strings.NewReader(data)).Next(req); err != nil { return nil, err } @@ -189,49 +195,54 @@ func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub, return nil, err } + pRecvReader, pRecvWriter := io.Pipe() + resp := &Response{ - method: method, - header: &metadata.MD{}, - recv: make(chan string), - send: make(chan string), - done: make(chan error), - cancel: cancel, + method: method, + header: &metadata.MD{}, + recv: pRecvReader, + send: make(chan []byte), + sendCompleted: make(chan error), + cancel: cancel, } // send go func() { defer func() { stream.CloseSend() - close(resp.send) cancel() + close(resp.sendCompleted) + close(resp.send) }() + reqParser := jsonpb.Unmarshaler{AnyResolver: anyResolver} + for { select { + case <-cancel_ctx.Done(): + resp.sendCompleted <- nil + return + case <-ctx.Done(): + resp.sendCompleted <- nil + return case msg, ok := <-resp.send: if !ok { - resp.done <- nil + resp.sendCompleted <- fmt.Errorf("read send message error") return } - // 将数据格式化到 message - if err := NewJsonRequestParser(anyResolver, msg).Next(req); err != nil { - resp.done <- err + if err := reqParser.UnmarshalNext(json.NewDecoder(bytes.NewReader(msg)), req); err != nil { + resp.sendCompleted <- err return } if err := stream.SendMsg(req); err != nil { - resp.done <- err + resp.sendCompleted <- err return } - - case <-cancel_ctx.Done(): - resp.done <- nil - return - case <-ctx.Done(): - resp.done <- nil - return } + + resp.sendCompleted <- nil } }() @@ -243,9 +254,8 @@ func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub, } defer func() { - close(resp.done) - close(resp.recv) cancel() + pRecvWriter.Close() }() for { @@ -253,25 +263,22 @@ func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub, msg, err := stream.RecvMsg() if err == io.EOF { - resp.done <- nil return } if err != nil { - resp.done <- err + pRecvWriter.CloseWithError(err) return } if *resp.header, err = stream.Header(); err != nil { - resp.done <- err + pRecvWriter.CloseWithError(err) return } - if data, err := msgParser.MarshalToString(msg); err != nil { - resp.done <- err + if err = msgParser.Marshal(pRecvWriter, msg); err != nil { + pRecvWriter.CloseWithError(err) return - } else { - resp.recv <- data } msg.Reset() diff --git a/grpcall_test.go b/grpcall_test.go index 0175eaf..cae9cff 100644 --- a/grpcall_test.go +++ b/grpcall_test.go @@ -2,6 +2,7 @@ package grpcall import ( "fmt" + "io" "os" "testing" "time" @@ -46,7 +47,10 @@ func TestInvokeUnary(t *testing.T) { if resp, err := gc.Invoke("User.UserResourceStatus", "TestInt64Value", `123`, nil); err != nil { t.Error(err) } else { - t.Log(resp.Data()) + // t.Log(ioutil.ReadAll(resp.Recv())) + str := "" + decoder := resp.RecvDecoder() + t.Log(decoder.Decode(&str), str) } } @@ -54,27 +58,24 @@ func TestInvokeUnary(t *testing.T) { func TestInvokeServStream(t *testing.T) { if resp, err := gc.Invoke("User.UserResourceStatus", "GetEvent", `{}`, nil); err != nil { t.Error(err) - } else if recv, err := resp.Recv(); err != nil { - t.Error(err) - } else if done, err := resp.Done(); err != nil { - t.Error(err) } else { - flag := make(chan bool) - go func() { - for { - select { - case msg := <-recv: - fmt.Println("msg", msg, resp.Header()) - case err := <-done: - if err != nil { - t.Error(err) - } - flag <- true - return + defer resp.Close() + decoder := resp.RecvDecoder() + + for { + data := struct { + Uuid string `json:"uuid"` + Type string `json:"type"` + }{} + + if err := decoder.Decode(&data); err != nil { + if err != io.EOF && err != io.ErrClosedPipe { + t.Error(err) } + break } - }() - <-flag + fmt.Println(data) + } } } @@ -82,41 +83,32 @@ func TestInvokeServStream(t *testing.T) { func TestInvokeBidiStream(t *testing.T) { if resp, err := gc.Invoke("User.UserResourceStatus", "TestBidiStream", `"hello"`, nil); err != nil { t.Error(err) - } else if send, err := resp.Send(); err != nil { - t.Error(err) - } else if recv, err := resp.Recv(); err != nil { - t.Error(err) - } else if done, err := resp.Done(); err != nil { - t.Error(err) } else { - flag := make(chan bool) + defer resp.Close() + + decoder := resp.RecvDecoder() go func() { - timer := time.NewTimer(time.Second) - defer timer.Stop() - for { - select { - case msg := <-recv: - fmt.Println("msg", msg, resp.Header()) - case err := <-done: - if err != nil { + data := struct { + Uuid string `json:"uuid"` + Type string `json:"type"` + }{} + + if err := decoder.Decode(&data); err != nil { + if err != io.EOF && err != io.ErrClosedPipe { t.Error(err) } - flag <- true return - case <-timer.C: - resp.Cancel() - timer.Reset(time.Second) } - + fmt.Println(333, data) } }() for i := 0; i < 10; i++ { - send <- fmt.Sprintf(`"abc %d"`, i) + fmt.Println("消息发送, ", i, resp.Send(fmt.Sprintf(`"abc %d"`, i))) } - <-flag + time.Sleep(time.Second) } } diff --git a/request.go b/request.go index a9dafce..6624c6f 100644 --- a/request.go +++ b/request.go @@ -3,7 +3,7 @@ package grpcall import ( "encoding/json" "fmt" - "strings" + "io" "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" @@ -22,14 +22,18 @@ type JsonRequestParser struct { func (this JsonRequestParser) Next(msg proto.Message) error { if err := this.unmarshaler.UnmarshalNext(this.dec, msg); err != nil { + if err == io.EOF { + return err + } + return fmt.Errorf("unmarshal request json data error, %s", err) } return nil } -func NewJsonRequestParser(resolver jsonpb.AnyResolver, data string) RequestParser { +func NewJsonRequestParser(resolver jsonpb.AnyResolver, in io.Reader) RequestParser { return &JsonRequestParser{ - dec: json.NewDecoder(strings.NewReader(data)), + dec: json.NewDecoder(in), unmarshaler: jsonpb.Unmarshaler{AnyResolver: resolver}, } } diff --git a/request_test.go b/request_test.go index 511a668..0f4204b 100644 --- a/request_test.go +++ b/request_test.go @@ -6,7 +6,6 @@ import ( "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/ptypes/wrappers" "google.golang.org/protobuf/types/known/emptypb" - "google.golang.org/protobuf/types/known/wrapperspb" ) func TestJsonpbUnmarshalString(t *testing.T) { @@ -43,24 +42,24 @@ func TestJsonpbMarshalEmpty(t *testing.T) { } } -func TestNewJsonRequestParser(t *testing.T) { - var str_msg wrapperspb.StringValue +// func TestNewJsonRequestParser(t *testing.T) { +// var str_msg wrapperspb.StringValue - inData := `"abc"` +// inData := `"abc"` - if err := NewJsonRequestParser(nil, inData).Next(&str_msg); err != nil { - t.Error(err) - } +// if err := NewJsonRequestParser(nil, inData).Next(&str_msg); err != nil { +// t.Error(err) +// } - t.Log("string", str_msg.GetValue()) +// t.Log("string", str_msg.GetValue()) - var int_msg wrapperspb.Int64Value +// var int_msg wrapperspb.Int64Value - inData = `"10"` // or `10` +// inData = `"10"` // or `10` - if err := NewJsonRequestParser(nil, inData).Next(&int_msg); err != nil { - t.Error(err) - } +// if err := NewJsonRequestParser(nil, inData).Next(&int_msg); err != nil { +// t.Error(err) +// } - t.Log("int", int_msg.GetValue()) -} +// t.Log("int", int_msg.GetValue()) +// } diff --git a/response.go b/response.go index 7db359a..c7a4745 100644 --- a/response.go +++ b/response.go @@ -2,59 +2,50 @@ package grpcall import ( "context" + "encoding/json" "fmt" + "io" "github.com/jhump/protoreflect/desc" "google.golang.org/grpc/metadata" ) type Response struct { - method *desc.MethodDescriptor - header *metadata.MD - data string - send chan string - recv chan string - done chan error - cancel context.CancelFunc + method *desc.MethodDescriptor + header *metadata.MD + recv *io.PipeReader + send chan []byte + sendCompleted chan error + cancel context.CancelFunc } func (this Response) Header() metadata.MD { return *this.header } -func (this Response) Data() (string, error) { - if this.method.IsClientStreaming() || this.method.IsServerStreaming() { - return "", fmt.Errorf("%q is %s, cannot use unary data", this.method.GetFullyQualifiedName(), this.GetCategory()) - } - return this.data, nil +func (this Response) Recv() io.Reader { + return this.recv +} + +func (this Response) RecvDecoder() *json.Decoder { + return json.NewDecoder(this.Recv()) } // 发送数据流 -func (this Response) Send() (chan<- string, error) { +func (this Response) Send(data string) error { if !this.method.IsClientStreaming() || !this.method.IsServerStreaming() { - return nil, fmt.Errorf("%q is %s, cannot use send streaming func", this.method.GetFullyQualifiedName(), this.GetCategory()) + return fmt.Errorf("%q is %s, not streaming gRPC", this.method.GetFullyQualifiedName(), this.GetCategory()) } - return this.send, nil -} -// 接收数据流 -func (this Response) Recv() (<-chan string, error) { - if !this.method.IsServerStreaming() { - return nil, fmt.Errorf("%q is %s, cannot use streaming func", this.method.GetFullyQualifiedName(), this.GetCategory()) - } - return this.recv, nil -} + this.send <- []byte(data) -// 数据是否响应结束 -func (this Response) Done() (<-chan error, error) { - if !this.method.IsServerStreaming() { - return nil, fmt.Errorf("%q is %s, cannot use streaming func", this.method.GetFullyQualifiedName(), this.GetCategory()) - } - return this.done, nil + return <-this.sendCompleted } // 取消接收数据 -func (this Response) Cancel() { +func (this Response) Close() { + this.recv.Close() + if this.method.IsServerStreaming() { this.cancel() }