package gobridge import ( "context" "encoding/json" "fmt" "io" "net" "reflect" "sync" "sync/atomic" ) // poolConfig 是进程池内部配置,通过 Option 函数填充 type poolConfig struct { workers int maxConnsPerWorker int pythonExe string scriptArgs []string workDir string env []string socketDir string stdout io.Writer stderr io.Writer handler any } // Option 是 NewPool 的函数选项 type Option func(*poolConfig) // WithWorkers 设置 Python 进程数量(默认 2) func WithWorkers(n int) Option { return func(c *poolConfig) { c.workers = n } } // WithMaxConns 设置每个进程的最大连接数(默认 4) func WithMaxConns(n int) Option { return func(c *poolConfig) { c.maxConnsPerWorker = n } } // WithPythonExe 设置 Python 可执行文件(默认 "python3") // uv 模式:WithPythonExe("uv"), WithScriptArgs("run") func WithPythonExe(exe string) Option { return func(c *poolConfig) { c.pythonExe = exe } } // WithScriptArgs 设置脚本路径之后的附加参数 func WithScriptArgs(args ...string) Option { return func(c *poolConfig) { c.scriptArgs = args } } // WithWorkDir 设置子进程工作目录(默认继承当前进程) func WithWorkDir(workDir string) Option { return func(c *poolConfig) { c.workDir = workDir } } // WithEnv 设置附加环境变量,格式为 "KEY=VALUE" func WithEnv(env ...string) Option { return func(c *poolConfig) { c.env = env } } // WithSocketDir 设置 UDS socket 文件目录(默认 /tmp) func WithSocketDir(dir string) Option { return func(c *poolConfig) { c.socketDir = dir } } // WithStdout 设置子进程标准输出目标(默认 os.Stdout,传 io.Discard 可静默) func WithStdout(w io.Writer) Option { return func(c *poolConfig) { c.stdout = w } } // WithStderr 设置子进程标准错误目标(默认 os.Stderr,传 io.Discard 可静默) func WithStderr(w io.Writer) Option { return func(c *poolConfig) { c.stderr = w } } // WithHandlers 注册 Go handler struct,其所有公开方法自动暴露给 Python 通过 call_go() 调用。 // // type MyService struct{} // func (s *MyService) Multiply(ctx context.Context, a, b int) (int, error) { ... } // // pool, err := gobridge.NewPool("worker.py", gobridge.WithHandlers(&MyService{})) // // Python: call_go("Multiply", 3, 4) func WithHandlers(h any) Option { return func(c *poolConfig) { c.handler = h } } // ── Pool 接口 ──────────────────────────────────────────────────────────────── // Pool 是 Python worker 进程池的接口 type Pool interface { Close() acquire(ctx context.Context) (net.Conn, *worker, error) nextReqID() uint64 callbackDispatch(ctx context.Context, msg Message) (any, string) } // ── goHandler ──────────────────────────────────────────────────────────────── type goHandler struct { fn reflect.Value hasCtx bool inTypes []reflect.Type outType reflect.Type hasErr bool } // ── pool(内部实现)───────────────────────────────────────────────────────── type pool struct { workers []*worker idx atomic.Uint64 reqID atomic.Uint64 mu sync.RWMutex handlers map[string]goHandler } // NewPool 创建并启动进程池。 // // pool, err := gobridge.NewPool("worker.py") // pool, err := gobridge.NewPool("worker.py", gobridge.WithWorkers(4)) // pool, err := gobridge.NewPool("worker.py", // gobridge.WithHandlers(&MyService{}), // gobridge.WithWorkers(2), // ) func NewPool(script string, opts ...Option) (Pool, error) { cfg := &poolConfig{ workers: 2, maxConnsPerWorker: 4, pythonExe: "python3", socketDir: "/tmp", } for _, o := range opts { o(cfg) } if script == "" { return nil, fmt.Errorf("gobridge: script must not be empty") } cfg.scriptArgs = append([]string{script}, cfg.scriptArgs...) workers := make([]*worker, cfg.workers) for i := range workers { w, err := newWorker(cfg, i) if err != nil { for j := range i { workers[j].stop() } return nil, fmt.Errorf("create worker %d: %w", i, err) } workers[i] = w } p := &pool{workers: workers, handlers: make(map[string]goHandler)} if cfg.handler != nil { p.bindHandlers(cfg.handler) } return p, nil } func (p *pool) bindHandlers(h any) { rv := reflect.ValueOf(h) rt := rv.Type() for i := range rt.NumMethod() { m := rt.Method(i) if !m.IsExported() { continue } p.bindHandler(m.Name, rv.Method(i)) } } func (p *pool) bindHandler(name string, fn reflect.Value) { errType := reflect.TypeFor[error]() ctxType := reflect.TypeFor[context.Context]() ft := fn.Type() h := goHandler{fn: fn} startIdx := 0 if ft.NumIn() > 0 && ft.In(0).Implements(ctxType) { h.hasCtx = true startIdx = 1 } for i := startIdx; i < ft.NumIn(); i++ { h.inTypes = append(h.inTypes, ft.In(i)) } switch ft.NumOut() { case 0: case 1: if ft.Out(0).Implements(errType) { h.hasErr = true } else { h.outType = ft.Out(0) } case 2: if !ft.Out(1).Implements(errType) { panic(fmt.Sprintf("gobridge: handler %s: second return value must implement error", name)) } h.outType = ft.Out(0) h.hasErr = true default: panic(fmt.Sprintf("gobridge: handler %s: must return at most (value, error)", name)) } p.handlers[name] = h } func (p *pool) acquire(ctx context.Context) (net.Conn, *worker, error) { n := uint64(len(p.workers)) if n == 0 { return nil, nil, fmt.Errorf("gobridge: pool has no workers") } var idx uint64 if i := workerIndexFor(ctx, int(n)); i >= 0 { idx = uint64(i) % n // 防御:ctx 可能来自不同 pool } else { idx = p.idx.Add(1) % n } w := p.workers[idx] conn, err := w.acquire(ctx) return conn, w, err } func (p *pool) nextReqID() uint64 { return p.reqID.Add(1) } func (p *pool) callbackDispatch(ctx context.Context, msg Message) (any, string) { p.mu.RLock() h, ok := p.handlers[msg.Method] p.mu.RUnlock() if !ok { return nil, fmt.Sprintf("unknown go handler: %s", msg.Method) } var rawArgs []json.RawMessage if len(msg.Args) > 0 { if err := json.Unmarshal(msg.Args, &rawArgs); err != nil { return nil, fmt.Sprintf("unmarshal args: %v", err) } } if len(rawArgs) != len(h.inTypes) { return nil, fmt.Sprintf("arg count mismatch: want %d got %d", len(h.inTypes), len(rawArgs)) } in := make([]reflect.Value, 0, len(h.inTypes)+1) if h.hasCtx { in = append(in, reflect.ValueOf(ctx)) } for i, t := range h.inTypes { v := reflect.New(t) if err := json.Unmarshal(rawArgs[i], v.Interface()); err != nil { return nil, fmt.Sprintf("unmarshal arg %d: %v", i, err) } in = append(in, v.Elem()) } out := h.fn.Call(in) if h.hasErr { errVal := out[len(out)-1] if !errVal.IsNil() { return nil, errVal.Interface().(error).Error() } } if h.outType == nil || len(out) == 0 { return nil, "" } return out[0].Interface(), "" } func (p *pool) Close() { for _, w := range p.workers { w.stop() } }