first commit

This commit is contained in:
what 2023-05-30 17:30:06 +08:00
commit 65e06fdb77
10 changed files with 1022 additions and 0 deletions

21
README.md Normal file
View File

@ -0,0 +1,21 @@
# grpcall
## 生成 protoset file 命令
```
protoc --go_out=. \
--go-grpc_out=. \
--descriptor_set_out=device.protoset \
--include_imports \
./proto3/device.proto
```
## protoc 版本
```
protoc --version // libprotoc 3.20.3
```
## Credits
[https://github.com/rfyiamcool/grpcall](https://github.com/rfyiamcool/grpcall)

20
go.mod Normal file
View File

@ -0,0 +1,20 @@
module git.fsdpf.net/go/grpcall
go 1.19
require (
github.com/golang/protobuf v1.5.3
github.com/jhump/protoreflect v1.15.1
github.com/samber/lo v1.38.1
google.golang.org/grpc v1.55.0
google.golang.org/protobuf v1.30.0
)
require (
github.com/bufbuild/protocompile v0.4.0 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
golang.org/x/net v0.8.0 // indirect
golang.org/x/sys v0.6.0 // indirect
golang.org/x/text v0.8.0 // indirect
google.golang.org/genproto v0.0.0-20230306155012-7f2fa6fef1f4 // indirect
)

29
go.sum Normal file
View File

@ -0,0 +1,29 @@
github.com/bufbuild/protocompile v0.4.0 h1:LbFKd2XowZvQ/kajzguUp2DC9UEIQhIq77fZZlaQsNA=
github.com/bufbuild/protocompile v0.4.0/go.mod h1:3v93+mbWn/v3xzN+31nwkJfrEpAUwp+BagBSZWx+TP8=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/jhump/protoreflect v1.15.1 h1:HUMERORf3I3ZdX05WaQ6MIpd/NJ434hTp5YiKgfCL6c=
github.com/jhump/protoreflect v1.15.1/go.mod h1:jD/2GMKKE6OqX8qTjhADU1e6DShO+gavG9e0Q693nKo=
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20230306155012-7f2fa6fef1f4 h1:DdoeryqhaXp1LtT/emMP1BRJPHHKFi5akj/nbx/zNTA=
google.golang.org/genproto v0.0.0-20230306155012-7f2fa6fef1f4/go.mod h1:NWraEVixdDnqcqQ30jipen1STv2r/n24Wb7twVTGR4s=
google.golang.org/grpc v1.55.0 h1:3Oj82/tFSCeUrRTg/5E/7d/W5A1tj6Ky1ABAuZuv5ag=
google.golang.org/grpc v1.55.0/go.mod h1:iYEXKGkEBhg1PjZQvoYEVPTDkHo1/bjTnfwTeGONTY8=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=

483
grpcall.go Normal file
View File

@ -0,0 +1,483 @@
package grpcall
import (
"context"
"fmt"
"io"
"sort"
"strings"
"time"
"github.com/golang/protobuf/jsonpb"
"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/dynamic"
"github.com/jhump/protoreflect/dynamic/grpcdynamic"
"github.com/samber/lo"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
type Grpcall struct {
conn *grpc.ClientConn
pDescriptor ProtoDescriptor
}
// 构建请求函数
func (this Grpcall) Invoke(ctx context.Context, service, method, data string, headers []string) (*Response, error) {
// 获取 gRPC 服务方法描述信息
mtd, err := this.GetServiceMethodDescriptor(service, method)
if err != nil {
return nil, err
}
md := this.MakeMetadata(headers)
ctx = metadata.NewOutgoingContext(ctx, md)
anyResolver, err := this.GetAnyResolver()
if err != nil {
return nil, err
}
// service.method message 构建器
msgFactory, err := this.MakeServiceMethodMessageFactory(mtd)
if err != nil {
return nil, err
}
stub := grpcdynamic.NewStubWithMessageFactory(this.conn, msgFactory)
if mtd.IsClientStreaming() && mtd.IsServerStreaming() {
return this.invokeBidiStream(ctx, stub, mtd, msgFactory, anyResolver, data)
}
if mtd.IsServerStreaming() {
return this.invokeServStream(ctx, stub, mtd, msgFactory, anyResolver, data)
}
return this.invokeUnary(ctx, stub, mtd, msgFactory, anyResolver, data)
}
func (this Grpcall) invokeUnary(ctx context.Context, stub grpcdynamic.Stub, method *desc.MethodDescriptor, msgFactory *dynamic.MessageFactory, anyResolver jsonpb.AnyResolver, data string) (*Response, error) {
// 创建请求 message
req := msgFactory.NewMessage(method.GetInputType())
// 将数据格式化到 message
if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil {
return nil, err
}
respHeaders := &metadata.MD{}
respMsg, err := stub.InvokeRpc(ctx, method, req, grpc.Header(respHeaders))
respStatus, ok := status.FromError(err)
if !ok || respStatus.Code() != codes.OK {
return nil, fmt.Errorf("grpc call %q failed, %s", method.GetFullyQualifiedName(), err)
}
marshaler := jsonpb.Marshaler{
EmitDefaults: true,
AnyResolver: anyResolver,
}
respData, err := marshaler.MarshalToString(respMsg)
resp := &Response{
method: method,
header: respHeaders,
data: respData,
}
return resp, err
}
func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub, method *desc.MethodDescriptor, msgFactory *dynamic.MessageFactory, anyResolver jsonpb.AnyResolver, data string) (*Response, error) {
// 创建请求 message
req := msgFactory.NewMessage(method.GetInputType())
// 将数据格式化到 message
if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil {
return nil, err
}
cancel_ctx, cancel := context.WithCancel(ctx)
stream, err := stub.InvokeRpcServerStream(cancel_ctx, method, req)
if err != nil {
cancel()
return nil, err
}
resp := &Response{
method: method,
header: &metadata.MD{},
recv: make(chan string),
done: make(chan error),
cancel: cancel,
}
go func() {
msgParser := jsonpb.Marshaler{
EmitDefaults: true,
AnyResolver: anyResolver,
}
defer func() {
close(resp.done)
close(resp.recv)
cancel()
}()
for {
// 从流中接收消息
msg, err := stream.RecvMsg()
if err == io.EOF {
resp.done <- nil
return
}
if err != nil {
resp.done <- err
return
}
if *resp.header, err = stream.Header(); err != nil {
resp.done <- err
return
}
if data, err := msgParser.MarshalToString(msg); err != nil {
resp.done <- err
return
} else {
resp.recv <- data
}
msg.Reset()
}
}()
return resp, nil
}
func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub, method *desc.MethodDescriptor, msgFactory *dynamic.MessageFactory, anyResolver jsonpb.AnyResolver, data string) (*Response, error) {
// 创建请求 message
req := msgFactory.NewMessage(method.GetInputType())
// 将数据格式化到 message
if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil {
return nil, err
}
cancel_ctx, cancel := context.WithCancel(ctx)
stream, err := stub.InvokeRpcBidiStream(cancel_ctx, method)
if err != nil {
cancel()
return nil, err
}
if err := stream.SendMsg(req); err != nil {
cancel()
return nil, err
}
resp := &Response{
method: method,
header: &metadata.MD{},
recv: make(chan string),
send: make(chan string),
done: make(chan error),
cancel: cancel,
}
// send
go func() {
defer func() {
stream.CloseSend()
close(resp.send)
cancel()
}()
for {
select {
case msg, ok := <-resp.send:
if !ok {
resp.done <- nil
return
}
// 将数据格式化到 message
if err := NewJsonRequestParser(anyResolver, msg).Next(req); err != nil {
resp.done <- err
return
}
if err := stream.SendMsg(req); err != nil {
resp.done <- err
return
}
case <-cancel_ctx.Done():
resp.done <- nil
return
case <-ctx.Done():
resp.done <- nil
return
}
}
}()
// recv
go func() {
msgParser := jsonpb.Marshaler{
EmitDefaults: true,
AnyResolver: anyResolver,
}
defer func() {
close(resp.done)
close(resp.recv)
cancel()
}()
for {
// 从流中接收消息
msg, err := stream.RecvMsg()
if err == io.EOF {
resp.done <- nil
return
}
if err != nil {
resp.done <- err
return
}
if *resp.header, err = stream.Header(); err != nil {
resp.done <- err
return
}
if data, err := msgParser.MarshalToString(msg); err != nil {
resp.done <- err
return
} else {
resp.recv <- data
}
msg.Reset()
}
}()
return resp, nil
}
func (this Grpcall) GetClientConn() *grpc.ClientConn {
return this.conn
}
func (this Grpcall) GetAnyResolver() (jsonpb.AnyResolver, error) {
files, err := this.GetAllFilesDescriptor()
if err != nil {
return nil, err
}
dExtension := dynamic.ExtensionRegistry{}
for _, fd := range files {
dExtension.AddExtensionsFromFile(fd)
}
mf := dynamic.NewMessageFactoryWithExtensionRegistry(&dExtension)
return dynamic.AnyResolver(mf, files...), nil
}
func (this Grpcall) GetProtoDescriptor() ProtoDescriptor {
return this.pDescriptor
}
func (this Grpcall) GetAllFilesDescriptor() ([]*desc.FileDescriptor, error) {
return this.pDescriptor.GetAllFiles()
}
// 获取所有服务
func (this Grpcall) GetServices() ([]string, error) {
return this.pDescriptor.ListServices()
}
// 获取服务下的所有方法
func (this Grpcall) GetServiceMethods(name string) ([]string, error) {
if dsc, err := this.pDescriptor.FindSymbol(name); err != nil {
return nil, err
} else if sd, ok := dsc.(*desc.ServiceDescriptor); !ok {
return nil, fmt.Errorf("%s, not found methods", name)
} else {
methods := make([]string, 0, len(sd.GetMethods()))
for _, method := range sd.GetMethods() {
methods = append(methods, method.GetFullyQualifiedName())
}
sort.Strings(methods)
return methods, nil
}
}
// 获取 gRPC 服务描述
func (this Grpcall) GetServiceDescriptor(service string) (*desc.ServiceDescriptor, error) {
if service == "" {
return nil, fmt.Errorf("service name is invalid")
}
dsc, err := this.pDescriptor.FindSymbol(service)
if err != nil {
return nil, err
}
sd, ok := dsc.(*desc.ServiceDescriptor)
if !ok {
return nil, fmt.Errorf("server not expose service, %q", service)
}
return sd, nil
}
// 获取 gRPC service.method 描述
func (this Grpcall) GetServiceMethodDescriptor(service, method string) (*desc.MethodDescriptor, error) {
serv, err := this.GetServiceDescriptor(service)
if err != nil {
return nil, err
}
if mtd := serv.FindMethodByName(method); mtd != nil {
return mtd, nil
}
return nil, fmt.Errorf("service %q does not include a method named %q", service, method)
}
// 生成 service.method message
func (this *Grpcall) MakeServiceMethodMessageFactory(method *desc.MethodDescriptor) (*dynamic.MessageFactory, error) {
// 扩展字段的信息,包括扩展字段的名称、标识符、类型等。用于注册扩展、查找扩展、解析扩展以及处理扩展字段的值。
var ext dynamic.ExtensionRegistry
excluded := []string{}
if err := this.RegServiceMethoddMessageExtensions(&ext, method.GetInputType(), excluded); err != nil {
return nil, fmt.Errorf("register service.method extensions for message %s: %v", method.GetInputType().GetFullyQualifiedName(), err)
}
if err := this.RegServiceMethoddMessageExtensions(&ext, method.GetOutputType(), excluded); err != nil {
return nil, fmt.Errorf("register service.method extensions for message %s: %v", method.GetOutputType().GetFullyQualifiedName(), err)
}
return dynamic.NewMessageFactoryWithExtensionRegistry(&ext), nil
}
// 注册 service.method message extensions
func (this Grpcall) RegServiceMethoddMessageExtensions(msgExt *dynamic.ExtensionRegistry, msg *desc.MessageDescriptor, excluded []string) error {
msgTypeName := msg.GetFullyQualifiedName()
// 避免重复注册
if lo.Contains(excluded, msgTypeName) {
return nil
} else {
excluded = append(excluded, msgTypeName)
}
// 注册 message extensions 信息
if len(msg.GetExtensionRanges()) > 0 {
fds, err := this.pDescriptor.AllExtensionsForType(msgTypeName)
if err != nil {
return fmt.Errorf("failed to query for extensions of type %s: %v", msgTypeName, err)
}
for i := 0; i < len(fds); i++ {
if err := msgExt.AddExtension(fds[i]); err != nil {
return fmt.Errorf("could not register extension %s of type %s: %v", fds[i].GetFullyQualifiedName(), msgTypeName, err)
}
}
}
// 递归注册 message 所有字段
for _, msgField := range msg.GetFields() {
if msgField.GetMessageType() == nil {
continue
}
if err := this.RegServiceMethoddMessageExtensions(msgExt, msgField.GetMessageType(), excluded); err != nil {
return err
}
}
return nil
}
func (this Grpcall) MakeMetadata(headers []string) metadata.MD {
md := make(metadata.MD)
for _, part := range headers {
if part != "" {
pieces := strings.SplitN(part, ":", 2)
if len(pieces) == 1 {
pieces = append(pieces, "") // if no value was specified, just make it "" (maybe the header value doesn't matter)
}
headerName := strings.ToLower(strings.TrimSpace(pieces[0]))
val := strings.TrimSpace(pieces[1])
if strings.HasSuffix(headerName, "-bin") {
if v, err := metadata_decode(val); err == nil {
val = v
}
}
md[headerName] = append(md[headerName], val)
}
}
return md
}
func (this Grpcall) Close() error {
return this.conn.Close()
}
func NewGrpcall(addr string, protosets []string, opts ...grpc.DialOption) (g *Grpcall, err error) {
if addr == "" {
return nil, fmt.Errorf("addr is invalid, %s", addr)
}
g = &Grpcall{}
opts = append(opts,
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 64 * time.Second,
Timeout: 64 * time.Second,
}),
grpc.WithBlock(),
grpc.FailOnNonTempDialError(true),
)
if conn, err := grpc.Dial(addr, opts...); err != nil {
return nil, err
} else {
g.conn = conn
}
if len(protosets) == 0 {
g.pDescriptor, err = NewServProtoDescriptor(g.conn)
} else {
g.pDescriptor, err = NewFileProtoDescriptor(protosets)
}
return g, err
}

118
grpcall_test.go Normal file
View File

@ -0,0 +1,118 @@
package grpcall
import (
"context"
"fmt"
"os"
"testing"
"google.golang.org/grpc"
)
var gc *Grpcall
func TestMain(m *testing.M) {
if _gc, err := NewGrpcall("127.0.0.1:8888", nil, grpc.WithInsecure()); err != nil {
// if _gc, err := NewGrpcall("192.168.0.6:6644", nil, grpc.WithInsecure()); err != nil {
panic(err)
} else {
gc = _gc
}
defer gc.Close()
// 执行测试
code := m.Run()
// 在这里执行一些清理工作,例如关闭数据库连接等
// 退出测试
os.Exit(code)
}
func TestGetServices(t *testing.T) {
t.Log(gc.GetServices())
}
func TestGetServiceMethods(t *testing.T) {
if s, err := gc.GetServices(); err != nil {
t.Error(err)
} else {
t.Log(gc.GetServiceMethods(s[0]))
}
}
func TestInvokeUnary(t *testing.T) {
ctx := context.Background()
if resp, err := gc.Invoke(ctx, "User.UserResourceStatus", "TestInt64Value", `123`, nil); err != nil {
t.Error(err)
} else {
t.Log(resp.Data())
}
}
func TestInvokeServStream(t *testing.T) {
ctx := context.Background()
if resp, err := gc.Invoke(ctx, "User.UserResourceStatus", "GetEvent", `{}`, nil); err != nil {
t.Error(err)
} else if recv, err := resp.Recv(); err != nil {
t.Error(err)
} else if done, err := resp.Done(); err != nil {
t.Error(err)
} else {
flag := make(chan bool)
go func() {
for {
select {
case msg := <-recv:
fmt.Println("msg", msg, resp.Header())
case err := <-done:
if err != nil {
t.Error(err)
}
flag <- true
return
}
}
}()
<-flag
}
}
func TestInvokeBidiStream(t *testing.T) {
ctx := context.Background()
if resp, err := gc.Invoke(ctx, "User.UserResourceStatus", "TestBidiStream", `"hello"`, nil); err != nil {
t.Error(err)
} else if send, err := resp.Send(); err != nil {
t.Error(err)
} else if recv, err := resp.Recv(); err != nil {
t.Error(err)
} else if done, err := resp.Done(); err != nil {
t.Error(err)
} else {
flag := make(chan bool)
go func() {
for {
select {
case msg := <-recv:
fmt.Println("msg", msg, resp.Header())
case err := <-done:
if err != nil {
t.Error(err)
}
flag <- true
return
}
}
}()
for i := 0; i < 10; i++ {
send <- fmt.Sprintf(`"abc %d"`, i)
}
<-flag
}
}

136
proto_descriptor.go Normal file
View File

@ -0,0 +1,136 @@
package grpcall
import (
"context"
"fmt"
"sort"
"sync"
"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/dynamic"
"github.com/jhump/protoreflect/grpcreflect"
"github.com/samber/lo"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
)
type ProtoDescriptor interface {
// gRPC 服务列表
ListServices() ([]string, error)
// 查找 proto 中定义的消息、服务或枚举等结构的描述符
FindSymbol(string) (desc.Descriptor, error)
// 获取消息类型的所有扩展字段的描述符
AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error)
// 获取 proto 文件描述信息
GetAllFiles() ([]*desc.FileDescriptor, error)
}
type WithFilesProtoDescriptor interface {
GetAllFiles() ([]*desc.FileDescriptor, error)
}
type ServProtoDescriptor struct {
client *grpcreflect.Client
}
type FileProtoDescriptor struct {
files map[string]*desc.FileDescriptor
er *dynamic.ExtensionRegistry
erInit sync.Once
}
func NewServProtoDescriptor(client *grpc.ClientConn) (ProtoDescriptor, error) {
ctx := context.Background()
pd := &ServProtoDescriptor{
client: grpcreflect.NewClient(ctx, grpc_reflection_v1alpha.NewServerReflectionClient(client)),
}
return pd, nil
}
func (this ServProtoDescriptor) ListServices() ([]string, error) {
return this.client.ListServices()
}
func (this ServProtoDescriptor) FindSymbol(name string) (desc.Descriptor, error) {
file, err := this.client.FileContainingSymbol(name)
if err != nil {
return nil, err
}
d := file.FindSymbol(name)
if d == nil {
return nil, fmt.Errorf("Symbol Not Found, %s", name)
}
return d, nil
}
func (this ServProtoDescriptor) AllExtensionsForType(typ string) ([]*desc.FieldDescriptor, error) {
exts := []*desc.FieldDescriptor{}
nums, err := this.client.AllExtensionNumbersForType(typ)
if err != nil {
return nil, err
}
for _, fieldNum := range nums {
ext, err := this.client.ResolveExtension(typ, fieldNum)
if err != nil {
return nil, err
}
exts = append(exts, ext)
}
return exts, nil
}
func (this ServProtoDescriptor) GetAllFiles() ([]*desc.FileDescriptor, error) {
var files []*desc.FileDescriptor
if services, err := this.ListServices(); err != nil {
return nil, err
} else {
temp_fields := map[string]*desc.FileDescriptor{}
for _, name := range services {
d, err := this.FindSymbol(name)
if err != nil {
return nil, err
}
this.addFileDescriptorsToCache(d.GetFile(), temp_fields)
}
files = lo.Values(temp_fields)
}
sort.Sort(FileDescriptorSorter(files))
return files, nil
}
func (this ServProtoDescriptor) addFileDescriptorsToCache(file *desc.FileDescriptor, cache map[string]*desc.FileDescriptor) error {
if _, ok := cache[file.GetName()]; ok {
return nil
} else {
cache[file.GetName()] = file
}
for _, dep := range file.GetDependencies() {
this.addFileDescriptorsToCache(dep, cache)
}
return nil
}
func NewFileProtoDescriptor(protosets []string) (ProtoDescriptor, error) {
return nil, nil
}
func (this FileProtoDescriptor) GetAllFiles() ([]*desc.FileDescriptor, error) {
files := lo.Values(this.files)
sort.Sort(FileDescriptorSorter(files))
return files, nil
}

35
request.go Normal file
View File

@ -0,0 +1,35 @@
package grpcall
import (
"encoding/json"
"fmt"
"strings"
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
)
type RequestData func(msg proto.Message) error
type RequestParser interface {
Next(msg proto.Message) error
}
type JsonRequestParser struct {
dec *json.Decoder
unmarshaler jsonpb.Unmarshaler
}
func (this JsonRequestParser) Next(msg proto.Message) error {
if err := this.unmarshaler.UnmarshalNext(this.dec, msg); err != nil {
return fmt.Errorf("unmarshal request json data error, %s", err)
}
return nil
}
func NewJsonRequestParser(resolver jsonpb.AnyResolver, data string) RequestParser {
return &JsonRequestParser{
dec: json.NewDecoder(strings.NewReader(data)),
unmarshaler: jsonpb.Unmarshaler{AnyResolver: resolver},
}
}

66
request_test.go Normal file
View File

@ -0,0 +1,66 @@
package grpcall
import (
"testing"
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/ptypes/wrappers"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
func TestJsonpbUnmarshalString(t *testing.T) {
msg := &wrappers.Int64Value{}
if err := jsonpb.UnmarshalString(`123`, msg); err != nil {
t.Error(err)
}
t.Log(msg)
}
func TestJsonpbMarshalString(t *testing.T) {
msg := &wrappers.Int64Value{Value: 123}
marshaler := &jsonpb.Marshaler{}
if json, err := marshaler.MarshalToString(msg); err != nil {
t.Error(err)
} else {
t.Log(json) // "123"
}
}
func TestJsonpbMarshalEmpty(t *testing.T) {
msg := &emptypb.Empty{}
marshaler := &jsonpb.Marshaler{}
if json, err := marshaler.MarshalToString(msg); err != nil {
t.Error(err)
} else {
t.Log(json) // {}
}
}
func TestNewJsonRequestParser(t *testing.T) {
var str_msg wrapperspb.StringValue
inData := `"abc"`
if err := NewJsonRequestParser(nil, inData).Next(&str_msg); err != nil {
t.Error(err)
}
t.Log("string", str_msg.GetValue())
var int_msg wrapperspb.Int64Value
inData = `"10"` // or `10`
if err := NewJsonRequestParser(nil, inData).Next(&int_msg); err != nil {
t.Error(err)
}
t.Log("int", int_msg.GetValue())
}

73
response.go Normal file
View File

@ -0,0 +1,73 @@
package grpcall
import (
"context"
"fmt"
"github.com/jhump/protoreflect/desc"
"google.golang.org/grpc/metadata"
)
type Response struct {
method *desc.MethodDescriptor
header *metadata.MD
data string
send chan string
recv chan string
done chan error
cancel context.CancelFunc
}
func (this Response) Header() metadata.MD {
return *this.header
}
func (this Response) Data() (string, error) {
if this.method.IsClientStreaming() || this.method.IsServerStreaming() {
return "", fmt.Errorf("%q is %s, cannot use unary data", this.method.GetFullyQualifiedName(), this.GetCategory())
}
return this.data, nil
}
// 发送数据流
func (this Response) Send() (chan<- string, error) {
if !this.method.IsClientStreaming() || !this.method.IsServerStreaming() {
return nil, fmt.Errorf("%q is %s, cannot use send streaming func", this.method.GetFullyQualifiedName(), this.GetCategory())
}
return this.send, nil
}
// 接收数据流
func (this Response) Recv() (<-chan string, error) {
if !this.method.IsServerStreaming() {
return nil, fmt.Errorf("%q is %s, cannot use streaming func", this.method.GetFullyQualifiedName(), this.GetCategory())
}
return this.recv, nil
}
// 数据是否响应结束
func (this Response) Done() (<-chan error, error) {
if !this.method.IsServerStreaming() {
return nil, fmt.Errorf("%q is %s, cannot use streaming func", this.method.GetFullyQualifiedName(), this.GetCategory())
}
return this.done, nil
}
// 取消接收数据
func (this Response) Cancel() {
if this.method.IsServerStreaming() {
this.cancel()
}
}
func (this Response) GetCategory() string {
if this.method.IsClientStreaming() && this.method.IsServerStreaming() {
return "bidi-streaming"
} else if this.method.IsClientStreaming() {
return "client-streaming"
} else if this.method.IsServerStreaming() {
return "server-streaming"
} else {
return "unary"
}
}

41
utils.go Normal file
View File

@ -0,0 +1,41 @@
package grpcall
import (
"encoding/base64"
"github.com/jhump/protoreflect/desc"
)
type FileDescriptorSorter []*desc.FileDescriptor
func (this FileDescriptorSorter) Len() int {
return len(this)
}
func (this FileDescriptorSorter) Less(i, j int) bool {
return this[i].GetName() < this[j].GetName()
}
func (this FileDescriptorSorter) Swap(i, j int) {
this[i], this[j] = this[j], this[i]
}
var base64s = []*base64.Encoding{base64.StdEncoding, base64.URLEncoding, base64.RawStdEncoding, base64.RawURLEncoding}
func metadata_decode(val string) (string, error) {
var firstErr error
var b []byte
// we are lenient and can accept any of the flavors of base64 encoding
for _, d := range base64s {
var err error
b, err = d.DecodeString(val)
if err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
return string(b), nil
}
return "", firstErr
}