From 65e06fdb77e9278044f6b530f918d6f36001ffae Mon Sep 17 00:00:00 2001 From: what Date: Tue, 30 May 2023 17:30:06 +0800 Subject: [PATCH] first commit --- README.md | 21 ++ go.mod | 20 ++ go.sum | 29 +++ grpcall.go | 483 ++++++++++++++++++++++++++++++++++++++++++++ grpcall_test.go | 118 +++++++++++ proto_descriptor.go | 136 +++++++++++++ request.go | 35 ++++ request_test.go | 66 ++++++ response.go | 73 +++++++ utils.go | 41 ++++ 10 files changed, 1022 insertions(+) create mode 100644 README.md create mode 100644 go.mod create mode 100644 go.sum create mode 100644 grpcall.go create mode 100644 grpcall_test.go create mode 100644 proto_descriptor.go create mode 100644 request.go create mode 100644 request_test.go create mode 100644 response.go create mode 100644 utils.go diff --git a/README.md b/README.md new file mode 100644 index 0000000..b98b13f --- /dev/null +++ b/README.md @@ -0,0 +1,21 @@ +# grpcall + + +## 生成 protoset file 命令 + +``` +protoc --go_out=. \ + --go-grpc_out=. \ + --descriptor_set_out=device.protoset \ + --include_imports \ + ./proto3/device.proto +``` + +## protoc 版本 +``` +protoc --version // libprotoc 3.20.3 +``` + +## Credits + +[https://github.com/rfyiamcool/grpcall](https://github.com/rfyiamcool/grpcall) \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..18808ea --- /dev/null +++ b/go.mod @@ -0,0 +1,20 @@ +module git.fsdpf.net/go/grpcall + +go 1.19 + +require ( + github.com/golang/protobuf v1.5.3 + github.com/jhump/protoreflect v1.15.1 + github.com/samber/lo v1.38.1 + google.golang.org/grpc v1.55.0 + google.golang.org/protobuf v1.30.0 +) + +require ( + github.com/bufbuild/protocompile v0.4.0 // indirect + golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect + golang.org/x/net v0.8.0 // indirect + golang.org/x/sys v0.6.0 // indirect + golang.org/x/text v0.8.0 // indirect + google.golang.org/genproto v0.0.0-20230306155012-7f2fa6fef1f4 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7c89c45 --- /dev/null +++ b/go.sum @@ -0,0 +1,29 @@ +github.com/bufbuild/protocompile v0.4.0 h1:LbFKd2XowZvQ/kajzguUp2DC9UEIQhIq77fZZlaQsNA= +github.com/bufbuild/protocompile v0.4.0/go.mod h1:3v93+mbWn/v3xzN+31nwkJfrEpAUwp+BagBSZWx+TP8= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/jhump/protoreflect v1.15.1 h1:HUMERORf3I3ZdX05WaQ6MIpd/NJ434hTp5YiKgfCL6c= +github.com/jhump/protoreflect v1.15.1/go.mod h1:jD/2GMKKE6OqX8qTjhADU1e6DShO+gavG9e0Q693nKo= +github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= +github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= +golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= +golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= +golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= +golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= +golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto v0.0.0-20230306155012-7f2fa6fef1f4 h1:DdoeryqhaXp1LtT/emMP1BRJPHHKFi5akj/nbx/zNTA= +google.golang.org/genproto v0.0.0-20230306155012-7f2fa6fef1f4/go.mod h1:NWraEVixdDnqcqQ30jipen1STv2r/n24Wb7twVTGR4s= +google.golang.org/grpc v1.55.0 h1:3Oj82/tFSCeUrRTg/5E/7d/W5A1tj6Ky1ABAuZuv5ag= +google.golang.org/grpc v1.55.0/go.mod h1:iYEXKGkEBhg1PjZQvoYEVPTDkHo1/bjTnfwTeGONTY8= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= diff --git a/grpcall.go b/grpcall.go new file mode 100644 index 0000000..2647420 --- /dev/null +++ b/grpcall.go @@ -0,0 +1,483 @@ +package grpcall + +import ( + "context" + "fmt" + "io" + "sort" + "strings" + "time" + + "github.com/golang/protobuf/jsonpb" + "github.com/jhump/protoreflect/desc" + "github.com/jhump/protoreflect/dynamic" + "github.com/jhump/protoreflect/dynamic/grpcdynamic" + "github.com/samber/lo" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +type Grpcall struct { + conn *grpc.ClientConn + pDescriptor ProtoDescriptor +} + +// 构建请求函数 +func (this Grpcall) Invoke(ctx context.Context, service, method, data string, headers []string) (*Response, error) { + // 获取 gRPC 服务方法描述信息 + mtd, err := this.GetServiceMethodDescriptor(service, method) + if err != nil { + return nil, err + } + + md := this.MakeMetadata(headers) + + ctx = metadata.NewOutgoingContext(ctx, md) + + anyResolver, err := this.GetAnyResolver() + + if err != nil { + return nil, err + } + + // service.method message 构建器 + msgFactory, err := this.MakeServiceMethodMessageFactory(mtd) + + if err != nil { + return nil, err + } + + stub := grpcdynamic.NewStubWithMessageFactory(this.conn, msgFactory) + + if mtd.IsClientStreaming() && mtd.IsServerStreaming() { + return this.invokeBidiStream(ctx, stub, mtd, msgFactory, anyResolver, data) + } + + if mtd.IsServerStreaming() { + return this.invokeServStream(ctx, stub, mtd, msgFactory, anyResolver, data) + } + + return this.invokeUnary(ctx, stub, mtd, msgFactory, anyResolver, data) +} + +func (this Grpcall) invokeUnary(ctx context.Context, stub grpcdynamic.Stub, method *desc.MethodDescriptor, msgFactory *dynamic.MessageFactory, anyResolver jsonpb.AnyResolver, data string) (*Response, error) { + // 创建请求 message + req := msgFactory.NewMessage(method.GetInputType()) + // 将数据格式化到 message + if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil { + return nil, err + } + + respHeaders := &metadata.MD{} + + respMsg, err := stub.InvokeRpc(ctx, method, req, grpc.Header(respHeaders)) + + respStatus, ok := status.FromError(err) + + if !ok || respStatus.Code() != codes.OK { + return nil, fmt.Errorf("grpc call %q failed, %s", method.GetFullyQualifiedName(), err) + } + + marshaler := jsonpb.Marshaler{ + EmitDefaults: true, + AnyResolver: anyResolver, + } + + respData, err := marshaler.MarshalToString(respMsg) + + resp := &Response{ + method: method, + header: respHeaders, + data: respData, + } + + return resp, err +} + +func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub, method *desc.MethodDescriptor, msgFactory *dynamic.MessageFactory, anyResolver jsonpb.AnyResolver, data string) (*Response, error) { + // 创建请求 message + req := msgFactory.NewMessage(method.GetInputType()) + // 将数据格式化到 message + if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil { + return nil, err + } + + cancel_ctx, cancel := context.WithCancel(ctx) + + stream, err := stub.InvokeRpcServerStream(cancel_ctx, method, req) + + if err != nil { + cancel() + return nil, err + } + + resp := &Response{ + method: method, + header: &metadata.MD{}, + recv: make(chan string), + done: make(chan error), + cancel: cancel, + } + + go func() { + msgParser := jsonpb.Marshaler{ + EmitDefaults: true, + AnyResolver: anyResolver, + } + + defer func() { + close(resp.done) + close(resp.recv) + cancel() + }() + + for { + // 从流中接收消息 + msg, err := stream.RecvMsg() + + if err == io.EOF { + resp.done <- nil + return + } + + if err != nil { + resp.done <- err + return + } + + if *resp.header, err = stream.Header(); err != nil { + resp.done <- err + return + } + + if data, err := msgParser.MarshalToString(msg); err != nil { + resp.done <- err + return + } else { + resp.recv <- data + } + + msg.Reset() + } + }() + + return resp, nil +} + +func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub, method *desc.MethodDescriptor, msgFactory *dynamic.MessageFactory, anyResolver jsonpb.AnyResolver, data string) (*Response, error) { + // 创建请求 message + req := msgFactory.NewMessage(method.GetInputType()) + // 将数据格式化到 message + if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil { + return nil, err + } + + cancel_ctx, cancel := context.WithCancel(ctx) + + stream, err := stub.InvokeRpcBidiStream(cancel_ctx, method) + + if err != nil { + cancel() + return nil, err + } + + if err := stream.SendMsg(req); err != nil { + cancel() + return nil, err + } + + resp := &Response{ + method: method, + header: &metadata.MD{}, + recv: make(chan string), + send: make(chan string), + done: make(chan error), + cancel: cancel, + } + + // send + go func() { + defer func() { + stream.CloseSend() + close(resp.send) + cancel() + }() + + for { + select { + case msg, ok := <-resp.send: + if !ok { + resp.done <- nil + return + } + + // 将数据格式化到 message + if err := NewJsonRequestParser(anyResolver, msg).Next(req); err != nil { + resp.done <- err + return + } + + if err := stream.SendMsg(req); err != nil { + resp.done <- err + return + } + + case <-cancel_ctx.Done(): + resp.done <- nil + return + case <-ctx.Done(): + resp.done <- nil + return + } + } + }() + + // recv + go func() { + msgParser := jsonpb.Marshaler{ + EmitDefaults: true, + AnyResolver: anyResolver, + } + + defer func() { + close(resp.done) + close(resp.recv) + cancel() + }() + + for { + // 从流中接收消息 + msg, err := stream.RecvMsg() + + if err == io.EOF { + resp.done <- nil + return + } + + if err != nil { + resp.done <- err + return + } + + if *resp.header, err = stream.Header(); err != nil { + resp.done <- err + return + } + + if data, err := msgParser.MarshalToString(msg); err != nil { + resp.done <- err + return + } else { + resp.recv <- data + } + + msg.Reset() + } + }() + + return resp, nil +} + +func (this Grpcall) GetClientConn() *grpc.ClientConn { + return this.conn +} + +func (this Grpcall) GetAnyResolver() (jsonpb.AnyResolver, error) { + files, err := this.GetAllFilesDescriptor() + + if err != nil { + return nil, err + } + + dExtension := dynamic.ExtensionRegistry{} + + for _, fd := range files { + dExtension.AddExtensionsFromFile(fd) + } + + mf := dynamic.NewMessageFactoryWithExtensionRegistry(&dExtension) + + return dynamic.AnyResolver(mf, files...), nil +} + +func (this Grpcall) GetProtoDescriptor() ProtoDescriptor { + return this.pDescriptor +} + +func (this Grpcall) GetAllFilesDescriptor() ([]*desc.FileDescriptor, error) { + return this.pDescriptor.GetAllFiles() +} + +// 获取所有服务 +func (this Grpcall) GetServices() ([]string, error) { + return this.pDescriptor.ListServices() +} + +// 获取服务下的所有方法 +func (this Grpcall) GetServiceMethods(name string) ([]string, error) { + if dsc, err := this.pDescriptor.FindSymbol(name); err != nil { + return nil, err + } else if sd, ok := dsc.(*desc.ServiceDescriptor); !ok { + return nil, fmt.Errorf("%s, not found methods", name) + } else { + methods := make([]string, 0, len(sd.GetMethods())) + + for _, method := range sd.GetMethods() { + methods = append(methods, method.GetFullyQualifiedName()) + } + + sort.Strings(methods) + return methods, nil + } +} + +// 获取 gRPC 服务描述 +func (this Grpcall) GetServiceDescriptor(service string) (*desc.ServiceDescriptor, error) { + if service == "" { + return nil, fmt.Errorf("service name is invalid") + } + + dsc, err := this.pDescriptor.FindSymbol(service) + + if err != nil { + return nil, err + } + + sd, ok := dsc.(*desc.ServiceDescriptor) + if !ok { + return nil, fmt.Errorf("server not expose service, %q", service) + } + + return sd, nil +} + +// 获取 gRPC service.method 描述 +func (this Grpcall) GetServiceMethodDescriptor(service, method string) (*desc.MethodDescriptor, error) { + serv, err := this.GetServiceDescriptor(service) + + if err != nil { + return nil, err + } + + if mtd := serv.FindMethodByName(method); mtd != nil { + return mtd, nil + } + + return nil, fmt.Errorf("service %q does not include a method named %q", service, method) +} + +// 生成 service.method message +func (this *Grpcall) MakeServiceMethodMessageFactory(method *desc.MethodDescriptor) (*dynamic.MessageFactory, error) { + // 扩展字段的信息,包括扩展字段的名称、标识符、类型等。用于注册扩展、查找扩展、解析扩展以及处理扩展字段的值。 + var ext dynamic.ExtensionRegistry + + excluded := []string{} + + if err := this.RegServiceMethoddMessageExtensions(&ext, method.GetInputType(), excluded); err != nil { + return nil, fmt.Errorf("register service.method extensions for message %s: %v", method.GetInputType().GetFullyQualifiedName(), err) + } + + if err := this.RegServiceMethoddMessageExtensions(&ext, method.GetOutputType(), excluded); err != nil { + return nil, fmt.Errorf("register service.method extensions for message %s: %v", method.GetOutputType().GetFullyQualifiedName(), err) + } + + return dynamic.NewMessageFactoryWithExtensionRegistry(&ext), nil +} + +// 注册 service.method message extensions +func (this Grpcall) RegServiceMethoddMessageExtensions(msgExt *dynamic.ExtensionRegistry, msg *desc.MessageDescriptor, excluded []string) error { + msgTypeName := msg.GetFullyQualifiedName() + + // 避免重复注册 + if lo.Contains(excluded, msgTypeName) { + return nil + } else { + excluded = append(excluded, msgTypeName) + } + + // 注册 message extensions 信息 + if len(msg.GetExtensionRanges()) > 0 { + fds, err := this.pDescriptor.AllExtensionsForType(msgTypeName) + + if err != nil { + return fmt.Errorf("failed to query for extensions of type %s: %v", msgTypeName, err) + } + + for i := 0; i < len(fds); i++ { + if err := msgExt.AddExtension(fds[i]); err != nil { + return fmt.Errorf("could not register extension %s of type %s: %v", fds[i].GetFullyQualifiedName(), msgTypeName, err) + } + } + } + + // 递归注册 message 所有字段 + for _, msgField := range msg.GetFields() { + if msgField.GetMessageType() == nil { + continue + } + if err := this.RegServiceMethoddMessageExtensions(msgExt, msgField.GetMessageType(), excluded); err != nil { + return err + } + } + + return nil +} + +func (this Grpcall) MakeMetadata(headers []string) metadata.MD { + md := make(metadata.MD) + for _, part := range headers { + if part != "" { + pieces := strings.SplitN(part, ":", 2) + if len(pieces) == 1 { + pieces = append(pieces, "") // if no value was specified, just make it "" (maybe the header value doesn't matter) + } + headerName := strings.ToLower(strings.TrimSpace(pieces[0])) + val := strings.TrimSpace(pieces[1]) + if strings.HasSuffix(headerName, "-bin") { + if v, err := metadata_decode(val); err == nil { + val = v + } + } + md[headerName] = append(md[headerName], val) + } + } + return md +} + +func (this Grpcall) Close() error { + return this.conn.Close() +} + +func NewGrpcall(addr string, protosets []string, opts ...grpc.DialOption) (g *Grpcall, err error) { + if addr == "" { + return nil, fmt.Errorf("addr is invalid, %s", addr) + } + + g = &Grpcall{} + + opts = append(opts, + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: 64 * time.Second, + Timeout: 64 * time.Second, + }), + grpc.WithBlock(), + grpc.FailOnNonTempDialError(true), + ) + + if conn, err := grpc.Dial(addr, opts...); err != nil { + return nil, err + } else { + g.conn = conn + } + + if len(protosets) == 0 { + g.pDescriptor, err = NewServProtoDescriptor(g.conn) + } else { + g.pDescriptor, err = NewFileProtoDescriptor(protosets) + } + + return g, err +} diff --git a/grpcall_test.go b/grpcall_test.go new file mode 100644 index 0000000..7a5fe3b --- /dev/null +++ b/grpcall_test.go @@ -0,0 +1,118 @@ +package grpcall + +import ( + "context" + "fmt" + "os" + "testing" + + "google.golang.org/grpc" +) + +var gc *Grpcall + +func TestMain(m *testing.M) { + if _gc, err := NewGrpcall("127.0.0.1:8888", nil, grpc.WithInsecure()); err != nil { + // if _gc, err := NewGrpcall("192.168.0.6:6644", nil, grpc.WithInsecure()); err != nil { + panic(err) + } else { + gc = _gc + } + + defer gc.Close() + + // 执行测试 + code := m.Run() + + // 在这里执行一些清理工作,例如关闭数据库连接等 + + // 退出测试 + os.Exit(code) +} + +func TestGetServices(t *testing.T) { + t.Log(gc.GetServices()) +} + +func TestGetServiceMethods(t *testing.T) { + if s, err := gc.GetServices(); err != nil { + t.Error(err) + } else { + t.Log(gc.GetServiceMethods(s[0])) + } +} + +func TestInvokeUnary(t *testing.T) { + ctx := context.Background() + + if resp, err := gc.Invoke(ctx, "User.UserResourceStatus", "TestInt64Value", `123`, nil); err != nil { + t.Error(err) + } else { + t.Log(resp.Data()) + } +} + +func TestInvokeServStream(t *testing.T) { + ctx := context.Background() + + if resp, err := gc.Invoke(ctx, "User.UserResourceStatus", "GetEvent", `{}`, nil); err != nil { + t.Error(err) + } else if recv, err := resp.Recv(); err != nil { + t.Error(err) + } else if done, err := resp.Done(); err != nil { + t.Error(err) + } else { + flag := make(chan bool) + go func() { + for { + select { + case msg := <-recv: + fmt.Println("msg", msg, resp.Header()) + case err := <-done: + if err != nil { + t.Error(err) + } + flag <- true + return + } + } + }() + <-flag + } +} +func TestInvokeBidiStream(t *testing.T) { + ctx := context.Background() + + if resp, err := gc.Invoke(ctx, "User.UserResourceStatus", "TestBidiStream", `"hello"`, nil); err != nil { + t.Error(err) + } else if send, err := resp.Send(); err != nil { + t.Error(err) + } else if recv, err := resp.Recv(); err != nil { + t.Error(err) + } else if done, err := resp.Done(); err != nil { + t.Error(err) + } else { + flag := make(chan bool) + + go func() { + for { + select { + case msg := <-recv: + fmt.Println("msg", msg, resp.Header()) + case err := <-done: + if err != nil { + t.Error(err) + } + flag <- true + return + } + } + }() + + for i := 0; i < 10; i++ { + send <- fmt.Sprintf(`"abc %d"`, i) + } + + <-flag + } +} diff --git a/proto_descriptor.go b/proto_descriptor.go new file mode 100644 index 0000000..04affc4 --- /dev/null +++ b/proto_descriptor.go @@ -0,0 +1,136 @@ +package grpcall + +import ( + "context" + "fmt" + "sort" + "sync" + + "github.com/jhump/protoreflect/desc" + "github.com/jhump/protoreflect/dynamic" + "github.com/jhump/protoreflect/grpcreflect" + "github.com/samber/lo" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" +) + +type ProtoDescriptor interface { + // gRPC 服务列表 + ListServices() ([]string, error) + // 查找 proto 中定义的消息、服务或枚举等结构的描述符 + FindSymbol(string) (desc.Descriptor, error) + // 获取消息类型的所有扩展字段的描述符 + AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) + // 获取 proto 文件描述信息 + GetAllFiles() ([]*desc.FileDescriptor, error) +} + +type WithFilesProtoDescriptor interface { + GetAllFiles() ([]*desc.FileDescriptor, error) +} + +type ServProtoDescriptor struct { + client *grpcreflect.Client +} + +type FileProtoDescriptor struct { + files map[string]*desc.FileDescriptor + er *dynamic.ExtensionRegistry + erInit sync.Once +} + +func NewServProtoDescriptor(client *grpc.ClientConn) (ProtoDescriptor, error) { + ctx := context.Background() + + pd := &ServProtoDescriptor{ + client: grpcreflect.NewClient(ctx, grpc_reflection_v1alpha.NewServerReflectionClient(client)), + } + + return pd, nil +} + +func (this ServProtoDescriptor) ListServices() ([]string, error) { + return this.client.ListServices() +} + +func (this ServProtoDescriptor) FindSymbol(name string) (desc.Descriptor, error) { + file, err := this.client.FileContainingSymbol(name) + + if err != nil { + return nil, err + } + + d := file.FindSymbol(name) + + if d == nil { + return nil, fmt.Errorf("Symbol Not Found, %s", name) + } + + return d, nil +} + +func (this ServProtoDescriptor) AllExtensionsForType(typ string) ([]*desc.FieldDescriptor, error) { + exts := []*desc.FieldDescriptor{} + + nums, err := this.client.AllExtensionNumbersForType(typ) + + if err != nil { + return nil, err + } + + for _, fieldNum := range nums { + ext, err := this.client.ResolveExtension(typ, fieldNum) + if err != nil { + return nil, err + } + exts = append(exts, ext) + } + + return exts, nil +} + +func (this ServProtoDescriptor) GetAllFiles() ([]*desc.FileDescriptor, error) { + var files []*desc.FileDescriptor + + if services, err := this.ListServices(); err != nil { + return nil, err + } else { + temp_fields := map[string]*desc.FileDescriptor{} + for _, name := range services { + d, err := this.FindSymbol(name) + if err != nil { + return nil, err + } + this.addFileDescriptorsToCache(d.GetFile(), temp_fields) + } + files = lo.Values(temp_fields) + } + + sort.Sort(FileDescriptorSorter(files)) + + return files, nil +} + +func (this ServProtoDescriptor) addFileDescriptorsToCache(file *desc.FileDescriptor, cache map[string]*desc.FileDescriptor) error { + if _, ok := cache[file.GetName()]; ok { + return nil + } else { + cache[file.GetName()] = file + } + + for _, dep := range file.GetDependencies() { + this.addFileDescriptorsToCache(dep, cache) + } + + return nil +} + +func NewFileProtoDescriptor(protosets []string) (ProtoDescriptor, error) { + return nil, nil +} + +func (this FileProtoDescriptor) GetAllFiles() ([]*desc.FileDescriptor, error) { + files := lo.Values(this.files) + sort.Sort(FileDescriptorSorter(files)) + return files, nil +} diff --git a/request.go b/request.go new file mode 100644 index 0000000..a9dafce --- /dev/null +++ b/request.go @@ -0,0 +1,35 @@ +package grpcall + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" +) + +type RequestData func(msg proto.Message) error + +type RequestParser interface { + Next(msg proto.Message) error +} + +type JsonRequestParser struct { + dec *json.Decoder + unmarshaler jsonpb.Unmarshaler +} + +func (this JsonRequestParser) Next(msg proto.Message) error { + if err := this.unmarshaler.UnmarshalNext(this.dec, msg); err != nil { + return fmt.Errorf("unmarshal request json data error, %s", err) + } + return nil +} + +func NewJsonRequestParser(resolver jsonpb.AnyResolver, data string) RequestParser { + return &JsonRequestParser{ + dec: json.NewDecoder(strings.NewReader(data)), + unmarshaler: jsonpb.Unmarshaler{AnyResolver: resolver}, + } +} diff --git a/request_test.go b/request_test.go new file mode 100644 index 0000000..511a668 --- /dev/null +++ b/request_test.go @@ -0,0 +1,66 @@ +package grpcall + +import ( + "testing" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/ptypes/wrappers" + "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +func TestJsonpbUnmarshalString(t *testing.T) { + msg := &wrappers.Int64Value{} + + if err := jsonpb.UnmarshalString(`123`, msg); err != nil { + t.Error(err) + } + + t.Log(msg) +} + +func TestJsonpbMarshalString(t *testing.T) { + msg := &wrappers.Int64Value{Value: 123} + + marshaler := &jsonpb.Marshaler{} + + if json, err := marshaler.MarshalToString(msg); err != nil { + t.Error(err) + } else { + t.Log(json) // "123" + } +} + +func TestJsonpbMarshalEmpty(t *testing.T) { + msg := &emptypb.Empty{} + + marshaler := &jsonpb.Marshaler{} + + if json, err := marshaler.MarshalToString(msg); err != nil { + t.Error(err) + } else { + t.Log(json) // {} + } +} + +func TestNewJsonRequestParser(t *testing.T) { + var str_msg wrapperspb.StringValue + + inData := `"abc"` + + if err := NewJsonRequestParser(nil, inData).Next(&str_msg); err != nil { + t.Error(err) + } + + t.Log("string", str_msg.GetValue()) + + var int_msg wrapperspb.Int64Value + + inData = `"10"` // or `10` + + if err := NewJsonRequestParser(nil, inData).Next(&int_msg); err != nil { + t.Error(err) + } + + t.Log("int", int_msg.GetValue()) +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..7db359a --- /dev/null +++ b/response.go @@ -0,0 +1,73 @@ +package grpcall + +import ( + "context" + "fmt" + + "github.com/jhump/protoreflect/desc" + "google.golang.org/grpc/metadata" +) + +type Response struct { + method *desc.MethodDescriptor + header *metadata.MD + data string + send chan string + recv chan string + done chan error + cancel context.CancelFunc +} + +func (this Response) Header() metadata.MD { + return *this.header +} + +func (this Response) Data() (string, error) { + if this.method.IsClientStreaming() || this.method.IsServerStreaming() { + return "", fmt.Errorf("%q is %s, cannot use unary data", this.method.GetFullyQualifiedName(), this.GetCategory()) + } + return this.data, nil +} + +// 发送数据流 +func (this Response) Send() (chan<- string, error) { + if !this.method.IsClientStreaming() || !this.method.IsServerStreaming() { + return nil, fmt.Errorf("%q is %s, cannot use send streaming func", this.method.GetFullyQualifiedName(), this.GetCategory()) + } + return this.send, nil +} + +// 接收数据流 +func (this Response) Recv() (<-chan string, error) { + if !this.method.IsServerStreaming() { + return nil, fmt.Errorf("%q is %s, cannot use streaming func", this.method.GetFullyQualifiedName(), this.GetCategory()) + } + return this.recv, nil +} + +// 数据是否响应结束 +func (this Response) Done() (<-chan error, error) { + if !this.method.IsServerStreaming() { + return nil, fmt.Errorf("%q is %s, cannot use streaming func", this.method.GetFullyQualifiedName(), this.GetCategory()) + } + return this.done, nil +} + +// 取消接收数据 +func (this Response) Cancel() { + if this.method.IsServerStreaming() { + this.cancel() + } +} + +func (this Response) GetCategory() string { + if this.method.IsClientStreaming() && this.method.IsServerStreaming() { + return "bidi-streaming" + } else if this.method.IsClientStreaming() { + return "client-streaming" + } else if this.method.IsServerStreaming() { + return "server-streaming" + } else { + return "unary" + } +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..5748324 --- /dev/null +++ b/utils.go @@ -0,0 +1,41 @@ +package grpcall + +import ( + "encoding/base64" + + "github.com/jhump/protoreflect/desc" +) + +type FileDescriptorSorter []*desc.FileDescriptor + +func (this FileDescriptorSorter) Len() int { + return len(this) +} + +func (this FileDescriptorSorter) Less(i, j int) bool { + return this[i].GetName() < this[j].GetName() +} + +func (this FileDescriptorSorter) Swap(i, j int) { + this[i], this[j] = this[j], this[i] +} + +var base64s = []*base64.Encoding{base64.StdEncoding, base64.URLEncoding, base64.RawStdEncoding, base64.RawURLEncoding} + +func metadata_decode(val string) (string, error) { + var firstErr error + var b []byte + // we are lenient and can accept any of the flavors of base64 encoding + for _, d := range base64s { + var err error + b, err = d.DecodeString(val) + if err != nil { + if firstErr == nil { + firstErr = err + } + continue + } + return string(b), nil + } + return "", firstErr +}