484 lines
11 KiB
Go
484 lines
11 KiB
Go
package grpcall
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/golang/protobuf/jsonpb"
|
|
"github.com/jhump/protoreflect/desc"
|
|
"github.com/jhump/protoreflect/dynamic"
|
|
"github.com/jhump/protoreflect/dynamic/grpcdynamic"
|
|
"github.com/samber/lo"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/keepalive"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
type Grpcall struct {
|
|
conn *grpc.ClientConn
|
|
pDescriptor ProtoDescriptor
|
|
}
|
|
|
|
// 构建请求函数
|
|
func (this Grpcall) Invoke(service, method, data string, headers []string) (*Response, error) {
|
|
// 获取 gRPC 服务方法描述信息
|
|
mtd, err := this.GetServiceMethodDescriptor(service, method)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
md := this.MakeMetadata(headers)
|
|
|
|
ctx := metadata.NewOutgoingContext(context.Background(), md)
|
|
|
|
anyResolver, err := this.GetAnyResolver()
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// service.method message 构建器
|
|
msgFactory, err := this.MakeServiceMethodMessageFactory(mtd)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
stub := grpcdynamic.NewStubWithMessageFactory(this.conn, msgFactory)
|
|
|
|
if mtd.IsClientStreaming() && mtd.IsServerStreaming() {
|
|
return this.invokeBidiStream(ctx, stub, mtd, msgFactory, anyResolver, data)
|
|
}
|
|
|
|
if mtd.IsServerStreaming() {
|
|
return this.invokeServStream(ctx, stub, mtd, msgFactory, anyResolver, data)
|
|
}
|
|
|
|
return this.invokeUnary(ctx, stub, mtd, msgFactory, anyResolver, data)
|
|
}
|
|
|
|
func (this Grpcall) invokeUnary(ctx context.Context, stub grpcdynamic.Stub, method *desc.MethodDescriptor, msgFactory *dynamic.MessageFactory, anyResolver jsonpb.AnyResolver, data string) (*Response, error) {
|
|
// 创建请求 message
|
|
req := msgFactory.NewMessage(method.GetInputType())
|
|
// 将数据格式化到 message
|
|
if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
respHeaders := &metadata.MD{}
|
|
|
|
respMsg, err := stub.InvokeRpc(ctx, method, req, grpc.Header(respHeaders))
|
|
|
|
respStatus, ok := status.FromError(err)
|
|
|
|
if !ok || respStatus.Code() != codes.OK {
|
|
return nil, fmt.Errorf("grpc call %q failed, %s", method.GetFullyQualifiedName(), err)
|
|
}
|
|
|
|
marshaler := jsonpb.Marshaler{
|
|
EmitDefaults: true,
|
|
AnyResolver: anyResolver,
|
|
}
|
|
|
|
respData, err := marshaler.MarshalToString(respMsg)
|
|
|
|
resp := &Response{
|
|
method: method,
|
|
header: respHeaders,
|
|
data: respData,
|
|
}
|
|
|
|
return resp, err
|
|
}
|
|
|
|
func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub, method *desc.MethodDescriptor, msgFactory *dynamic.MessageFactory, anyResolver jsonpb.AnyResolver, data string) (*Response, error) {
|
|
// 创建请求 message
|
|
req := msgFactory.NewMessage(method.GetInputType())
|
|
// 将数据格式化到 message
|
|
if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cancel_ctx, cancel := context.WithCancel(ctx)
|
|
|
|
stream, err := stub.InvokeRpcServerStream(cancel_ctx, method, req)
|
|
|
|
if err != nil {
|
|
cancel()
|
|
return nil, err
|
|
}
|
|
|
|
resp := &Response{
|
|
method: method,
|
|
header: &metadata.MD{},
|
|
recv: make(chan string),
|
|
done: make(chan error),
|
|
cancel: cancel,
|
|
}
|
|
|
|
go func() {
|
|
msgParser := jsonpb.Marshaler{
|
|
EmitDefaults: true,
|
|
AnyResolver: anyResolver,
|
|
}
|
|
|
|
defer func() {
|
|
close(resp.done)
|
|
close(resp.recv)
|
|
cancel()
|
|
}()
|
|
|
|
for {
|
|
// 从流中接收消息
|
|
msg, err := stream.RecvMsg()
|
|
|
|
if err == io.EOF {
|
|
resp.done <- nil
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
resp.done <- err
|
|
return
|
|
}
|
|
|
|
if *resp.header, err = stream.Header(); err != nil {
|
|
resp.done <- err
|
|
return
|
|
}
|
|
|
|
if data, err := msgParser.MarshalToString(msg); err != nil {
|
|
resp.done <- err
|
|
return
|
|
} else {
|
|
resp.recv <- data
|
|
}
|
|
|
|
msg.Reset()
|
|
}
|
|
}()
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub, method *desc.MethodDescriptor, msgFactory *dynamic.MessageFactory, anyResolver jsonpb.AnyResolver, data string) (*Response, error) {
|
|
// 创建请求 message
|
|
req := msgFactory.NewMessage(method.GetInputType())
|
|
// 将数据格式化到 message
|
|
if err := NewJsonRequestParser(anyResolver, data).Next(req); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cancel_ctx, cancel := context.WithCancel(ctx)
|
|
|
|
stream, err := stub.InvokeRpcBidiStream(cancel_ctx, method)
|
|
|
|
if err != nil {
|
|
cancel()
|
|
return nil, err
|
|
}
|
|
|
|
if err := stream.SendMsg(req); err != nil {
|
|
cancel()
|
|
return nil, err
|
|
}
|
|
|
|
resp := &Response{
|
|
method: method,
|
|
header: &metadata.MD{},
|
|
recv: make(chan string),
|
|
send: make(chan string),
|
|
done: make(chan error),
|
|
cancel: cancel,
|
|
}
|
|
|
|
// send
|
|
go func() {
|
|
defer func() {
|
|
stream.CloseSend()
|
|
close(resp.send)
|
|
cancel()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case msg, ok := <-resp.send:
|
|
if !ok {
|
|
resp.done <- nil
|
|
return
|
|
}
|
|
|
|
// 将数据格式化到 message
|
|
if err := NewJsonRequestParser(anyResolver, msg).Next(req); err != nil {
|
|
resp.done <- err
|
|
return
|
|
}
|
|
|
|
if err := stream.SendMsg(req); err != nil {
|
|
resp.done <- err
|
|
return
|
|
}
|
|
|
|
case <-cancel_ctx.Done():
|
|
resp.done <- nil
|
|
return
|
|
case <-ctx.Done():
|
|
resp.done <- nil
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
// recv
|
|
go func() {
|
|
msgParser := jsonpb.Marshaler{
|
|
EmitDefaults: true,
|
|
AnyResolver: anyResolver,
|
|
}
|
|
|
|
defer func() {
|
|
close(resp.done)
|
|
close(resp.recv)
|
|
cancel()
|
|
}()
|
|
|
|
for {
|
|
// 从流中接收消息
|
|
msg, err := stream.RecvMsg()
|
|
|
|
if err == io.EOF {
|
|
resp.done <- nil
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
resp.done <- err
|
|
return
|
|
}
|
|
|
|
if *resp.header, err = stream.Header(); err != nil {
|
|
resp.done <- err
|
|
return
|
|
}
|
|
|
|
if data, err := msgParser.MarshalToString(msg); err != nil {
|
|
resp.done <- err
|
|
return
|
|
} else {
|
|
resp.recv <- data
|
|
}
|
|
|
|
msg.Reset()
|
|
}
|
|
}()
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (this Grpcall) GetClientConn() *grpc.ClientConn {
|
|
return this.conn
|
|
}
|
|
|
|
func (this Grpcall) GetAnyResolver() (jsonpb.AnyResolver, error) {
|
|
files, err := this.GetAllFilesDescriptor()
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
dExtension := dynamic.ExtensionRegistry{}
|
|
|
|
for _, fd := range files {
|
|
dExtension.AddExtensionsFromFile(fd)
|
|
}
|
|
|
|
mf := dynamic.NewMessageFactoryWithExtensionRegistry(&dExtension)
|
|
|
|
return dynamic.AnyResolver(mf, files...), nil
|
|
}
|
|
|
|
func (this Grpcall) GetProtoDescriptor() ProtoDescriptor {
|
|
return this.pDescriptor
|
|
}
|
|
|
|
func (this Grpcall) GetAllFilesDescriptor() ([]*desc.FileDescriptor, error) {
|
|
return this.pDescriptor.GetAllFiles()
|
|
}
|
|
|
|
// 获取所有服务
|
|
func (this Grpcall) GetServices() ([]string, error) {
|
|
return this.pDescriptor.ListServices()
|
|
}
|
|
|
|
// 获取服务下的所有方法
|
|
func (this Grpcall) GetServiceMethods(name string) ([]string, error) {
|
|
if dsc, err := this.pDescriptor.FindSymbol(name); err != nil {
|
|
return nil, err
|
|
} else if sd, ok := dsc.(*desc.ServiceDescriptor); !ok {
|
|
return nil, fmt.Errorf("%s, not found methods", name)
|
|
} else {
|
|
methods := make([]string, 0, len(sd.GetMethods()))
|
|
|
|
for _, method := range sd.GetMethods() {
|
|
methods = append(methods, method.GetFullyQualifiedName())
|
|
}
|
|
|
|
sort.Strings(methods)
|
|
return methods, nil
|
|
}
|
|
}
|
|
|
|
// 获取 gRPC 服务描述
|
|
func (this Grpcall) GetServiceDescriptor(service string) (*desc.ServiceDescriptor, error) {
|
|
if service == "" {
|
|
return nil, fmt.Errorf("service name is invalid")
|
|
}
|
|
|
|
dsc, err := this.pDescriptor.FindSymbol(service)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sd, ok := dsc.(*desc.ServiceDescriptor)
|
|
if !ok {
|
|
return nil, fmt.Errorf("server not expose service, %q", service)
|
|
}
|
|
|
|
return sd, nil
|
|
}
|
|
|
|
// 获取 gRPC service.method 描述
|
|
func (this Grpcall) GetServiceMethodDescriptor(service, method string) (*desc.MethodDescriptor, error) {
|
|
serv, err := this.GetServiceDescriptor(service)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if mtd := serv.FindMethodByName(method); mtd != nil {
|
|
return mtd, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("service %q does not include a method named %q", service, method)
|
|
}
|
|
|
|
// 生成 service.method message
|
|
func (this *Grpcall) MakeServiceMethodMessageFactory(method *desc.MethodDescriptor) (*dynamic.MessageFactory, error) {
|
|
// 扩展字段的信息,包括扩展字段的名称、标识符、类型等。用于注册扩展、查找扩展、解析扩展以及处理扩展字段的值。
|
|
var ext dynamic.ExtensionRegistry
|
|
|
|
excluded := []string{}
|
|
|
|
if err := this.RegServiceMethoddMessageExtensions(&ext, method.GetInputType(), excluded); err != nil {
|
|
return nil, fmt.Errorf("register service.method extensions for message %s: %v", method.GetInputType().GetFullyQualifiedName(), err)
|
|
}
|
|
|
|
if err := this.RegServiceMethoddMessageExtensions(&ext, method.GetOutputType(), excluded); err != nil {
|
|
return nil, fmt.Errorf("register service.method extensions for message %s: %v", method.GetOutputType().GetFullyQualifiedName(), err)
|
|
}
|
|
|
|
return dynamic.NewMessageFactoryWithExtensionRegistry(&ext), nil
|
|
}
|
|
|
|
// 注册 service.method message extensions
|
|
func (this Grpcall) RegServiceMethoddMessageExtensions(msgExt *dynamic.ExtensionRegistry, msg *desc.MessageDescriptor, excluded []string) error {
|
|
msgTypeName := msg.GetFullyQualifiedName()
|
|
|
|
// 避免重复注册
|
|
if lo.Contains(excluded, msgTypeName) {
|
|
return nil
|
|
} else {
|
|
excluded = append(excluded, msgTypeName)
|
|
}
|
|
|
|
// 注册 message extensions 信息
|
|
if len(msg.GetExtensionRanges()) > 0 {
|
|
fds, err := this.pDescriptor.AllExtensionsForType(msgTypeName)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to query for extensions of type %s: %v", msgTypeName, err)
|
|
}
|
|
|
|
for i := 0; i < len(fds); i++ {
|
|
if err := msgExt.AddExtension(fds[i]); err != nil {
|
|
return fmt.Errorf("could not register extension %s of type %s: %v", fds[i].GetFullyQualifiedName(), msgTypeName, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// 递归注册 message 所有字段
|
|
for _, msgField := range msg.GetFields() {
|
|
if msgField.GetMessageType() == nil {
|
|
continue
|
|
}
|
|
if err := this.RegServiceMethoddMessageExtensions(msgExt, msgField.GetMessageType(), excluded); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (this Grpcall) MakeMetadata(headers []string) metadata.MD {
|
|
md := make(metadata.MD)
|
|
for _, part := range headers {
|
|
if part != "" {
|
|
pieces := strings.SplitN(part, ":", 2)
|
|
if len(pieces) == 1 {
|
|
pieces = append(pieces, "") // if no value was specified, just make it "" (maybe the header value doesn't matter)
|
|
}
|
|
headerName := strings.ToLower(strings.TrimSpace(pieces[0]))
|
|
val := strings.TrimSpace(pieces[1])
|
|
if strings.HasSuffix(headerName, "-bin") {
|
|
if v, err := metadata_decode(val); err == nil {
|
|
val = v
|
|
}
|
|
}
|
|
md[headerName] = append(md[headerName], val)
|
|
}
|
|
}
|
|
return md
|
|
}
|
|
|
|
func (this Grpcall) Close() error {
|
|
return this.conn.Close()
|
|
}
|
|
|
|
func NewGrpcall(addr string, protosets []string, opts ...grpc.DialOption) (g *Grpcall, err error) {
|
|
if addr == "" {
|
|
return nil, fmt.Errorf("addr is invalid, %s", addr)
|
|
}
|
|
|
|
g = &Grpcall{}
|
|
|
|
opts = append(opts,
|
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
|
Time: 64 * time.Second,
|
|
Timeout: 64 * time.Second,
|
|
}),
|
|
grpc.WithBlock(),
|
|
grpc.FailOnNonTempDialError(true),
|
|
)
|
|
|
|
if conn, err := grpc.Dial(addr, opts...); err != nil {
|
|
return nil, err
|
|
} else {
|
|
g.conn = conn
|
|
}
|
|
|
|
if len(protosets) == 0 {
|
|
g.pDescriptor, err = NewServProtoDescriptor(g.conn)
|
|
} else {
|
|
g.pDescriptor, err = NewFileProtoDescriptor(protosets)
|
|
}
|
|
|
|
return g, err
|
|
}
|