Files
gobridge/pool.go
what b390effd8e feat: 添加 Python→Go 全双工回调支持(call_go)
- 新增 WithHandlers 选项,通过反射将 Go 结构体方法暴露给 Python
- 新增 callback/callback_result 消息类型,支持 Python 在处理中回调 Go
- client 侧新增 readResult,内联处理 callback,复用同一连接避免死锁
- Python 侧新增 call_go[T]() 泛型调用,支持 dataclass 自动构造
- 注入 GOBRIDGE_WORKER_ID/WORKER_COUNT 环境变量,支持多 worker 初始化分工
- 新增示例演示 Go→Python→Go→Python 四层全双工链路
- Python 包版本升至 0.1.1
2026-04-14 13:06:50 +08:00

273 lines
7.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package gobridge
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"reflect"
"sync"
"sync/atomic"
)
// poolConfig 是进程池内部配置,通过 Option 函数填充
type poolConfig struct {
workers int
maxConnsPerWorker int
pythonExe string
scriptArgs []string
workDir string
env []string
socketDir string
stdout io.Writer
stderr io.Writer
handler any
}
// Option 是 NewPool 的函数选项
type Option func(*poolConfig)
// WithWorkers 设置 Python 进程数量(默认 2
func WithWorkers(n int) Option {
return func(c *poolConfig) { c.workers = n }
}
// WithMaxConns 设置每个进程的最大连接数(默认 4
func WithMaxConns(n int) Option {
return func(c *poolConfig) { c.maxConnsPerWorker = n }
}
// WithPythonExe 设置 Python 可执行文件(默认 "python3"
// uv 模式WithPythonExe("uv"), WithScriptArgs("run")
func WithPythonExe(exe string) Option {
return func(c *poolConfig) { c.pythonExe = exe }
}
// WithScriptArgs 设置脚本路径之后的附加参数
func WithScriptArgs(args ...string) Option {
return func(c *poolConfig) { c.scriptArgs = args }
}
// WithWorkDir 设置子进程工作目录(默认继承当前进程)
func WithWorkDir(workDir string) Option {
return func(c *poolConfig) { c.workDir = workDir }
}
// WithEnv 设置附加环境变量,格式为 "KEY=VALUE"
func WithEnv(env ...string) Option {
return func(c *poolConfig) { c.env = env }
}
// WithSocketDir 设置 UDS socket 文件目录(默认 /tmp
func WithSocketDir(dir string) Option {
return func(c *poolConfig) { c.socketDir = dir }
}
// WithStdout 设置子进程标准输出目标(默认 os.Stdout传 io.Discard 可静默)
func WithStdout(w io.Writer) Option {
return func(c *poolConfig) { c.stdout = w }
}
// WithStderr 设置子进程标准错误目标(默认 os.Stderr传 io.Discard 可静默)
func WithStderr(w io.Writer) Option {
return func(c *poolConfig) { c.stderr = w }
}
// WithHandlers 注册 Go handler struct其所有公开方法自动暴露给 Python 通过 call_go() 调用。
//
// type MyService struct{}
// func (s *MyService) Multiply(ctx context.Context, a, b int) (int, error) { ... }
//
// pool, err := gobridge.NewPool("worker.py", gobridge.WithHandlers(&MyService{}))
// // Python: call_go("Multiply", 3, 4)
func WithHandlers(h any) Option {
return func(c *poolConfig) { c.handler = h }
}
// ── Pool 接口 ────────────────────────────────────────────────────────────────
// Pool 是 Python worker 进程池的接口
type Pool interface {
Close()
acquire(ctx context.Context) (net.Conn, *worker, error)
nextReqID() uint64
callbackDispatch(ctx context.Context, msg Message) (any, string)
}
// ── goHandler ────────────────────────────────────────────────────────────────
type goHandler struct {
fn reflect.Value
hasCtx bool
inTypes []reflect.Type
outType reflect.Type
hasErr bool
}
// ── pool内部实现─────────────────────────────────────────────────────────
type pool struct {
workers []*worker
idx atomic.Uint64
reqID atomic.Uint64
mu sync.RWMutex
handlers map[string]goHandler
}
// NewPool 创建并启动进程池。
//
// pool, err := gobridge.NewPool("worker.py")
// pool, err := gobridge.NewPool("worker.py", gobridge.WithWorkers(4))
// pool, err := gobridge.NewPool("worker.py",
// gobridge.WithHandlers(&MyService{}),
// gobridge.WithWorkers(2),
// )
func NewPool(script string, opts ...Option) (Pool, error) {
cfg := &poolConfig{
workers: 2,
maxConnsPerWorker: 4,
pythonExe: "python3",
socketDir: "/tmp",
}
for _, o := range opts {
o(cfg)
}
if script == "" {
return nil, fmt.Errorf("gobridge: script must not be empty")
}
cfg.scriptArgs = append([]string{script}, cfg.scriptArgs...)
workers := make([]*worker, cfg.workers)
for i := range workers {
w, err := newWorker(cfg, i)
if err != nil {
for j := range i {
workers[j].stop()
}
return nil, fmt.Errorf("create worker %d: %w", i, err)
}
workers[i] = w
}
p := &pool{workers: workers, handlers: make(map[string]goHandler)}
if cfg.handler != nil {
p.bindHandlers(cfg.handler)
}
return p, nil
}
func (p *pool) bindHandlers(h any) {
rv := reflect.ValueOf(h)
rt := rv.Type()
for i := range rt.NumMethod() {
m := rt.Method(i)
if !m.IsExported() {
continue
}
p.bindHandler(m.Name, rv.Method(i))
}
}
func (p *pool) bindHandler(name string, fn reflect.Value) {
errType := reflect.TypeFor[error]()
ctxType := reflect.TypeFor[context.Context]()
ft := fn.Type()
h := goHandler{fn: fn}
startIdx := 0
if ft.NumIn() > 0 && ft.In(0).Implements(ctxType) {
h.hasCtx = true
startIdx = 1
}
for i := startIdx; i < ft.NumIn(); i++ {
h.inTypes = append(h.inTypes, ft.In(i))
}
switch ft.NumOut() {
case 0:
case 1:
if ft.Out(0).Implements(errType) {
h.hasErr = true
} else {
h.outType = ft.Out(0)
}
case 2:
if !ft.Out(1).Implements(errType) {
panic(fmt.Sprintf("gobridge: handler %s: second return value must implement error", name))
}
h.outType = ft.Out(0)
h.hasErr = true
default:
panic(fmt.Sprintf("gobridge: handler %s: must return at most (value, error)", name))
}
p.handlers[name] = h
}
func (p *pool) acquire(ctx context.Context) (net.Conn, *worker, error) {
idx := p.idx.Add(1) % uint64(len(p.workers))
w := p.workers[idx]
conn, err := w.acquire(ctx)
return conn, w, err
}
func (p *pool) nextReqID() uint64 {
return p.reqID.Add(1)
}
func (p *pool) callbackDispatch(ctx context.Context, msg Message) (any, string) {
p.mu.RLock()
h, ok := p.handlers[msg.Method]
p.mu.RUnlock()
if !ok {
return nil, fmt.Sprintf("unknown go handler: %s", msg.Method)
}
var rawArgs []json.RawMessage
if len(msg.Args) > 0 {
if err := json.Unmarshal(msg.Args, &rawArgs); err != nil {
return nil, fmt.Sprintf("unmarshal args: %v", err)
}
}
if len(rawArgs) != len(h.inTypes) {
return nil, fmt.Sprintf("arg count mismatch: want %d got %d", len(h.inTypes), len(rawArgs))
}
in := make([]reflect.Value, 0, len(h.inTypes)+1)
if h.hasCtx {
in = append(in, reflect.ValueOf(ctx))
}
for i, t := range h.inTypes {
v := reflect.New(t)
if err := json.Unmarshal(rawArgs[i], v.Interface()); err != nil {
return nil, fmt.Sprintf("unmarshal arg %d: %v", i, err)
}
in = append(in, v.Elem())
}
out := h.fn.Call(in)
if h.hasErr {
errVal := out[len(out)-1]
if !errVal.IsNil() {
return nil, errVal.Interface().(error).Error()
}
}
if h.outType == nil || len(out) == 0 {
return nil, ""
}
return out[0].Interface(), ""
}
func (p *pool) Close() {
for _, w := range p.workers {
w.stop()
}
}