From 308faa95ffa7fa1318cb9d6fffa246762e939b3c Mon Sep 17 00:00:00 2001 From: what Date: Fri, 10 Apr 2026 16:38:37 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=B0=86=20Pool=20=E9=87=8D?= =?UTF-8?q?=E6=9E=84=E4=B8=BA=E6=8E=A5=E5=8F=A3=EF=BC=8C=E6=8F=90=E5=8F=96?= =?UTF-8?q?=20pool=20=E4=B8=BA=E5=85=B7=E4=BD=93=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client.go | 18 +++++++++--------- pool.go | 24 ++++++++++++++++++------ 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/client.go b/client.go index e979058..a6b850d 100644 --- a/client.go +++ b/client.go @@ -20,7 +20,7 @@ import ( // // ctx 取消时会立即中断与 Python 的通信并返回 ctx.Err()。 // 对于流式输出/双向流,ctx 取消会关闭返回的 channel。 -func Invoke[R any](ctx context.Context, pool *Pool, method string, args ...any) (R, error) { +func Invoke[R any](ctx context.Context, pool Pool, method string, args ...any) (R, error) { rt := reflect.TypeFor[R]() // 查找 chan 类型的输入参数 @@ -89,7 +89,7 @@ func contextErr(ctx context.Context, err error) error { return err } -func invokeRegular[R any](ctx context.Context, pool *Pool, method string, args ...any) (R, error) { +func invokeRegular[R any](ctx context.Context, pool Pool, method string, args ...any) (R, error) { var zero R argsJSON, err := json.Marshal(args) @@ -102,7 +102,7 @@ func invokeRegular[R any](ctx context.Context, pool *Pool, method string, args . return zero, err } - id := pool.reqID.Add(1) + id := pool.nextReqID() stop := watchCtx(ctx, conn, id) defer stop() @@ -136,7 +136,7 @@ func invokeRegular[R any](ctx context.Context, pool *Pool, method string, args . return result, nil } -func invokeStreamOut[R any](ctx context.Context, pool *Pool, method string, rt reflect.Type, args ...any) (R, error) { +func invokeStreamOut[R any](ctx context.Context, pool Pool, method string, rt reflect.Type, args ...any) (R, error) { var zero R argsJSON, err := json.Marshal(args) @@ -149,7 +149,7 @@ func invokeStreamOut[R any](ctx context.Context, pool *Pool, method string, rt r return zero, err } - id := pool.reqID.Add(1) + id := pool.nextReqID() if err := writeMsg(conn, Message{ ID: id, Type: TypeCall, @@ -187,7 +187,7 @@ func invokeStreamOut[R any](ctx context.Context, pool *Pool, method string, rt r return ch.Interface().(R), nil } -func invokeStreamIn[R any](ctx context.Context, pool *Pool, method string, streamArgIdx int, streamCh reflect.Value, args ...any) (R, error) { +func invokeStreamIn[R any](ctx context.Context, pool Pool, method string, streamArgIdx int, streamCh reflect.Value, args ...any) (R, error) { var zero R jsonArgs := make([]any, len(args)) @@ -204,7 +204,7 @@ func invokeStreamIn[R any](ctx context.Context, pool *Pool, method string, strea return zero, err } - id := pool.reqID.Add(1) + id := pool.nextReqID() stop := watchCtx(ctx, conn, id) defer stop() @@ -265,7 +265,7 @@ func invokeStreamIn[R any](ctx context.Context, pool *Pool, method string, strea return result, nil } -func invokeStreamBoth[R any](ctx context.Context, pool *Pool, method string, streamArgIdx int, streamCh reflect.Value, rt reflect.Type, args ...any) (R, error) { +func invokeStreamBoth[R any](ctx context.Context, pool Pool, method string, streamArgIdx int, streamCh reflect.Value, rt reflect.Type, args ...any) (R, error) { var zero R jsonArgs := make([]any, len(args)) @@ -282,7 +282,7 @@ func invokeStreamBoth[R any](ctx context.Context, pool *Pool, method string, str return zero, err } - id := pool.reqID.Add(1) + id := pool.nextReqID() if err := writeMsg(conn, Message{ ID: id, Type: TypeCall, diff --git a/pool.go b/pool.go index 25c3850..eb451c1 100644 --- a/pool.go +++ b/pool.go @@ -72,8 +72,16 @@ func WithStderr(w io.Writer) Option { return func(c *poolConfig) { c.stderr = w } } -// Pool 管理多个 Python worker 进程及其连接池 -type Pool struct { +// Pool 是 Python worker 进程池的接口 +type Pool interface { + // Close 关闭所有 worker 进程和连接 + Close() + acquire(ctx context.Context) (net.Conn, *worker, error) + nextReqID() uint64 +} + +// pool 是 Pool 的具体实现 +type pool struct { workers []*worker idx atomic.Uint64 reqID atomic.Uint64 @@ -88,7 +96,7 @@ type Pool struct { // gobridge.WithScriptArgs("worker.py"), // gobridge.WithWorkDir("./worker"), // ) -func NewPool(script string, opts ...Option) (*Pool, error) { +func NewPool(script string, opts ...Option) (Pool, error) { cfg := &poolConfig{ workers: 2, maxConnsPerWorker: 4, @@ -114,19 +122,23 @@ func NewPool(script string, opts ...Option) (*Pool, error) { } workers[i] = w } - return &Pool{workers: workers}, nil + return &pool{workers: workers}, nil } // acquire 以轮询方式从进程池取出一个可用连接 -func (p *Pool) acquire(ctx context.Context) (net.Conn, *worker, error) { +func (p *pool) acquire(ctx context.Context) (net.Conn, *worker, error) { idx := p.idx.Add(1) % uint64(len(p.workers)) w := p.workers[idx] conn, err := w.acquire(ctx) return conn, w, err } +func (p *pool) nextReqID() uint64 { + return p.reqID.Add(1) +} + // Close 关闭所有 worker 进程和连接 -func (p *Pool) Close() { +func (p *pool) Close() { for _, w := range p.workers { w.stop() }