137 lines
3.1 KiB
Go
137 lines
3.1 KiB
Go
package grpcall
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sort"
|
|
"sync"
|
|
|
|
"github.com/jhump/protoreflect/desc"
|
|
"github.com/jhump/protoreflect/dynamic"
|
|
"github.com/jhump/protoreflect/grpcreflect"
|
|
"github.com/samber/lo"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
|
|
)
|
|
|
|
type ProtoDescriptor interface {
|
|
// gRPC 服务列表
|
|
ListServices() ([]string, error)
|
|
// 查找 proto 中定义的消息、服务或枚举等结构的描述符
|
|
FindSymbol(string) (desc.Descriptor, error)
|
|
// 获取消息类型的所有扩展字段的描述符
|
|
AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error)
|
|
// 获取 proto 文件描述信息
|
|
GetAllFiles() ([]*desc.FileDescriptor, error)
|
|
}
|
|
|
|
type WithFilesProtoDescriptor interface {
|
|
GetAllFiles() ([]*desc.FileDescriptor, error)
|
|
}
|
|
|
|
type ServProtoDescriptor struct {
|
|
client *grpcreflect.Client
|
|
}
|
|
|
|
type FileProtoDescriptor struct {
|
|
files map[string]*desc.FileDescriptor
|
|
er *dynamic.ExtensionRegistry
|
|
erInit sync.Once
|
|
}
|
|
|
|
func NewServProtoDescriptor(client *grpc.ClientConn) (ProtoDescriptor, error) {
|
|
ctx := context.Background()
|
|
|
|
pd := &ServProtoDescriptor{
|
|
client: grpcreflect.NewClient(ctx, grpc_reflection_v1alpha.NewServerReflectionClient(client)),
|
|
}
|
|
|
|
return pd, nil
|
|
}
|
|
|
|
func (this ServProtoDescriptor) ListServices() ([]string, error) {
|
|
return this.client.ListServices()
|
|
}
|
|
|
|
func (this ServProtoDescriptor) FindSymbol(name string) (desc.Descriptor, error) {
|
|
file, err := this.client.FileContainingSymbol(name)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
d := file.FindSymbol(name)
|
|
|
|
if d == nil {
|
|
return nil, fmt.Errorf("Symbol Not Found, %s", name)
|
|
}
|
|
|
|
return d, nil
|
|
}
|
|
|
|
func (this ServProtoDescriptor) AllExtensionsForType(typ string) ([]*desc.FieldDescriptor, error) {
|
|
exts := []*desc.FieldDescriptor{}
|
|
|
|
nums, err := this.client.AllExtensionNumbersForType(typ)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, fieldNum := range nums {
|
|
ext, err := this.client.ResolveExtension(typ, fieldNum)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
exts = append(exts, ext)
|
|
}
|
|
|
|
return exts, nil
|
|
}
|
|
|
|
func (this ServProtoDescriptor) GetAllFiles() ([]*desc.FileDescriptor, error) {
|
|
var files []*desc.FileDescriptor
|
|
|
|
if services, err := this.ListServices(); err != nil {
|
|
return nil, err
|
|
} else {
|
|
temp_fields := map[string]*desc.FileDescriptor{}
|
|
for _, name := range services {
|
|
d, err := this.FindSymbol(name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
this.addFileDescriptorsToCache(d.GetFile(), temp_fields)
|
|
}
|
|
files = lo.Values(temp_fields)
|
|
}
|
|
|
|
sort.Sort(FileDescriptorSorter(files))
|
|
|
|
return files, nil
|
|
}
|
|
|
|
func (this ServProtoDescriptor) addFileDescriptorsToCache(file *desc.FileDescriptor, cache map[string]*desc.FileDescriptor) error {
|
|
if _, ok := cache[file.GetName()]; ok {
|
|
return nil
|
|
} else {
|
|
cache[file.GetName()] = file
|
|
}
|
|
|
|
for _, dep := range file.GetDependencies() {
|
|
this.addFileDescriptorsToCache(dep, cache)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func NewFileProtoDescriptor(protosets []string) (ProtoDescriptor, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (this FileProtoDescriptor) GetAllFiles() ([]*desc.FileDescriptor, error) {
|
|
files := lo.Values(this.files)
|
|
sort.Sort(FileDescriptorSorter(files))
|
|
return files, nil
|
|
}
|