343 lines
8.4 KiB
Go
343 lines
8.4 KiB
Go
// Package gobridge 提供 Go 与 Python 之间的双向通信桥接,
|
||
// 支持普通调用、流式输出、流式输入和双向流四种模式。
|
||
package gobridge
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"net"
|
||
"reflect"
|
||
"sync"
|
||
)
|
||
|
||
// Invoke 调用 Python 暴露的函数,支持四种模式:
|
||
//
|
||
// 普通调用: Invoke[int](ctx, pool, "Add", 3, 4)
|
||
// 流式输出: Invoke[chan int](ctx, pool, "RangeGen", 1, 10) // Python yield → Go channel
|
||
// 流式输入: Invoke[int](ctx, pool, "SumStream", inputChan) // Go channel → Python Iterator
|
||
// 双向流: Invoke[chan int](ctx, pool, "Transform", inputChan) // 两端均为流
|
||
//
|
||
// ctx 取消时会立即中断与 Python 的通信并返回 ctx.Err()。
|
||
// 对于流式输出/双向流,ctx 取消会关闭返回的 channel。
|
||
func Invoke[R any](ctx context.Context, pool *Pool, method string, args ...any) (R, error) {
|
||
rt := reflect.TypeFor[R]()
|
||
|
||
// 查找 chan 类型的输入参数
|
||
streamArgIdx := -1
|
||
var streamCh reflect.Value
|
||
for i, arg := range args {
|
||
if arg != nil {
|
||
rv := reflect.ValueOf(arg)
|
||
if rv.Kind() == reflect.Chan {
|
||
streamArgIdx = i
|
||
streamCh = rv
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
switch {
|
||
case rt.Kind() == reflect.Chan && streamArgIdx >= 0:
|
||
return invokeStreamBoth[R](ctx, pool, method, streamArgIdx, streamCh, rt, args...)
|
||
case rt.Kind() == reflect.Chan:
|
||
return invokeStreamOut[R](ctx, pool, method, rt, args...)
|
||
case streamArgIdx >= 0:
|
||
return invokeStreamIn[R](ctx, pool, method, streamArgIdx, streamCh, args...)
|
||
default:
|
||
return invokeRegular[R](ctx, pool, method, args...)
|
||
}
|
||
}
|
||
|
||
// watchCtx 启动一个 goroutine 监听 ctx:
|
||
// - ctx 取消时先发送 cancel 消息(Python 侧收到后注入 InterruptedError)
|
||
// - 再关闭连接,解除阻塞中的读写操作
|
||
//
|
||
// 返回 stop 函数,必须在 conn 归还连接池前调用,可安全多次调用。
|
||
func watchCtx(ctx context.Context, conn net.Conn, id uint64) (stop func()) {
|
||
done := make(chan struct{})
|
||
var once sync.Once
|
||
go func() {
|
||
select {
|
||
case <-ctx.Done():
|
||
writeMsg(conn, Message{ID: id, Type: TypeCancel}) //nolint
|
||
conn.Close()
|
||
case <-done:
|
||
}
|
||
}()
|
||
return func() { once.Do(func() { close(done) }) }
|
||
}
|
||
|
||
// chanRecv 从 ch 接收一个值,同时监听 ctx.Done()。
|
||
// 返回 (值, channel是否open, ctx是否已取消)。
|
||
func chanRecv(ctx context.Context, ch reflect.Value) (reflect.Value, bool, bool) {
|
||
chosen, val, ok := reflect.Select([]reflect.SelectCase{
|
||
{Dir: reflect.SelectRecv, Chan: ch},
|
||
{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ctx.Done())},
|
||
})
|
||
if chosen == 1 {
|
||
return reflect.Value{}, false, true
|
||
}
|
||
return val, ok, false
|
||
}
|
||
|
||
// contextErr 在 io 错误时优先返回 ctx 的错误原因
|
||
func contextErr(ctx context.Context, err error) error {
|
||
if e := ctx.Err(); e != nil {
|
||
return e
|
||
}
|
||
return err
|
||
}
|
||
|
||
func invokeRegular[R any](ctx context.Context, pool *Pool, method string, args ...any) (R, error) {
|
||
var zero R
|
||
|
||
argsJSON, err := json.Marshal(args)
|
||
if err != nil {
|
||
return zero, fmt.Errorf("marshal args: %w", err)
|
||
}
|
||
|
||
conn, w, err := pool.acquire(ctx)
|
||
if err != nil {
|
||
return zero, err
|
||
}
|
||
|
||
id := pool.reqID.Add(1)
|
||
stop := watchCtx(ctx, conn, id)
|
||
defer stop()
|
||
|
||
if err := writeMsg(conn, Message{
|
||
ID: id,
|
||
Type: TypeCall,
|
||
Method: method,
|
||
Args: argsJSON,
|
||
}); err != nil {
|
||
w.release(conn, false)
|
||
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
|
||
}
|
||
|
||
resp, err := readMsg(conn)
|
||
if err != nil {
|
||
w.release(conn, false)
|
||
return zero, contextErr(ctx, fmt.Errorf("read response: %w", err))
|
||
}
|
||
|
||
stop()
|
||
w.release(conn, true)
|
||
|
||
if resp.Type == TypeError {
|
||
return zero, fmt.Errorf("remote error: %s", resp.Error)
|
||
}
|
||
|
||
var result R
|
||
if err := json.Unmarshal(resp.Data, &result); err != nil {
|
||
return zero, fmt.Errorf("unmarshal result: %w", err)
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
func invokeStreamOut[R any](ctx context.Context, pool *Pool, method string, rt reflect.Type, args ...any) (R, error) {
|
||
var zero R
|
||
|
||
argsJSON, err := json.Marshal(args)
|
||
if err != nil {
|
||
return zero, fmt.Errorf("marshal args: %w", err)
|
||
}
|
||
|
||
conn, w, err := pool.acquire(ctx)
|
||
if err != nil {
|
||
return zero, err
|
||
}
|
||
|
||
id := pool.reqID.Add(1)
|
||
if err := writeMsg(conn, Message{
|
||
ID: id,
|
||
Type: TypeCall,
|
||
Method: method,
|
||
Args: argsJSON,
|
||
}); err != nil {
|
||
w.release(conn, false)
|
||
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
|
||
}
|
||
|
||
ch := reflect.MakeChan(rt, 64)
|
||
|
||
go func() {
|
||
stop := watchCtx(ctx, conn, id)
|
||
defer func() {
|
||
stop()
|
||
ch.Close()
|
||
w.release(conn, ctx.Err() == nil)
|
||
}()
|
||
for {
|
||
msg, err := readMsg(conn)
|
||
if err != nil || msg.Type == TypeEnd || msg.Type == TypeError {
|
||
return
|
||
}
|
||
if msg.Type == TypeChunk {
|
||
val := reflect.New(rt.Elem())
|
||
if err := json.Unmarshal(msg.Data, val.Interface()); err != nil {
|
||
return
|
||
}
|
||
ch.Send(val.Elem())
|
||
}
|
||
}
|
||
}()
|
||
|
||
return ch.Interface().(R), nil
|
||
}
|
||
|
||
func invokeStreamIn[R any](ctx context.Context, pool *Pool, method string, streamArgIdx int, streamCh reflect.Value, args ...any) (R, error) {
|
||
var zero R
|
||
|
||
jsonArgs := make([]any, len(args))
|
||
copy(jsonArgs, args)
|
||
jsonArgs[streamArgIdx] = nil
|
||
|
||
argsJSON, err := json.Marshal(jsonArgs)
|
||
if err != nil {
|
||
return zero, fmt.Errorf("marshal args: %w", err)
|
||
}
|
||
|
||
conn, w, err := pool.acquire(ctx)
|
||
if err != nil {
|
||
return zero, err
|
||
}
|
||
|
||
id := pool.reqID.Add(1)
|
||
stop := watchCtx(ctx, conn, id)
|
||
defer stop()
|
||
|
||
if err := writeMsg(conn, Message{
|
||
ID: id,
|
||
Type: TypeCall,
|
||
Method: method,
|
||
Args: argsJSON,
|
||
StreamInput: true,
|
||
StreamArgIdx: streamArgIdx,
|
||
}); err != nil {
|
||
w.release(conn, false)
|
||
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
|
||
}
|
||
|
||
for {
|
||
val, ok, cancelled := chanRecv(ctx, streamCh)
|
||
if cancelled {
|
||
w.release(conn, false)
|
||
return zero, ctx.Err()
|
||
}
|
||
if !ok {
|
||
break
|
||
}
|
||
chunkData, err := json.Marshal(val.Interface())
|
||
if err != nil {
|
||
w.release(conn, false)
|
||
return zero, fmt.Errorf("marshal chunk: %w", err)
|
||
}
|
||
if err := writeMsg(conn, Message{ID: id, Type: TypeChunk, Data: chunkData}); err != nil {
|
||
w.release(conn, false)
|
||
return zero, contextErr(ctx, fmt.Errorf("write chunk: %w", err))
|
||
}
|
||
}
|
||
|
||
if err := writeMsg(conn, Message{ID: id, Type: TypeEnd}); err != nil {
|
||
w.release(conn, false)
|
||
return zero, contextErr(ctx, fmt.Errorf("write end: %w", err))
|
||
}
|
||
|
||
resp, err := readMsg(conn)
|
||
if err != nil {
|
||
w.release(conn, false)
|
||
return zero, contextErr(ctx, fmt.Errorf("read response: %w", err))
|
||
}
|
||
|
||
stop()
|
||
w.release(conn, true)
|
||
|
||
if resp.Type == TypeError {
|
||
return zero, fmt.Errorf("remote error: %s", resp.Error)
|
||
}
|
||
|
||
var result R
|
||
if err := json.Unmarshal(resp.Data, &result); err != nil {
|
||
return zero, fmt.Errorf("unmarshal result: %w", err)
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
func invokeStreamBoth[R any](ctx context.Context, pool *Pool, method string, streamArgIdx int, streamCh reflect.Value, rt reflect.Type, args ...any) (R, error) {
|
||
var zero R
|
||
|
||
jsonArgs := make([]any, len(args))
|
||
copy(jsonArgs, args)
|
||
jsonArgs[streamArgIdx] = nil
|
||
|
||
argsJSON, err := json.Marshal(jsonArgs)
|
||
if err != nil {
|
||
return zero, fmt.Errorf("marshal args: %w", err)
|
||
}
|
||
|
||
conn, w, err := pool.acquire(ctx)
|
||
if err != nil {
|
||
return zero, err
|
||
}
|
||
|
||
id := pool.reqID.Add(1)
|
||
if err := writeMsg(conn, Message{
|
||
ID: id,
|
||
Type: TypeCall,
|
||
Method: method,
|
||
Args: argsJSON,
|
||
StreamInput: true,
|
||
StreamArgIdx: streamArgIdx,
|
||
}); err != nil {
|
||
w.release(conn, false)
|
||
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
|
||
}
|
||
|
||
outCh := reflect.MakeChan(rt, 64)
|
||
|
||
// 写入 goroutine:输入 channel → Python chunks
|
||
go func() {
|
||
for {
|
||
val, ok, cancelled := chanRecv(ctx, streamCh)
|
||
if cancelled || !ok {
|
||
break
|
||
}
|
||
data, err := json.Marshal(val.Interface())
|
||
if err != nil {
|
||
break
|
||
}
|
||
if err := writeMsg(conn, Message{ID: id, Type: TypeChunk, Data: data}); err != nil {
|
||
break
|
||
}
|
||
}
|
||
writeMsg(conn, Message{ID: id, Type: TypeEnd}) //nolint
|
||
}()
|
||
|
||
// 读取 goroutine:Python chunks → 输出 channel
|
||
go func() {
|
||
stop := watchCtx(ctx, conn, id)
|
||
defer func() {
|
||
stop()
|
||
outCh.Close()
|
||
w.release(conn, ctx.Err() == nil)
|
||
}()
|
||
for {
|
||
msg, err := readMsg(conn)
|
||
if err != nil || msg.Type == TypeEnd || msg.Type == TypeError {
|
||
return
|
||
}
|
||
if msg.Type == TypeChunk {
|
||
val := reflect.New(rt.Elem())
|
||
if err := json.Unmarshal(msg.Data, val.Interface()); err != nil {
|
||
return
|
||
}
|
||
outCh.Send(val.Elem())
|
||
}
|
||
}
|
||
}()
|
||
|
||
return outCh.Interface().(R), nil
|
||
}
|