grpcall/proto_descriptor.go

137 lines
3.1 KiB
Go
Raw Normal View History

2023-05-30 17:30:06 +08:00
package grpcall
import (
"context"
"fmt"
"sort"
"sync"
"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/dynamic"
"github.com/jhump/protoreflect/grpcreflect"
"github.com/samber/lo"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
)
type ProtoDescriptor interface {
// gRPC 服务列表
ListServices() ([]string, error)
// 查找 proto 中定义的消息、服务或枚举等结构的描述符
FindSymbol(string) (desc.Descriptor, error)
// 获取消息类型的所有扩展字段的描述符
AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error)
// 获取 proto 文件描述信息
GetAllFiles() ([]*desc.FileDescriptor, error)
}
type WithFilesProtoDescriptor interface {
GetAllFiles() ([]*desc.FileDescriptor, error)
}
type ServProtoDescriptor struct {
client *grpcreflect.Client
}
type FileProtoDescriptor struct {
files map[string]*desc.FileDescriptor
er *dynamic.ExtensionRegistry
erInit sync.Once
}
func NewServProtoDescriptor(client *grpc.ClientConn) (ProtoDescriptor, error) {
ctx := context.Background()
pd := &ServProtoDescriptor{
client: grpcreflect.NewClient(ctx, grpc_reflection_v1alpha.NewServerReflectionClient(client)),
}
return pd, nil
}
func (this ServProtoDescriptor) ListServices() ([]string, error) {
return this.client.ListServices()
}
func (this ServProtoDescriptor) FindSymbol(name string) (desc.Descriptor, error) {
file, err := this.client.FileContainingSymbol(name)
if err != nil {
return nil, err
}
d := file.FindSymbol(name)
if d == nil {
return nil, fmt.Errorf("Symbol Not Found, %s", name)
}
return d, nil
}
func (this ServProtoDescriptor) AllExtensionsForType(typ string) ([]*desc.FieldDescriptor, error) {
exts := []*desc.FieldDescriptor{}
nums, err := this.client.AllExtensionNumbersForType(typ)
if err != nil {
return nil, err
}
for _, fieldNum := range nums {
ext, err := this.client.ResolveExtension(typ, fieldNum)
if err != nil {
return nil, err
}
exts = append(exts, ext)
}
return exts, nil
}
func (this ServProtoDescriptor) GetAllFiles() ([]*desc.FileDescriptor, error) {
var files []*desc.FileDescriptor
if services, err := this.ListServices(); err != nil {
return nil, err
} else {
temp_fields := map[string]*desc.FileDescriptor{}
for _, name := range services {
d, err := this.FindSymbol(name)
if err != nil {
return nil, err
}
this.addFileDescriptorsToCache(d.GetFile(), temp_fields)
}
files = lo.Values(temp_fields)
}
sort.Sort(FileDescriptorSorter(files))
return files, nil
}
func (this ServProtoDescriptor) addFileDescriptorsToCache(file *desc.FileDescriptor, cache map[string]*desc.FileDescriptor) error {
if _, ok := cache[file.GetName()]; ok {
return nil
} else {
cache[file.GetName()] = file
}
for _, dep := range file.GetDependencies() {
this.addFileDescriptorsToCache(dep, cache)
}
return nil
}
func NewFileProtoDescriptor(protosets []string) (ProtoDescriptor, error) {
return nil, nil
}
func (this FileProtoDescriptor) GetAllFiles() ([]*desc.FileDescriptor, error) {
files := lo.Values(this.files)
sort.Sort(FileDescriptorSorter(files))
return files, nil
}