Files
gobridge/client.go
2026-04-10 09:24:28 +08:00

343 lines
8.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Package gobridge 提供 Go 与 Python 之间的双向通信桥接,
// 支持普通调用、流式输出、流式输入和双向流四种模式。
package gobridge
import (
"context"
"encoding/json"
"fmt"
"net"
"reflect"
"sync"
)
// Invoke 调用 Python 暴露的函数,支持四种模式:
//
// 普通调用: Invoke[int](ctx, pool, "Add", 3, 4)
// 流式输出: Invoke[chan int](ctx, pool, "RangeGen", 1, 10) // Python yield → Go channel
// 流式输入: Invoke[int](ctx, pool, "SumStream", inputChan) // Go channel → Python Iterator
// 双向流: Invoke[chan int](ctx, pool, "Transform", inputChan) // 两端均为流
//
// ctx 取消时会立即中断与 Python 的通信并返回 ctx.Err()。
// 对于流式输出/双向流ctx 取消会关闭返回的 channel。
func Invoke[R any](ctx context.Context, pool *Pool, method string, args ...any) (R, error) {
rt := reflect.TypeFor[R]()
// 查找 chan 类型的输入参数
streamArgIdx := -1
var streamCh reflect.Value
for i, arg := range args {
if arg != nil {
rv := reflect.ValueOf(arg)
if rv.Kind() == reflect.Chan {
streamArgIdx = i
streamCh = rv
break
}
}
}
switch {
case rt.Kind() == reflect.Chan && streamArgIdx >= 0:
return invokeStreamBoth[R](ctx, pool, method, streamArgIdx, streamCh, rt, args...)
case rt.Kind() == reflect.Chan:
return invokeStreamOut[R](ctx, pool, method, rt, args...)
case streamArgIdx >= 0:
return invokeStreamIn[R](ctx, pool, method, streamArgIdx, streamCh, args...)
default:
return invokeRegular[R](ctx, pool, method, args...)
}
}
// watchCtx 启动一个 goroutine 监听 ctx
// - ctx 取消时先发送 cancel 消息Python 侧收到后注入 InterruptedError
// - 再关闭连接,解除阻塞中的读写操作
//
// 返回 stop 函数,必须在 conn 归还连接池前调用,可安全多次调用。
func watchCtx(ctx context.Context, conn net.Conn, id uint64) (stop func()) {
done := make(chan struct{})
var once sync.Once
go func() {
select {
case <-ctx.Done():
writeMsg(conn, Message{ID: id, Type: TypeCancel}) //nolint
conn.Close()
case <-done:
}
}()
return func() { once.Do(func() { close(done) }) }
}
// chanRecv 从 ch 接收一个值,同时监听 ctx.Done()。
// 返回 (值, channel是否open, ctx是否已取消)。
func chanRecv(ctx context.Context, ch reflect.Value) (reflect.Value, bool, bool) {
chosen, val, ok := reflect.Select([]reflect.SelectCase{
{Dir: reflect.SelectRecv, Chan: ch},
{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ctx.Done())},
})
if chosen == 1 {
return reflect.Value{}, false, true
}
return val, ok, false
}
// contextErr 在 io 错误时优先返回 ctx 的错误原因
func contextErr(ctx context.Context, err error) error {
if e := ctx.Err(); e != nil {
return e
}
return err
}
func invokeRegular[R any](ctx context.Context, pool *Pool, method string, args ...any) (R, error) {
var zero R
argsJSON, err := json.Marshal(args)
if err != nil {
return zero, fmt.Errorf("marshal args: %w", err)
}
conn, w, err := pool.acquire(ctx)
if err != nil {
return zero, err
}
id := pool.reqID.Add(1)
stop := watchCtx(ctx, conn, id)
defer stop()
if err := writeMsg(conn, Message{
ID: id,
Type: TypeCall,
Method: method,
Args: argsJSON,
}); err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
}
resp, err := readMsg(conn)
if err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("read response: %w", err))
}
stop()
w.release(conn, true)
if resp.Type == TypeError {
return zero, fmt.Errorf("remote error: %s", resp.Error)
}
var result R
if err := json.Unmarshal(resp.Data, &result); err != nil {
return zero, fmt.Errorf("unmarshal result: %w", err)
}
return result, nil
}
func invokeStreamOut[R any](ctx context.Context, pool *Pool, method string, rt reflect.Type, args ...any) (R, error) {
var zero R
argsJSON, err := json.Marshal(args)
if err != nil {
return zero, fmt.Errorf("marshal args: %w", err)
}
conn, w, err := pool.acquire(ctx)
if err != nil {
return zero, err
}
id := pool.reqID.Add(1)
if err := writeMsg(conn, Message{
ID: id,
Type: TypeCall,
Method: method,
Args: argsJSON,
}); err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
}
ch := reflect.MakeChan(rt, 64)
go func() {
stop := watchCtx(ctx, conn, id)
defer func() {
stop()
ch.Close()
w.release(conn, ctx.Err() == nil)
}()
for {
msg, err := readMsg(conn)
if err != nil || msg.Type == TypeEnd || msg.Type == TypeError {
return
}
if msg.Type == TypeChunk {
val := reflect.New(rt.Elem())
if err := json.Unmarshal(msg.Data, val.Interface()); err != nil {
return
}
ch.Send(val.Elem())
}
}
}()
return ch.Interface().(R), nil
}
func invokeStreamIn[R any](ctx context.Context, pool *Pool, method string, streamArgIdx int, streamCh reflect.Value, args ...any) (R, error) {
var zero R
jsonArgs := make([]any, len(args))
copy(jsonArgs, args)
jsonArgs[streamArgIdx] = nil
argsJSON, err := json.Marshal(jsonArgs)
if err != nil {
return zero, fmt.Errorf("marshal args: %w", err)
}
conn, w, err := pool.acquire(ctx)
if err != nil {
return zero, err
}
id := pool.reqID.Add(1)
stop := watchCtx(ctx, conn, id)
defer stop()
if err := writeMsg(conn, Message{
ID: id,
Type: TypeCall,
Method: method,
Args: argsJSON,
StreamInput: true,
StreamArgIdx: streamArgIdx,
}); err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
}
for {
val, ok, cancelled := chanRecv(ctx, streamCh)
if cancelled {
w.release(conn, false)
return zero, ctx.Err()
}
if !ok {
break
}
chunkData, err := json.Marshal(val.Interface())
if err != nil {
w.release(conn, false)
return zero, fmt.Errorf("marshal chunk: %w", err)
}
if err := writeMsg(conn, Message{ID: id, Type: TypeChunk, Data: chunkData}); err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("write chunk: %w", err))
}
}
if err := writeMsg(conn, Message{ID: id, Type: TypeEnd}); err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("write end: %w", err))
}
resp, err := readMsg(conn)
if err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("read response: %w", err))
}
stop()
w.release(conn, true)
if resp.Type == TypeError {
return zero, fmt.Errorf("remote error: %s", resp.Error)
}
var result R
if err := json.Unmarshal(resp.Data, &result); err != nil {
return zero, fmt.Errorf("unmarshal result: %w", err)
}
return result, nil
}
func invokeStreamBoth[R any](ctx context.Context, pool *Pool, method string, streamArgIdx int, streamCh reflect.Value, rt reflect.Type, args ...any) (R, error) {
var zero R
jsonArgs := make([]any, len(args))
copy(jsonArgs, args)
jsonArgs[streamArgIdx] = nil
argsJSON, err := json.Marshal(jsonArgs)
if err != nil {
return zero, fmt.Errorf("marshal args: %w", err)
}
conn, w, err := pool.acquire(ctx)
if err != nil {
return zero, err
}
id := pool.reqID.Add(1)
if err := writeMsg(conn, Message{
ID: id,
Type: TypeCall,
Method: method,
Args: argsJSON,
StreamInput: true,
StreamArgIdx: streamArgIdx,
}); err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
}
outCh := reflect.MakeChan(rt, 64)
// 写入 goroutine输入 channel → Python chunks
go func() {
for {
val, ok, cancelled := chanRecv(ctx, streamCh)
if cancelled || !ok {
break
}
data, err := json.Marshal(val.Interface())
if err != nil {
break
}
if err := writeMsg(conn, Message{ID: id, Type: TypeChunk, Data: data}); err != nil {
break
}
}
writeMsg(conn, Message{ID: id, Type: TypeEnd}) //nolint
}()
// 读取 goroutinePython chunks → 输出 channel
go func() {
stop := watchCtx(ctx, conn, id)
defer func() {
stop()
outCh.Close()
w.release(conn, ctx.Err() == nil)
}()
for {
msg, err := readMsg(conn)
if err != nil || msg.Type == TypeEnd || msg.Type == TypeError {
return
}
if msg.Type == TypeChunk {
val := reflect.New(rt.Elem())
if err := json.Unmarshal(msg.Data, val.Interface()); err != nil {
return
}
outCh.Send(val.Elem())
}
}
}()
return outCh.Interface().(R), nil
}