package grpcall import ( "context" "fmt" "io" "sort" "strings" "time" "github.com/golang/protobuf/jsonpb" "github.com/jhump/protoreflect/desc" "github.com/jhump/protoreflect/dynamic" "github.com/jhump/protoreflect/dynamic/grpcdynamic" "github.com/samber/lo" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) type Grpcall struct { conn *grpc.ClientConn pDescriptor ProtoDescriptor } // 构建请求函数 func (this Grpcall) Invoke(ctx context.Context, service, method, data string, headers []string) (*Response, error) { // 获取 gRPC 服务方法描述信息 mtd, err := this.GetServiceMethodDescriptor(service, method) if err != nil { return nil, err } md := this.MakeMetadata(headers) ctx = metadata.NewOutgoingContext(ctx, md) anyResolver, err := this.GetAnyResolver() if err != nil { return nil, err } // service.method message 构建器 msgFactory, err := this.MakeServiceMethodMessageFactory(mtd) if err != nil { return nil, err } stub := grpcdynamic.NewStubWithMessageFactory(this.conn, msgFactory) if mtd.IsClientStreaming() && mtd.IsServerStreaming() { return this.invokeBidiStream(ctx, stub, mtd, msgFactory, anyResolver, data) } if mtd.IsServerStreaming() { return this.invokeServStream(ctx, stub, mtd, msgFactory, anyResolver, data) } return this.invokeUnary(ctx, stub, mtd, msgFactory, anyResolver, data) } func (this Grpcall) invokeUnary(ctx context.Context, stub grpcdynamic.Stub, method *desc.MethodDescriptor, msgFactory *dynamic.MessageFactory, anyResolver jsonpb.AnyResolver, data string) (*Response, error) { // 创建请求 message req := msgFactory.NewMessage(method.GetInputType()) // 将数据格式化到 message if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil { return nil, err } respHeaders := &metadata.MD{} respMsg, err := stub.InvokeRpc(ctx, method, req, grpc.Header(respHeaders)) respStatus, ok := status.FromError(err) if !ok || respStatus.Code() != codes.OK { return nil, fmt.Errorf("grpc call %q failed, %s", method.GetFullyQualifiedName(), err) } marshaler := jsonpb.Marshaler{ EmitDefaults: true, AnyResolver: anyResolver, } respData, err := marshaler.MarshalToString(respMsg) resp := &Response{ method: method, header: respHeaders, data: respData, } return resp, err } func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub, method *desc.MethodDescriptor, msgFactory *dynamic.MessageFactory, anyResolver jsonpb.AnyResolver, data string) (*Response, error) { // 创建请求 message req := msgFactory.NewMessage(method.GetInputType()) // 将数据格式化到 message if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil { return nil, err } cancel_ctx, cancel := context.WithCancel(ctx) stream, err := stub.InvokeRpcServerStream(cancel_ctx, method, req) if err != nil { cancel() return nil, err } resp := &Response{ method: method, header: &metadata.MD{}, recv: make(chan string), done: make(chan error), cancel: cancel, } go func() { msgParser := jsonpb.Marshaler{ EmitDefaults: true, AnyResolver: anyResolver, } defer func() { close(resp.done) close(resp.recv) cancel() }() for { // 从流中接收消息 msg, err := stream.RecvMsg() if err == io.EOF { resp.done <- nil return } if err != nil { resp.done <- err return } if *resp.header, err = stream.Header(); err != nil { resp.done <- err return } if data, err := msgParser.MarshalToString(msg); err != nil { resp.done <- err return } else { resp.recv <- data } msg.Reset() } }() return resp, nil } func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub, method *desc.MethodDescriptor, msgFactory *dynamic.MessageFactory, anyResolver jsonpb.AnyResolver, data string) (*Response, error) { // 创建请求 message req := msgFactory.NewMessage(method.GetInputType()) // 将数据格式化到 message if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil { return nil, err } cancel_ctx, cancel := context.WithCancel(ctx) stream, err := stub.InvokeRpcBidiStream(cancel_ctx, method) if err != nil { cancel() return nil, err } if err := stream.SendMsg(req); err != nil { cancel() return nil, err } resp := &Response{ method: method, header: &metadata.MD{}, recv: make(chan string), send: make(chan string), done: make(chan error), cancel: cancel, } // send go func() { defer func() { stream.CloseSend() close(resp.send) cancel() }() for { select { case msg, ok := <-resp.send: if !ok { resp.done <- nil return } // 将数据格式化到 message if err := NewJsonRequestParser(anyResolver, msg).Next(req); err != nil { resp.done <- err return } if err := stream.SendMsg(req); err != nil { resp.done <- err return } case <-cancel_ctx.Done(): resp.done <- nil return case <-ctx.Done(): resp.done <- nil return } } }() // recv go func() { msgParser := jsonpb.Marshaler{ EmitDefaults: true, AnyResolver: anyResolver, } defer func() { close(resp.done) close(resp.recv) cancel() }() for { // 从流中接收消息 msg, err := stream.RecvMsg() if err == io.EOF { resp.done <- nil return } if err != nil { resp.done <- err return } if *resp.header, err = stream.Header(); err != nil { resp.done <- err return } if data, err := msgParser.MarshalToString(msg); err != nil { resp.done <- err return } else { resp.recv <- data } msg.Reset() } }() return resp, nil } func (this Grpcall) GetClientConn() *grpc.ClientConn { return this.conn } func (this Grpcall) GetAnyResolver() (jsonpb.AnyResolver, error) { files, err := this.GetAllFilesDescriptor() if err != nil { return nil, err } dExtension := dynamic.ExtensionRegistry{} for _, fd := range files { dExtension.AddExtensionsFromFile(fd) } mf := dynamic.NewMessageFactoryWithExtensionRegistry(&dExtension) return dynamic.AnyResolver(mf, files...), nil } func (this Grpcall) GetProtoDescriptor() ProtoDescriptor { return this.pDescriptor } func (this Grpcall) GetAllFilesDescriptor() ([]*desc.FileDescriptor, error) { return this.pDescriptor.GetAllFiles() } // 获取所有服务 func (this Grpcall) GetServices() ([]string, error) { return this.pDescriptor.ListServices() } // 获取服务下的所有方法 func (this Grpcall) GetServiceMethods(name string) ([]string, error) { if dsc, err := this.pDescriptor.FindSymbol(name); err != nil { return nil, err } else if sd, ok := dsc.(*desc.ServiceDescriptor); !ok { return nil, fmt.Errorf("%s, not found methods", name) } else { methods := make([]string, 0, len(sd.GetMethods())) for _, method := range sd.GetMethods() { methods = append(methods, method.GetFullyQualifiedName()) } sort.Strings(methods) return methods, nil } } // 获取 gRPC 服务描述 func (this Grpcall) GetServiceDescriptor(service string) (*desc.ServiceDescriptor, error) { if service == "" { return nil, fmt.Errorf("service name is invalid") } dsc, err := this.pDescriptor.FindSymbol(service) if err != nil { return nil, err } sd, ok := dsc.(*desc.ServiceDescriptor) if !ok { return nil, fmt.Errorf("server not expose service, %q", service) } return sd, nil } // 获取 gRPC service.method 描述 func (this Grpcall) GetServiceMethodDescriptor(service, method string) (*desc.MethodDescriptor, error) { serv, err := this.GetServiceDescriptor(service) if err != nil { return nil, err } if mtd := serv.FindMethodByName(method); mtd != nil { return mtd, nil } return nil, fmt.Errorf("service %q does not include a method named %q", service, method) } // 生成 service.method message func (this *Grpcall) MakeServiceMethodMessageFactory(method *desc.MethodDescriptor) (*dynamic.MessageFactory, error) { // 扩展字段的信息,包括扩展字段的名称、标识符、类型等。用于注册扩展、查找扩展、解析扩展以及处理扩展字段的值。 var ext dynamic.ExtensionRegistry excluded := []string{} if err := this.RegServiceMethoddMessageExtensions(&ext, method.GetInputType(), excluded); err != nil { return nil, fmt.Errorf("register service.method extensions for message %s: %v", method.GetInputType().GetFullyQualifiedName(), err) } if err := this.RegServiceMethoddMessageExtensions(&ext, method.GetOutputType(), excluded); err != nil { return nil, fmt.Errorf("register service.method extensions for message %s: %v", method.GetOutputType().GetFullyQualifiedName(), err) } return dynamic.NewMessageFactoryWithExtensionRegistry(&ext), nil } // 注册 service.method message extensions func (this Grpcall) RegServiceMethoddMessageExtensions(msgExt *dynamic.ExtensionRegistry, msg *desc.MessageDescriptor, excluded []string) error { msgTypeName := msg.GetFullyQualifiedName() // 避免重复注册 if lo.Contains(excluded, msgTypeName) { return nil } else { excluded = append(excluded, msgTypeName) } // 注册 message extensions 信息 if len(msg.GetExtensionRanges()) > 0 { fds, err := this.pDescriptor.AllExtensionsForType(msgTypeName) if err != nil { return fmt.Errorf("failed to query for extensions of type %s: %v", msgTypeName, err) } for i := 0; i < len(fds); i++ { if err := msgExt.AddExtension(fds[i]); err != nil { return fmt.Errorf("could not register extension %s of type %s: %v", fds[i].GetFullyQualifiedName(), msgTypeName, err) } } } // 递归注册 message 所有字段 for _, msgField := range msg.GetFields() { if msgField.GetMessageType() == nil { continue } if err := this.RegServiceMethoddMessageExtensions(msgExt, msgField.GetMessageType(), excluded); err != nil { return err } } return nil } func (this Grpcall) MakeMetadata(headers []string) metadata.MD { md := make(metadata.MD) for _, part := range headers { if part != "" { pieces := strings.SplitN(part, ":", 2) if len(pieces) == 1 { pieces = append(pieces, "") // if no value was specified, just make it "" (maybe the header value doesn't matter) } headerName := strings.ToLower(strings.TrimSpace(pieces[0])) val := strings.TrimSpace(pieces[1]) if strings.HasSuffix(headerName, "-bin") { if v, err := metadata_decode(val); err == nil { val = v } } md[headerName] = append(md[headerName], val) } } return md } func (this Grpcall) Close() error { return this.conn.Close() } func NewGrpcall(addr string, protosets []string, opts ...grpc.DialOption) (g *Grpcall, err error) { if addr == "" { return nil, fmt.Errorf("addr is invalid, %s", addr) } g = &Grpcall{} opts = append(opts, grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 64 * time.Second, Timeout: 64 * time.Second, }), grpc.WithBlock(), grpc.FailOnNonTempDialError(true), ) if conn, err := grpc.Dial(addr, opts...); err != nil { return nil, err } else { g.conn = conn } if len(protosets) == 0 { g.pDescriptor, err = NewServProtoDescriptor(g.conn) } else { g.pDescriptor, err = NewFileProtoDescriptor(protosets) } return g, err }