[feat] 将 gRPC recv 的 channel 改为 io read/write, 方便使用 json.Decoder 解析 message

This commit is contained in:
what 2023-06-01 10:28:08 +08:00
parent d8f58df39a
commit a30b34e137
5 changed files with 126 additions and 133 deletions

View File

@ -1,7 +1,9 @@
package grpcall
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"sort"
@ -67,7 +69,7 @@ func (this Grpcall) invokeUnary(ctx context.Context, stub grpcdynamic.Stub, meth
// 创建请求 message
req := msgFactory.NewMessage(method.GetInputType())
// 将数据格式化到 message
if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil {
if err := NewJsonRequestParser(anyResolver, strings.NewReader(data)).Next(req); err != nil {
return nil, err
}
@ -86,14 +88,21 @@ func (this Grpcall) invokeUnary(ctx context.Context, stub grpcdynamic.Stub, meth
AnyResolver: anyResolver,
}
respData, err := marshaler.MarshalToString(respMsg)
pReader, pWriter := io.Pipe()
resp := &Response{
method: method,
header: respHeaders,
data: respData,
recv: pReader,
}
go func() {
defer pWriter.Close()
if err = marshaler.Marshal(pWriter, respMsg); err != nil {
pWriter.CloseWithError(err)
}
}()
return resp, err
}
@ -101,7 +110,7 @@ func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub,
// 创建请求 message
req := msgFactory.NewMessage(method.GetInputType())
// 将数据格式化到 message
if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil {
if err := NewJsonRequestParser(anyResolver, strings.NewReader(data)).Next(req); err != nil {
return nil, err
}
@ -114,11 +123,12 @@ func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub,
return nil, err
}
pReader, pWriter := io.Pipe()
resp := &Response{
method: method,
header: &metadata.MD{},
recv: make(chan string),
done: make(chan error),
recv: pReader,
cancel: cancel,
}
@ -129,9 +139,8 @@ func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub,
}
defer func() {
close(resp.done)
close(resp.recv)
cancel()
pWriter.Close()
}()
for {
@ -139,25 +148,22 @@ func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub,
msg, err := stream.RecvMsg()
if err == io.EOF {
resp.done <- nil
return
}
if err != nil {
resp.done <- err
pWriter.CloseWithError(err)
return
}
if *resp.header, err = stream.Header(); err != nil {
resp.done <- err
pWriter.CloseWithError(err)
return
}
if data, err := msgParser.MarshalToString(msg); err != nil {
resp.done <- err
if err = msgParser.Marshal(pWriter, msg); err != nil {
pWriter.CloseWithError(err)
return
} else {
resp.recv <- data
}
msg.Reset()
@ -171,7 +177,7 @@ func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub,
// 创建请求 message
req := msgFactory.NewMessage(method.GetInputType())
// 将数据格式化到 message
if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil {
if err := NewJsonRequestParser(anyResolver, strings.NewReader(data)).Next(req); err != nil {
return nil, err
}
@ -189,12 +195,14 @@ func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub,
return nil, err
}
pRecvReader, pRecvWriter := io.Pipe()
resp := &Response{
method: method,
header: &metadata.MD{},
recv: make(chan string),
send: make(chan string),
done: make(chan error),
recv: pRecvReader,
send: make(chan []byte),
sendCompleted: make(chan error),
cancel: cancel,
}
@ -202,36 +210,39 @@ func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub,
go func() {
defer func() {
stream.CloseSend()
close(resp.send)
cancel()
close(resp.sendCompleted)
close(resp.send)
}()
reqParser := jsonpb.Unmarshaler{AnyResolver: anyResolver}
for {
select {
case <-cancel_ctx.Done():
resp.sendCompleted <- nil
return
case <-ctx.Done():
resp.sendCompleted <- nil
return
case msg, ok := <-resp.send:
if !ok {
resp.done <- nil
resp.sendCompleted <- fmt.Errorf("read send message error")
return
}
// 将数据格式化到 message
if err := NewJsonRequestParser(anyResolver, msg).Next(req); err != nil {
resp.done <- err
if err := reqParser.UnmarshalNext(json.NewDecoder(bytes.NewReader(msg)), req); err != nil {
resp.sendCompleted <- err
return
}
if err := stream.SendMsg(req); err != nil {
resp.done <- err
resp.sendCompleted <- err
return
}
}
case <-cancel_ctx.Done():
resp.done <- nil
return
case <-ctx.Done():
resp.done <- nil
return
}
resp.sendCompleted <- nil
}
}()
@ -243,9 +254,8 @@ func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub,
}
defer func() {
close(resp.done)
close(resp.recv)
cancel()
pRecvWriter.Close()
}()
for {
@ -253,25 +263,22 @@ func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub,
msg, err := stream.RecvMsg()
if err == io.EOF {
resp.done <- nil
return
}
if err != nil {
resp.done <- err
pRecvWriter.CloseWithError(err)
return
}
if *resp.header, err = stream.Header(); err != nil {
resp.done <- err
pRecvWriter.CloseWithError(err)
return
}
if data, err := msgParser.MarshalToString(msg); err != nil {
resp.done <- err
if err = msgParser.Marshal(pRecvWriter, msg); err != nil {
pRecvWriter.CloseWithError(err)
return
} else {
resp.recv <- data
}
msg.Reset()

View File

@ -2,6 +2,7 @@ package grpcall
import (
"fmt"
"io"
"os"
"testing"
"time"
@ -46,7 +47,10 @@ func TestInvokeUnary(t *testing.T) {
if resp, err := gc.Invoke("User.UserResourceStatus", "TestInt64Value", `123`, nil); err != nil {
t.Error(err)
} else {
t.Log(resp.Data())
// t.Log(ioutil.ReadAll(resp.Recv()))
str := ""
decoder := resp.RecvDecoder()
t.Log(decoder.Decode(&str), str)
}
}
@ -54,27 +58,24 @@ func TestInvokeUnary(t *testing.T) {
func TestInvokeServStream(t *testing.T) {
if resp, err := gc.Invoke("User.UserResourceStatus", "GetEvent", `{}`, nil); err != nil {
t.Error(err)
} else if recv, err := resp.Recv(); err != nil {
t.Error(err)
} else if done, err := resp.Done(); err != nil {
t.Error(err)
} else {
flag := make(chan bool)
go func() {
defer resp.Close()
decoder := resp.RecvDecoder()
for {
select {
case msg := <-recv:
fmt.Println("msg", msg, resp.Header())
case err := <-done:
if err != nil {
data := struct {
Uuid string `json:"uuid"`
Type string `json:"type"`
}{}
if err := decoder.Decode(&data); err != nil {
if err != io.EOF && err != io.ErrClosedPipe {
t.Error(err)
}
flag <- true
return
break
}
fmt.Println(data)
}
}()
<-flag
}
}
@ -82,41 +83,32 @@ func TestInvokeServStream(t *testing.T) {
func TestInvokeBidiStream(t *testing.T) {
if resp, err := gc.Invoke("User.UserResourceStatus", "TestBidiStream", `"hello"`, nil); err != nil {
t.Error(err)
} else if send, err := resp.Send(); err != nil {
t.Error(err)
} else if recv, err := resp.Recv(); err != nil {
t.Error(err)
} else if done, err := resp.Done(); err != nil {
t.Error(err)
} else {
flag := make(chan bool)
defer resp.Close()
decoder := resp.RecvDecoder()
go func() {
timer := time.NewTimer(time.Second)
defer timer.Stop()
for {
select {
case msg := <-recv:
fmt.Println("msg", msg, resp.Header())
case err := <-done:
if err != nil {
data := struct {
Uuid string `json:"uuid"`
Type string `json:"type"`
}{}
if err := decoder.Decode(&data); err != nil {
if err != io.EOF && err != io.ErrClosedPipe {
t.Error(err)
}
flag <- true
return
case <-timer.C:
resp.Cancel()
timer.Reset(time.Second)
}
fmt.Println(333, data)
}
}()
for i := 0; i < 10; i++ {
send <- fmt.Sprintf(`"abc %d"`, i)
fmt.Println("消息发送, ", i, resp.Send(fmt.Sprintf(`"abc %d"`, i)))
}
<-flag
time.Sleep(time.Second)
}
}

View File

@ -3,7 +3,7 @@ package grpcall
import (
"encoding/json"
"fmt"
"strings"
"io"
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
@ -22,14 +22,18 @@ type JsonRequestParser struct {
func (this JsonRequestParser) Next(msg proto.Message) error {
if err := this.unmarshaler.UnmarshalNext(this.dec, msg); err != nil {
if err == io.EOF {
return err
}
return fmt.Errorf("unmarshal request json data error, %s", err)
}
return nil
}
func NewJsonRequestParser(resolver jsonpb.AnyResolver, data string) RequestParser {
func NewJsonRequestParser(resolver jsonpb.AnyResolver, in io.Reader) RequestParser {
return &JsonRequestParser{
dec: json.NewDecoder(strings.NewReader(data)),
dec: json.NewDecoder(in),
unmarshaler: jsonpb.Unmarshaler{AnyResolver: resolver},
}
}

View File

@ -6,7 +6,6 @@ import (
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/ptypes/wrappers"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
func TestJsonpbUnmarshalString(t *testing.T) {
@ -43,24 +42,24 @@ func TestJsonpbMarshalEmpty(t *testing.T) {
}
}
func TestNewJsonRequestParser(t *testing.T) {
var str_msg wrapperspb.StringValue
// func TestNewJsonRequestParser(t *testing.T) {
// var str_msg wrapperspb.StringValue
inData := `"abc"`
// inData := `"abc"`
if err := NewJsonRequestParser(nil, inData).Next(&str_msg); err != nil {
t.Error(err)
}
// if err := NewJsonRequestParser(nil, inData).Next(&str_msg); err != nil {
// t.Error(err)
// }
t.Log("string", str_msg.GetValue())
// t.Log("string", str_msg.GetValue())
var int_msg wrapperspb.Int64Value
// var int_msg wrapperspb.Int64Value
inData = `"10"` // or `10`
// inData = `"10"` // or `10`
if err := NewJsonRequestParser(nil, inData).Next(&int_msg); err != nil {
t.Error(err)
}
// if err := NewJsonRequestParser(nil, inData).Next(&int_msg); err != nil {
// t.Error(err)
// }
t.Log("int", int_msg.GetValue())
}
// t.Log("int", int_msg.GetValue())
// }

View File

@ -2,7 +2,9 @@ package grpcall
import (
"context"
"encoding/json"
"fmt"
"io"
"github.com/jhump/protoreflect/desc"
"google.golang.org/grpc/metadata"
@ -11,10 +13,9 @@ import (
type Response struct {
method *desc.MethodDescriptor
header *metadata.MD
data string
send chan string
recv chan string
done chan error
recv *io.PipeReader
send chan []byte
sendCompleted chan error
cancel context.CancelFunc
}
@ -22,39 +23,29 @@ func (this Response) Header() metadata.MD {
return *this.header
}
func (this Response) Data() (string, error) {
if this.method.IsClientStreaming() || this.method.IsServerStreaming() {
return "", fmt.Errorf("%q is %s, cannot use unary data", this.method.GetFullyQualifiedName(), this.GetCategory())
func (this Response) Recv() io.Reader {
return this.recv
}
return this.data, nil
func (this Response) RecvDecoder() *json.Decoder {
return json.NewDecoder(this.Recv())
}
// 发送数据流
func (this Response) Send() (chan<- string, error) {
func (this Response) Send(data string) error {
if !this.method.IsClientStreaming() || !this.method.IsServerStreaming() {
return nil, fmt.Errorf("%q is %s, cannot use send streaming func", this.method.GetFullyQualifiedName(), this.GetCategory())
}
return this.send, nil
return fmt.Errorf("%q is %s, not streaming gRPC", this.method.GetFullyQualifiedName(), this.GetCategory())
}
// 接收数据流
func (this Response) Recv() (<-chan string, error) {
if !this.method.IsServerStreaming() {
return nil, fmt.Errorf("%q is %s, cannot use streaming func", this.method.GetFullyQualifiedName(), this.GetCategory())
}
return this.recv, nil
}
this.send <- []byte(data)
// 数据是否响应结束
func (this Response) Done() (<-chan error, error) {
if !this.method.IsServerStreaming() {
return nil, fmt.Errorf("%q is %s, cannot use streaming func", this.method.GetFullyQualifiedName(), this.GetCategory())
}
return this.done, nil
return <-this.sendCompleted
}
// 取消接收数据
func (this Response) Cancel() {
func (this Response) Close() {
this.recv.Close()
if this.method.IsServerStreaming() {
this.cancel()
}