// Package gobridge 提供 Go 与 Python 之间的双向通信桥接, // 支持普通调用、流式输出、流式输入和双向流四种模式。 package gobridge import ( "context" "encoding/json" "fmt" "net" "reflect" "sync" ) // Invoke 调用 Python 暴露的函数,支持四种模式: // // 普通调用: Invoke[int](ctx, pool, "Add", 3, 4) // 流式输出: Invoke[chan int](ctx, pool, "RangeGen", 1, 10) // Python yield → Go channel // 流式输入: Invoke[int](ctx, pool, "SumStream", inputChan) // Go channel → Python Iterator // 双向流: Invoke[chan int](ctx, pool, "Transform", inputChan) // 两端均为流 // // ctx 取消时会立即中断与 Python 的通信并返回 ctx.Err()。 // 对于流式输出/双向流,ctx 取消会关闭返回的 channel。 func Invoke[R any](ctx context.Context, pool Pool, method string, args ...any) (R, error) { rt := reflect.TypeFor[R]() // 查找 chan 类型的输入参数 streamArgIdx := -1 var streamCh reflect.Value for i, arg := range args { if arg != nil { rv := reflect.ValueOf(arg) if rv.Kind() == reflect.Chan { streamArgIdx = i streamCh = rv break } } } switch { case rt.Kind() == reflect.Chan && streamArgIdx >= 0: return invokeStreamBoth[R](ctx, pool, method, streamArgIdx, streamCh, rt, args...) case rt.Kind() == reflect.Chan: return invokeStreamOut[R](ctx, pool, method, rt, args...) case streamArgIdx >= 0: return invokeStreamIn[R](ctx, pool, method, streamArgIdx, streamCh, args...) default: return invokeRegular[R](ctx, pool, method, args...) } } // watchCtx 启动一个 goroutine 监听 ctx: // - ctx 取消时先发送 cancel 消息(Python 侧收到后注入 InterruptedError) // - 再关闭连接,解除阻塞中的读写操作 // // 返回 stop 函数,必须在 conn 归还连接池前调用,可安全多次调用。 func watchCtx(ctx context.Context, conn net.Conn, id uint64) (stop func()) { done := make(chan struct{}) var once sync.Once go func() { select { case <-ctx.Done(): writeMsg(conn, Message{ID: id, Type: TypeCancel}) //nolint conn.Close() case <-done: } }() return func() { once.Do(func() { close(done) }) } } // chanRecv 从 ch 接收一个值,同时监听 ctx.Done()。 // 返回 (值, channel是否open, ctx是否已取消)。 func chanRecv(ctx context.Context, ch reflect.Value) (reflect.Value, bool, bool) { chosen, val, ok := reflect.Select([]reflect.SelectCase{ {Dir: reflect.SelectRecv, Chan: ch}, {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ctx.Done())}, }) if chosen == 1 { return reflect.Value{}, false, true } return val, ok, false } // contextErr 在 io 错误时优先返回 ctx 的错误原因 func contextErr(ctx context.Context, err error) error { if e := ctx.Err(); e != nil { return e } return err } func invokeRegular[R any](ctx context.Context, pool Pool, method string, args ...any) (R, error) { var zero R argsJSON, err := json.Marshal(args) if err != nil { return zero, fmt.Errorf("marshal args: %w", err) } conn, w, err := pool.acquire(ctx) if err != nil { return zero, err } id := pool.nextReqID() stop := watchCtx(ctx, conn, id) defer stop() if err := writeMsg(conn, Message{ ID: id, Type: TypeCall, Method: method, Args: argsJSON, }); err != nil { w.release(conn, false) return zero, contextErr(ctx, fmt.Errorf("write call: %w", err)) } resp, err := readMsg(conn) if err != nil { w.release(conn, false) return zero, contextErr(ctx, fmt.Errorf("read response: %w", err)) } stop() w.release(conn, true) if resp.Type == TypeError { return zero, fmt.Errorf("remote error: %s", resp.Error) } var result R if err := json.Unmarshal(resp.Data, &result); err != nil { return zero, fmt.Errorf("unmarshal result: %w", err) } return result, nil } func invokeStreamOut[R any](ctx context.Context, pool Pool, method string, rt reflect.Type, args ...any) (R, error) { var zero R argsJSON, err := json.Marshal(args) if err != nil { return zero, fmt.Errorf("marshal args: %w", err) } conn, w, err := pool.acquire(ctx) if err != nil { return zero, err } id := pool.nextReqID() if err := writeMsg(conn, Message{ ID: id, Type: TypeCall, Method: method, Args: argsJSON, }); err != nil { w.release(conn, false) return zero, contextErr(ctx, fmt.Errorf("write call: %w", err)) } ch := reflect.MakeChan(rt, 64) go func() { stop := watchCtx(ctx, conn, id) defer func() { stop() ch.Close() w.release(conn, ctx.Err() == nil) }() for { msg, err := readMsg(conn) if err != nil || msg.Type == TypeEnd || msg.Type == TypeError { return } if msg.Type == TypeChunk { val := reflect.New(rt.Elem()) if err := json.Unmarshal(msg.Data, val.Interface()); err != nil { return } ch.Send(val.Elem()) } } }() return ch.Interface().(R), nil } func invokeStreamIn[R any](ctx context.Context, pool Pool, method string, streamArgIdx int, streamCh reflect.Value, args ...any) (R, error) { var zero R jsonArgs := make([]any, len(args)) copy(jsonArgs, args) jsonArgs[streamArgIdx] = nil argsJSON, err := json.Marshal(jsonArgs) if err != nil { return zero, fmt.Errorf("marshal args: %w", err) } conn, w, err := pool.acquire(ctx) if err != nil { return zero, err } id := pool.nextReqID() stop := watchCtx(ctx, conn, id) defer stop() if err := writeMsg(conn, Message{ ID: id, Type: TypeCall, Method: method, Args: argsJSON, StreamInput: true, StreamArgIdx: streamArgIdx, }); err != nil { w.release(conn, false) return zero, contextErr(ctx, fmt.Errorf("write call: %w", err)) } for { val, ok, cancelled := chanRecv(ctx, streamCh) if cancelled { w.release(conn, false) return zero, ctx.Err() } if !ok { break } chunkData, err := json.Marshal(val.Interface()) if err != nil { w.release(conn, false) return zero, fmt.Errorf("marshal chunk: %w", err) } if err := writeMsg(conn, Message{ID: id, Type: TypeChunk, Data: chunkData}); err != nil { w.release(conn, false) return zero, contextErr(ctx, fmt.Errorf("write chunk: %w", err)) } } if err := writeMsg(conn, Message{ID: id, Type: TypeEnd}); err != nil { w.release(conn, false) return zero, contextErr(ctx, fmt.Errorf("write end: %w", err)) } resp, err := readMsg(conn) if err != nil { w.release(conn, false) return zero, contextErr(ctx, fmt.Errorf("read response: %w", err)) } stop() w.release(conn, true) if resp.Type == TypeError { return zero, fmt.Errorf("remote error: %s", resp.Error) } var result R if err := json.Unmarshal(resp.Data, &result); err != nil { return zero, fmt.Errorf("unmarshal result: %w", err) } return result, nil } func invokeStreamBoth[R any](ctx context.Context, pool Pool, method string, streamArgIdx int, streamCh reflect.Value, rt reflect.Type, args ...any) (R, error) { var zero R jsonArgs := make([]any, len(args)) copy(jsonArgs, args) jsonArgs[streamArgIdx] = nil argsJSON, err := json.Marshal(jsonArgs) if err != nil { return zero, fmt.Errorf("marshal args: %w", err) } conn, w, err := pool.acquire(ctx) if err != nil { return zero, err } id := pool.nextReqID() if err := writeMsg(conn, Message{ ID: id, Type: TypeCall, Method: method, Args: argsJSON, StreamInput: true, StreamArgIdx: streamArgIdx, }); err != nil { w.release(conn, false) return zero, contextErr(ctx, fmt.Errorf("write call: %w", err)) } outCh := reflect.MakeChan(rt, 64) // 写入 goroutine:输入 channel → Python chunks go func() { for { val, ok, cancelled := chanRecv(ctx, streamCh) if cancelled || !ok { break } data, err := json.Marshal(val.Interface()) if err != nil { break } if err := writeMsg(conn, Message{ID: id, Type: TypeChunk, Data: data}); err != nil { break } } writeMsg(conn, Message{ID: id, Type: TypeEnd}) //nolint }() // 读取 goroutine:Python chunks → 输出 channel go func() { stop := watchCtx(ctx, conn, id) defer func() { stop() outCh.Close() w.release(conn, ctx.Err() == nil) }() for { msg, err := readMsg(conn) if err != nil || msg.Type == TypeEnd || msg.Type == TypeError { return } if msg.Type == TypeChunk { val := reflect.New(rt.Elem()) if err := json.Unmarshal(msg.Data, val.Interface()); err != nil { return } outCh.Send(val.Elem()) } } }() return outCh.Interface().(R), nil }