package grpcall import ( "context" "fmt" "sort" "sync" "github.com/jhump/protoreflect/desc" "github.com/jhump/protoreflect/dynamic" "github.com/jhump/protoreflect/grpcreflect" "github.com/samber/lo" "google.golang.org/grpc" "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" ) type ProtoDescriptor interface { // gRPC 服务列表 ListServices() ([]string, error) // 查找 proto 中定义的消息、服务或枚举等结构的描述符 FindSymbol(string) (desc.Descriptor, error) // 获取消息类型的所有扩展字段的描述符 AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) // 获取 proto 文件描述信息 GetAllFiles() ([]*desc.FileDescriptor, error) } type WithFilesProtoDescriptor interface { GetAllFiles() ([]*desc.FileDescriptor, error) } type ServProtoDescriptor struct { client *grpcreflect.Client } type FileProtoDescriptor struct { files map[string]*desc.FileDescriptor er *dynamic.ExtensionRegistry erInit sync.Once } func NewServProtoDescriptor(client *grpc.ClientConn) (ProtoDescriptor, error) { ctx := context.Background() pd := &ServProtoDescriptor{ client: grpcreflect.NewClient(ctx, grpc_reflection_v1alpha.NewServerReflectionClient(client)), } return pd, nil } func (this ServProtoDescriptor) ListServices() ([]string, error) { return this.client.ListServices() } func (this ServProtoDescriptor) FindSymbol(name string) (desc.Descriptor, error) { file, err := this.client.FileContainingSymbol(name) if err != nil { return nil, err } d := file.FindSymbol(name) if d == nil { return nil, fmt.Errorf("Symbol Not Found, %s", name) } return d, nil } func (this ServProtoDescriptor) AllExtensionsForType(typ string) ([]*desc.FieldDescriptor, error) { exts := []*desc.FieldDescriptor{} nums, err := this.client.AllExtensionNumbersForType(typ) if err != nil { return nil, err } for _, fieldNum := range nums { ext, err := this.client.ResolveExtension(typ, fieldNum) if err != nil { return nil, err } exts = append(exts, ext) } return exts, nil } func (this ServProtoDescriptor) GetAllFiles() ([]*desc.FileDescriptor, error) { var files []*desc.FileDescriptor if services, err := this.ListServices(); err != nil { return nil, err } else { temp_fields := map[string]*desc.FileDescriptor{} for _, name := range services { d, err := this.FindSymbol(name) if err != nil { return nil, err } this.addFileDescriptorsToCache(d.GetFile(), temp_fields) } files = lo.Values(temp_fields) } sort.Sort(FileDescriptorSorter(files)) return files, nil } func (this ServProtoDescriptor) addFileDescriptorsToCache(file *desc.FileDescriptor, cache map[string]*desc.FileDescriptor) error { if _, ok := cache[file.GetName()]; ok { return nil } else { cache[file.GetName()] = file } for _, dep := range file.GetDependencies() { this.addFileDescriptorsToCache(dep, cache) } return nil } func NewFileProtoDescriptor(protosets []string) (ProtoDescriptor, error) { return nil, nil } func (this FileProtoDescriptor) GetAllFiles() ([]*desc.FileDescriptor, error) { files := lo.Values(this.files) sort.Sort(FileDescriptorSorter(files)) return files, nil }