[feat] 将 gRPC recv 的 channel 改为 io read/write, 方便使用 json.Decoder 解析 message

This commit is contained in:
what 2023-06-01 10:28:08 +08:00
parent d8f58df39a
commit a30b34e137
5 changed files with 126 additions and 133 deletions

View File

@ -1,7 +1,9 @@
package grpcall package grpcall
import ( import (
"bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"io" "io"
"sort" "sort"
@ -67,7 +69,7 @@ func (this Grpcall) invokeUnary(ctx context.Context, stub grpcdynamic.Stub, meth
// 创建请求 message // 创建请求 message
req := msgFactory.NewMessage(method.GetInputType()) req := msgFactory.NewMessage(method.GetInputType())
// 将数据格式化到 message // 将数据格式化到 message
if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil { if err := NewJsonRequestParser(anyResolver, strings.NewReader(data)).Next(req); err != nil {
return nil, err return nil, err
} }
@ -86,14 +88,21 @@ func (this Grpcall) invokeUnary(ctx context.Context, stub grpcdynamic.Stub, meth
AnyResolver: anyResolver, AnyResolver: anyResolver,
} }
respData, err := marshaler.MarshalToString(respMsg) pReader, pWriter := io.Pipe()
resp := &Response{ resp := &Response{
method: method, method: method,
header: respHeaders, header: respHeaders,
data: respData, recv: pReader,
} }
go func() {
defer pWriter.Close()
if err = marshaler.Marshal(pWriter, respMsg); err != nil {
pWriter.CloseWithError(err)
}
}()
return resp, err return resp, err
} }
@ -101,7 +110,7 @@ func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub,
// 创建请求 message // 创建请求 message
req := msgFactory.NewMessage(method.GetInputType()) req := msgFactory.NewMessage(method.GetInputType())
// 将数据格式化到 message // 将数据格式化到 message
if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil { if err := NewJsonRequestParser(anyResolver, strings.NewReader(data)).Next(req); err != nil {
return nil, err return nil, err
} }
@ -114,11 +123,12 @@ func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub,
return nil, err return nil, err
} }
pReader, pWriter := io.Pipe()
resp := &Response{ resp := &Response{
method: method, method: method,
header: &metadata.MD{}, header: &metadata.MD{},
recv: make(chan string), recv: pReader,
done: make(chan error),
cancel: cancel, cancel: cancel,
} }
@ -129,9 +139,8 @@ func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub,
} }
defer func() { defer func() {
close(resp.done)
close(resp.recv)
cancel() cancel()
pWriter.Close()
}() }()
for { for {
@ -139,25 +148,22 @@ func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub,
msg, err := stream.RecvMsg() msg, err := stream.RecvMsg()
if err == io.EOF { if err == io.EOF {
resp.done <- nil
return return
} }
if err != nil { if err != nil {
resp.done <- err pWriter.CloseWithError(err)
return return
} }
if *resp.header, err = stream.Header(); err != nil { if *resp.header, err = stream.Header(); err != nil {
resp.done <- err pWriter.CloseWithError(err)
return return
} }
if data, err := msgParser.MarshalToString(msg); err != nil { if err = msgParser.Marshal(pWriter, msg); err != nil {
resp.done <- err pWriter.CloseWithError(err)
return return
} else {
resp.recv <- data
} }
msg.Reset() msg.Reset()
@ -171,7 +177,7 @@ func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub,
// 创建请求 message // 创建请求 message
req := msgFactory.NewMessage(method.GetInputType()) req := msgFactory.NewMessage(method.GetInputType())
// 将数据格式化到 message // 将数据格式化到 message
if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil { if err := NewJsonRequestParser(anyResolver, strings.NewReader(data)).Next(req); err != nil {
return nil, err return nil, err
} }
@ -189,12 +195,14 @@ func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub,
return nil, err return nil, err
} }
pRecvReader, pRecvWriter := io.Pipe()
resp := &Response{ resp := &Response{
method: method, method: method,
header: &metadata.MD{}, header: &metadata.MD{},
recv: make(chan string), recv: pRecvReader,
send: make(chan string), send: make(chan []byte),
done: make(chan error), sendCompleted: make(chan error),
cancel: cancel, cancel: cancel,
} }
@ -202,36 +210,39 @@ func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub,
go func() { go func() {
defer func() { defer func() {
stream.CloseSend() stream.CloseSend()
close(resp.send)
cancel() cancel()
close(resp.sendCompleted)
close(resp.send)
}() }()
reqParser := jsonpb.Unmarshaler{AnyResolver: anyResolver}
for { for {
select { select {
case <-cancel_ctx.Done():
resp.sendCompleted <- nil
return
case <-ctx.Done():
resp.sendCompleted <- nil
return
case msg, ok := <-resp.send: case msg, ok := <-resp.send:
if !ok { if !ok {
resp.done <- nil resp.sendCompleted <- fmt.Errorf("read send message error")
return return
} }
// 将数据格式化到 message if err := reqParser.UnmarshalNext(json.NewDecoder(bytes.NewReader(msg)), req); err != nil {
if err := NewJsonRequestParser(anyResolver, msg).Next(req); err != nil { resp.sendCompleted <- err
resp.done <- err
return return
} }
if err := stream.SendMsg(req); err != nil { if err := stream.SendMsg(req); err != nil {
resp.done <- err resp.sendCompleted <- err
return return
} }
}
case <-cancel_ctx.Done(): resp.sendCompleted <- nil
resp.done <- nil
return
case <-ctx.Done():
resp.done <- nil
return
}
} }
}() }()
@ -243,9 +254,8 @@ func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub,
} }
defer func() { defer func() {
close(resp.done)
close(resp.recv)
cancel() cancel()
pRecvWriter.Close()
}() }()
for { for {
@ -253,25 +263,22 @@ func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub,
msg, err := stream.RecvMsg() msg, err := stream.RecvMsg()
if err == io.EOF { if err == io.EOF {
resp.done <- nil
return return
} }
if err != nil { if err != nil {
resp.done <- err pRecvWriter.CloseWithError(err)
return return
} }
if *resp.header, err = stream.Header(); err != nil { if *resp.header, err = stream.Header(); err != nil {
resp.done <- err pRecvWriter.CloseWithError(err)
return return
} }
if data, err := msgParser.MarshalToString(msg); err != nil { if err = msgParser.Marshal(pRecvWriter, msg); err != nil {
resp.done <- err pRecvWriter.CloseWithError(err)
return return
} else {
resp.recv <- data
} }
msg.Reset() msg.Reset()

View File

@ -2,6 +2,7 @@ package grpcall
import ( import (
"fmt" "fmt"
"io"
"os" "os"
"testing" "testing"
"time" "time"
@ -46,7 +47,10 @@ func TestInvokeUnary(t *testing.T) {
if resp, err := gc.Invoke("User.UserResourceStatus", "TestInt64Value", `123`, nil); err != nil { if resp, err := gc.Invoke("User.UserResourceStatus", "TestInt64Value", `123`, nil); err != nil {
t.Error(err) t.Error(err)
} else { } else {
t.Log(resp.Data()) // t.Log(ioutil.ReadAll(resp.Recv()))
str := ""
decoder := resp.RecvDecoder()
t.Log(decoder.Decode(&str), str)
} }
} }
@ -54,27 +58,24 @@ func TestInvokeUnary(t *testing.T) {
func TestInvokeServStream(t *testing.T) { func TestInvokeServStream(t *testing.T) {
if resp, err := gc.Invoke("User.UserResourceStatus", "GetEvent", `{}`, nil); err != nil { if resp, err := gc.Invoke("User.UserResourceStatus", "GetEvent", `{}`, nil); err != nil {
t.Error(err) t.Error(err)
} else if recv, err := resp.Recv(); err != nil {
t.Error(err)
} else if done, err := resp.Done(); err != nil {
t.Error(err)
} else { } else {
flag := make(chan bool) defer resp.Close()
go func() { decoder := resp.RecvDecoder()
for { for {
select { data := struct {
case msg := <-recv: Uuid string `json:"uuid"`
fmt.Println("msg", msg, resp.Header()) Type string `json:"type"`
case err := <-done: }{}
if err != nil {
if err := decoder.Decode(&data); err != nil {
if err != io.EOF && err != io.ErrClosedPipe {
t.Error(err) t.Error(err)
} }
flag <- true break
return
} }
fmt.Println(data)
} }
}()
<-flag
} }
} }
@ -82,41 +83,32 @@ func TestInvokeServStream(t *testing.T) {
func TestInvokeBidiStream(t *testing.T) { func TestInvokeBidiStream(t *testing.T) {
if resp, err := gc.Invoke("User.UserResourceStatus", "TestBidiStream", `"hello"`, nil); err != nil { if resp, err := gc.Invoke("User.UserResourceStatus", "TestBidiStream", `"hello"`, nil); err != nil {
t.Error(err) t.Error(err)
} else if send, err := resp.Send(); err != nil {
t.Error(err)
} else if recv, err := resp.Recv(); err != nil {
t.Error(err)
} else if done, err := resp.Done(); err != nil {
t.Error(err)
} else { } else {
flag := make(chan bool) defer resp.Close()
decoder := resp.RecvDecoder()
go func() { go func() {
timer := time.NewTimer(time.Second)
defer timer.Stop()
for { for {
select { data := struct {
case msg := <-recv: Uuid string `json:"uuid"`
fmt.Println("msg", msg, resp.Header()) Type string `json:"type"`
case err := <-done: }{}
if err != nil {
if err := decoder.Decode(&data); err != nil {
if err != io.EOF && err != io.ErrClosedPipe {
t.Error(err) t.Error(err)
} }
flag <- true
return return
case <-timer.C:
resp.Cancel()
timer.Reset(time.Second)
} }
fmt.Println(333, data)
} }
}() }()
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
send <- fmt.Sprintf(`"abc %d"`, i) fmt.Println("消息发送, ", i, resp.Send(fmt.Sprintf(`"abc %d"`, i)))
} }
<-flag time.Sleep(time.Second)
} }
} }

View File

@ -3,7 +3,7 @@ package grpcall
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings" "io"
"github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
@ -22,14 +22,18 @@ type JsonRequestParser struct {
func (this JsonRequestParser) Next(msg proto.Message) error { func (this JsonRequestParser) Next(msg proto.Message) error {
if err := this.unmarshaler.UnmarshalNext(this.dec, msg); err != nil { if err := this.unmarshaler.UnmarshalNext(this.dec, msg); err != nil {
if err == io.EOF {
return err
}
return fmt.Errorf("unmarshal request json data error, %s", err) return fmt.Errorf("unmarshal request json data error, %s", err)
} }
return nil return nil
} }
func NewJsonRequestParser(resolver jsonpb.AnyResolver, data string) RequestParser { func NewJsonRequestParser(resolver jsonpb.AnyResolver, in io.Reader) RequestParser {
return &JsonRequestParser{ return &JsonRequestParser{
dec: json.NewDecoder(strings.NewReader(data)), dec: json.NewDecoder(in),
unmarshaler: jsonpb.Unmarshaler{AnyResolver: resolver}, unmarshaler: jsonpb.Unmarshaler{AnyResolver: resolver},
} }
} }

View File

@ -6,7 +6,6 @@ import (
"github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/ptypes/wrappers" "github.com/golang/protobuf/ptypes/wrappers"
"google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/wrapperspb"
) )
func TestJsonpbUnmarshalString(t *testing.T) { func TestJsonpbUnmarshalString(t *testing.T) {
@ -43,24 +42,24 @@ func TestJsonpbMarshalEmpty(t *testing.T) {
} }
} }
func TestNewJsonRequestParser(t *testing.T) { // func TestNewJsonRequestParser(t *testing.T) {
var str_msg wrapperspb.StringValue // var str_msg wrapperspb.StringValue
inData := `"abc"` // inData := `"abc"`
if err := NewJsonRequestParser(nil, inData).Next(&str_msg); err != nil { // if err := NewJsonRequestParser(nil, inData).Next(&str_msg); err != nil {
t.Error(err) // t.Error(err)
} // }
t.Log("string", str_msg.GetValue()) // t.Log("string", str_msg.GetValue())
var int_msg wrapperspb.Int64Value // var int_msg wrapperspb.Int64Value
inData = `"10"` // or `10` // inData = `"10"` // or `10`
if err := NewJsonRequestParser(nil, inData).Next(&int_msg); err != nil { // if err := NewJsonRequestParser(nil, inData).Next(&int_msg); err != nil {
t.Error(err) // t.Error(err)
} // }
t.Log("int", int_msg.GetValue()) // t.Log("int", int_msg.GetValue())
} // }

View File

@ -2,7 +2,9 @@ package grpcall
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"io"
"github.com/jhump/protoreflect/desc" "github.com/jhump/protoreflect/desc"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@ -11,10 +13,9 @@ import (
type Response struct { type Response struct {
method *desc.MethodDescriptor method *desc.MethodDescriptor
header *metadata.MD header *metadata.MD
data string recv *io.PipeReader
send chan string send chan []byte
recv chan string sendCompleted chan error
done chan error
cancel context.CancelFunc cancel context.CancelFunc
} }
@ -22,39 +23,29 @@ func (this Response) Header() metadata.MD {
return *this.header return *this.header
} }
func (this Response) Data() (string, error) { func (this Response) Recv() io.Reader {
if this.method.IsClientStreaming() || this.method.IsServerStreaming() { return this.recv
return "", fmt.Errorf("%q is %s, cannot use unary data", this.method.GetFullyQualifiedName(), this.GetCategory()) }
}
return this.data, nil func (this Response) RecvDecoder() *json.Decoder {
return json.NewDecoder(this.Recv())
} }
// 发送数据流 // 发送数据流
func (this Response) Send() (chan<- string, error) { func (this Response) Send(data string) error {
if !this.method.IsClientStreaming() || !this.method.IsServerStreaming() { if !this.method.IsClientStreaming() || !this.method.IsServerStreaming() {
return nil, fmt.Errorf("%q is %s, cannot use send streaming func", this.method.GetFullyQualifiedName(), this.GetCategory()) return fmt.Errorf("%q is %s, not streaming gRPC", this.method.GetFullyQualifiedName(), this.GetCategory())
} }
return this.send, nil
}
// 接收数据流 this.send <- []byte(data)
func (this Response) Recv() (<-chan string, error) {
if !this.method.IsServerStreaming() {
return nil, fmt.Errorf("%q is %s, cannot use streaming func", this.method.GetFullyQualifiedName(), this.GetCategory())
}
return this.recv, nil
}
// 数据是否响应结束 return <-this.sendCompleted
func (this Response) Done() (<-chan error, error) {
if !this.method.IsServerStreaming() {
return nil, fmt.Errorf("%q is %s, cannot use streaming func", this.method.GetFullyQualifiedName(), this.GetCategory())
}
return this.done, nil
} }
// 取消接收数据 // 取消接收数据
func (this Response) Cancel() { func (this Response) Close() {
this.recv.Close()
if this.method.IsServerStreaming() { if this.method.IsServerStreaming() {
this.cancel() this.cancel()
} }