grpcall/grpcall.go

491 lines
12 KiB
Go

package grpcall
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"sort"
"strings"
"time"
"github.com/golang/protobuf/jsonpb"
"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/dynamic"
"github.com/jhump/protoreflect/dynamic/grpcdynamic"
"github.com/samber/lo"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
type Grpcall struct {
conn *grpc.ClientConn
pDescriptor ProtoDescriptor
}
// 构建请求函数
func (this Grpcall) Invoke(service, method, data string, headers []string) (*Response, error) {
// 获取 gRPC 服务方法描述信息
mtd, err := this.GetServiceMethodDescriptor(service, method)
if err != nil {
return nil, err
}
md := this.MakeMetadata(headers)
ctx := metadata.NewOutgoingContext(context.Background(), md)
anyResolver, err := this.GetAnyResolver()
if err != nil {
return nil, err
}
// service.method message 构建器
msgFactory, err := this.MakeServiceMethodMessageFactory(mtd)
if err != nil {
return nil, err
}
stub := grpcdynamic.NewStubWithMessageFactory(this.conn, msgFactory)
if mtd.IsClientStreaming() && mtd.IsServerStreaming() {
return this.invokeBidiStream(ctx, stub, mtd, msgFactory, anyResolver, data)
}
if mtd.IsServerStreaming() {
return this.invokeServStream(ctx, stub, mtd, msgFactory, anyResolver, data)
}
return this.invokeUnary(ctx, stub, mtd, msgFactory, anyResolver, data)
}
func (this Grpcall) invokeUnary(ctx context.Context, stub grpcdynamic.Stub, method *desc.MethodDescriptor, msgFactory *dynamic.MessageFactory, anyResolver jsonpb.AnyResolver, data string) (*Response, error) {
// 创建请求 message
req := msgFactory.NewMessage(method.GetInputType())
// 将数据格式化到 message
if err := NewJsonRequestParser(anyResolver, strings.NewReader(data)).Next(req); err != nil {
return nil, err
}
respHeaders := &metadata.MD{}
respMsg, err := stub.InvokeRpc(ctx, method, req, grpc.Header(respHeaders))
respStatus, ok := status.FromError(err)
if !ok || respStatus.Code() != codes.OK {
return nil, fmt.Errorf("grpc call %q failed, %s", method.GetFullyQualifiedName(), err)
}
marshaler := jsonpb.Marshaler{
EmitDefaults: true,
AnyResolver: anyResolver,
}
pReader, pWriter := io.Pipe()
resp := &Response{
method: method,
header: respHeaders,
recv: pReader,
}
go func() {
defer pWriter.Close()
if err = marshaler.Marshal(pWriter, respMsg); err != nil {
pWriter.CloseWithError(err)
}
}()
return resp, err
}
func (this Grpcall) invokeServStream(ctx context.Context, stub grpcdynamic.Stub, method *desc.MethodDescriptor, msgFactory *dynamic.MessageFactory, anyResolver jsonpb.AnyResolver, data string) (*Response, error) {
// 创建请求 message
req := msgFactory.NewMessage(method.GetInputType())
// 将数据格式化到 message
if err := NewJsonRequestParser(anyResolver, strings.NewReader(data)).Next(req); err != nil {
return nil, err
}
cancel_ctx, cancel := context.WithCancel(ctx)
stream, err := stub.InvokeRpcServerStream(cancel_ctx, method, req)
if err != nil {
cancel()
return nil, err
}
pReader, pWriter := io.Pipe()
resp := &Response{
method: method,
header: &metadata.MD{},
recv: pReader,
cancel: cancel,
}
go func() {
msgParser := jsonpb.Marshaler{
EmitDefaults: true,
AnyResolver: anyResolver,
}
defer func() {
cancel()
pWriter.Close()
}()
for {
// 从流中接收消息
msg, err := stream.RecvMsg()
if err == io.EOF {
return
}
if err != nil {
pWriter.CloseWithError(err)
return
}
if *resp.header, err = stream.Header(); err != nil {
pWriter.CloseWithError(err)
return
}
if err = msgParser.Marshal(pWriter, msg); err != nil {
pWriter.CloseWithError(err)
return
}
msg.Reset()
}
}()
return resp, nil
}
func (this Grpcall) invokeBidiStream(ctx context.Context, stub grpcdynamic.Stub, method *desc.MethodDescriptor, msgFactory *dynamic.MessageFactory, anyResolver jsonpb.AnyResolver, data string) (*Response, error) {
// 创建请求 message
req := msgFactory.NewMessage(method.GetInputType())
// 将数据格式化到 message
if err := NewJsonRequestParser(anyResolver, strings.NewReader(data)).Next(req); err != nil {
return nil, err
}
cancel_ctx, cancel := context.WithCancel(ctx)
stream, err := stub.InvokeRpcBidiStream(cancel_ctx, method)
if err != nil {
cancel()
return nil, err
}
if err := stream.SendMsg(req); err != nil {
cancel()
return nil, err
}
pRecvReader, pRecvWriter := io.Pipe()
resp := &Response{
method: method,
header: &metadata.MD{},
recv: pRecvReader,
send: make(chan []byte),
sendCompleted: make(chan error),
cancel: cancel,
}
// send
go func() {
defer func() {
stream.CloseSend()
cancel()
close(resp.sendCompleted)
close(resp.send)
}()
reqParser := jsonpb.Unmarshaler{AnyResolver: anyResolver}
for {
select {
case <-cancel_ctx.Done():
resp.sendCompleted <- nil
return
case <-ctx.Done():
resp.sendCompleted <- nil
return
case msg, ok := <-resp.send:
if !ok {
resp.sendCompleted <- fmt.Errorf("read send message error")
return
}
if err := reqParser.UnmarshalNext(json.NewDecoder(bytes.NewReader(msg)), req); err != nil {
resp.sendCompleted <- err
return
}
if err := stream.SendMsg(req); err != nil {
resp.sendCompleted <- err
return
}
}
resp.sendCompleted <- nil
}
}()
// recv
go func() {
msgParser := jsonpb.Marshaler{
EmitDefaults: true,
AnyResolver: anyResolver,
}
defer func() {
cancel()
pRecvWriter.Close()
}()
for {
// 从流中接收消息
msg, err := stream.RecvMsg()
if err == io.EOF {
return
}
if err != nil {
pRecvWriter.CloseWithError(err)
return
}
if *resp.header, err = stream.Header(); err != nil {
pRecvWriter.CloseWithError(err)
return
}
if err = msgParser.Marshal(pRecvWriter, msg); err != nil {
pRecvWriter.CloseWithError(err)
return
}
msg.Reset()
}
}()
return resp, nil
}
func (this Grpcall) GetClientConn() *grpc.ClientConn {
return this.conn
}
func (this Grpcall) GetAnyResolver() (jsonpb.AnyResolver, error) {
files, err := this.GetAllFilesDescriptor()
if err != nil {
return nil, err
}
dExtension := dynamic.ExtensionRegistry{}
for _, fd := range files {
dExtension.AddExtensionsFromFile(fd)
}
mf := dynamic.NewMessageFactoryWithExtensionRegistry(&dExtension)
return dynamic.AnyResolver(mf, files...), nil
}
func (this Grpcall) GetProtoDescriptor() ProtoDescriptor {
return this.pDescriptor
}
func (this Grpcall) GetAllFilesDescriptor() ([]*desc.FileDescriptor, error) {
return this.pDescriptor.GetAllFiles()
}
// 获取所有服务
func (this Grpcall) GetServices() ([]string, error) {
return this.pDescriptor.ListServices()
}
// 获取服务下的所有方法
func (this Grpcall) GetServiceMethods(name string) ([]string, error) {
if dsc, err := this.pDescriptor.FindSymbol(name); err != nil {
return nil, err
} else if sd, ok := dsc.(*desc.ServiceDescriptor); !ok {
return nil, fmt.Errorf("%s, not found methods", name)
} else {
methods := make([]string, 0, len(sd.GetMethods()))
for _, method := range sd.GetMethods() {
methods = append(methods, method.GetFullyQualifiedName())
}
sort.Strings(methods)
return methods, nil
}
}
// 获取 gRPC 服务描述
func (this Grpcall) GetServiceDescriptor(service string) (*desc.ServiceDescriptor, error) {
if service == "" {
return nil, fmt.Errorf("service name is invalid")
}
dsc, err := this.pDescriptor.FindSymbol(service)
if err != nil {
return nil, err
}
sd, ok := dsc.(*desc.ServiceDescriptor)
if !ok {
return nil, fmt.Errorf("server not expose service, %q", service)
}
return sd, nil
}
// 获取 gRPC service.method 描述
func (this Grpcall) GetServiceMethodDescriptor(service, method string) (*desc.MethodDescriptor, error) {
serv, err := this.GetServiceDescriptor(service)
if err != nil {
return nil, err
}
if mtd := serv.FindMethodByName(method); mtd != nil {
return mtd, nil
}
return nil, fmt.Errorf("service %q does not include a method named %q", service, method)
}
// 生成 service.method message
func (this *Grpcall) MakeServiceMethodMessageFactory(method *desc.MethodDescriptor) (*dynamic.MessageFactory, error) {
// 扩展字段的信息,包括扩展字段的名称、标识符、类型等。用于注册扩展、查找扩展、解析扩展以及处理扩展字段的值。
var ext dynamic.ExtensionRegistry
excluded := []string{}
if err := this.RegServiceMethoddMessageExtensions(&ext, method.GetInputType(), excluded); err != nil {
return nil, fmt.Errorf("register service.method extensions for message %s: %v", method.GetInputType().GetFullyQualifiedName(), err)
}
if err := this.RegServiceMethoddMessageExtensions(&ext, method.GetOutputType(), excluded); err != nil {
return nil, fmt.Errorf("register service.method extensions for message %s: %v", method.GetOutputType().GetFullyQualifiedName(), err)
}
return dynamic.NewMessageFactoryWithExtensionRegistry(&ext), nil
}
// 注册 service.method message extensions
func (this Grpcall) RegServiceMethoddMessageExtensions(msgExt *dynamic.ExtensionRegistry, msg *desc.MessageDescriptor, excluded []string) error {
msgTypeName := msg.GetFullyQualifiedName()
// 避免重复注册
if lo.Contains(excluded, msgTypeName) {
return nil
} else {
excluded = append(excluded, msgTypeName)
}
// 注册 message extensions 信息
if len(msg.GetExtensionRanges()) > 0 {
fds, err := this.pDescriptor.AllExtensionsForType(msgTypeName)
if err != nil {
return fmt.Errorf("failed to query for extensions of type %s: %v", msgTypeName, err)
}
for i := 0; i < len(fds); i++ {
if err := msgExt.AddExtension(fds[i]); err != nil {
return fmt.Errorf("could not register extension %s of type %s: %v", fds[i].GetFullyQualifiedName(), msgTypeName, err)
}
}
}
// 递归注册 message 所有字段
for _, msgField := range msg.GetFields() {
if msgField.GetMessageType() == nil {
continue
}
if err := this.RegServiceMethoddMessageExtensions(msgExt, msgField.GetMessageType(), excluded); err != nil {
return err
}
}
return nil
}
func (this Grpcall) MakeMetadata(headers []string) metadata.MD {
md := make(metadata.MD)
for _, part := range headers {
if part != "" {
pieces := strings.SplitN(part, ":", 2)
if len(pieces) == 1 {
pieces = append(pieces, "") // if no value was specified, just make it "" (maybe the header value doesn't matter)
}
headerName := strings.ToLower(strings.TrimSpace(pieces[0]))
val := strings.TrimSpace(pieces[1])
if strings.HasSuffix(headerName, "-bin") {
if v, err := metadata_decode(val); err == nil {
val = v
}
}
md[headerName] = append(md[headerName], val)
}
}
return md
}
func (this Grpcall) Close() error {
return this.conn.Close()
}
func NewGrpcall(addr string, protosets []string, opts ...grpc.DialOption) (g *Grpcall, err error) {
if addr == "" {
return nil, fmt.Errorf("addr is invalid, %s", addr)
}
g = &Grpcall{}
opts = append(opts,
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 64 * time.Second,
Timeout: 64 * time.Second,
}),
grpc.WithBlock(),
grpc.FailOnNonTempDialError(true),
)
if conn, err := grpc.Dial(addr, opts...); err != nil {
return nil, err
} else {
g.conn = conn
}
if len(protosets) == 0 {
g.pDescriptor, err = NewServProtoDescriptor(g.conn)
} else {
g.pDescriptor, err = NewFileProtoDescriptor(protosets)
}
return g, err
}