Files
gobridge/client.go
what db71a904e0 feat: 添加全局错误日志捕获
- Python 侧 _dispatch 异常时输出完整堆栈到 stderr
- Go 侧 handler 返回 error 时打印日志
- 升级 Python 包版本至 0.1.4
2026-05-20 19:30:52 +08:00

387 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Package gobridge 提供 Go 与 Python 之间的双向通信桥接,
// 支持普通调用、流式输出、流式输入和双向流四种模式。
package gobridge
import (
"context"
"encoding/json"
"fmt"
"log"
"net"
"reflect"
"sync"
)
// Invoke 调用 Python 暴露的函数,支持四种模式:
//
// 普通调用: Invoke[int](ctx, pool, "Add", 3, 4)
// 流式输出: Invoke[chan int](ctx, pool, "RangeGen", 1, 10) // Python yield → Go channel
// 流式输入: Invoke[int](ctx, pool, "SumStream", inputChan) // Go channel → Python Iterator
// 双向流: Invoke[chan int](ctx, pool, "Transform", inputChan) // 两端均为流
//
// ctx 取消时会立即中断与 Python 的通信并返回 ctx.Err()。
// 对于流式输出/双向流ctx 取消会关闭返回的 channel。
func Invoke[R any](ctx context.Context, pool Pool, method string, args ...any) (R, error) {
rt := reflect.TypeFor[R]()
// 查找 chan 类型的输入参数
streamArgIdx := -1
var streamCh reflect.Value
for i, arg := range args {
if arg != nil {
rv := reflect.ValueOf(arg)
if rv.Kind() == reflect.Chan {
streamArgIdx = i
streamCh = rv
break
}
}
}
switch {
case rt.Kind() == reflect.Chan && streamArgIdx >= 0:
return invokeStreamBoth[R](ctx, pool, method, streamArgIdx, streamCh, rt, args...)
case rt.Kind() == reflect.Chan:
return invokeStreamOut[R](ctx, pool, method, rt, args...)
case streamArgIdx >= 0:
return invokeStreamIn[R](ctx, pool, method, streamArgIdx, streamCh, args...)
default:
return invokeRegular[R](ctx, pool, method, args...)
}
}
// watchCtx 启动一个 goroutine 监听 ctx
// - ctx 取消时先发送 cancel 消息Python 侧收到后注入 InterruptedError
// - 再关闭连接,解除阻塞中的读写操作
//
// write 是调用方提供的互斥写函数,保证与其他写操作不并发。
// 返回 stop 函数,必须在 conn 归还连接池前调用,可安全多次调用。
func watchCtx(ctx context.Context, conn net.Conn, id uint64, write func(Message)) (stop func()) {
done := make(chan struct{})
var once sync.Once
go func() {
select {
case <-ctx.Done():
write(Message{ID: id, Type: TypeCancel})
conn.Close()
case <-done:
}
}()
return func() { once.Do(func() { close(done) }) }
}
// chanRecv 从 ch 接收一个值,同时监听 ctx.Done()。
// 返回 (值, channel是否open, ctx是否已取消)。
func chanRecv(ctx context.Context, ch reflect.Value) (reflect.Value, bool, bool) {
chosen, val, ok := reflect.Select([]reflect.SelectCase{
{Dir: reflect.SelectRecv, Chan: ch},
{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ctx.Done())},
})
if chosen == 1 {
return reflect.Value{}, false, true
}
return val, ok, false
}
// contextErr 在 io 错误时优先返回 ctx 的错误原因
func contextErr(ctx context.Context, err error) error {
if e := ctx.Err(); e != nil {
return e
}
return err
}
// readResult 读取下一条非 callback 消息,期间内联处理所有 Python→Go 回调。
// 保证 py→go→py→go→... 全链路复用同一条连接,不产生额外线程。
// write 是调用方提供的互斥写函数,与 watchCtx 共享同一把锁,避免并发写。
func readResult(ctx context.Context, conn net.Conn, pool Pool, write func(Message)) (Message, error) {
for {
msg, err := readMsg(conn)
if err != nil {
return Message{}, err
}
if msg.Type != TypeCallback {
return msg, nil
}
result, errStr := pool.callbackDispatch(ctx, msg)
var resp Message
if errStr != "" {
log.Printf("gobridge: handler %s error: %s", msg.Method, errStr)
resp = Message{ID: msg.ID, Type: TypeError, Error: errStr}
} else {
data, _ := json.Marshal(result)
resp = Message{ID: msg.ID, Type: TypeCallbackResult, Data: data}
}
write(resp)
}
}
func invokeRegular[R any](ctx context.Context, pool Pool, method string, args ...any) (R, error) {
var zero R
argsJSON, err := json.Marshal(args)
if err != nil {
return zero, fmt.Errorf("marshal args: %w", err)
}
conn, w, err := pool.acquire(ctx)
if err != nil {
return zero, err
}
var mu sync.Mutex
write := func(msg Message) { mu.Lock(); writeMsg(conn, msg); mu.Unlock() } //nolint
id := pool.nextReqID()
stop := watchCtx(ctx, conn, id, write)
defer stop()
mu.Lock()
err = writeMsg(conn, Message{ID: id, Type: TypeCall, Method: method, Args: argsJSON})
mu.Unlock()
if err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
}
resp, err := readResult(ctx, conn, pool, write)
if err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("read response: %w", err))
}
stop()
w.release(conn, true)
if resp.Type == TypeError {
return zero, fmt.Errorf("remote error: %s", resp.Error)
}
var result R
if err := json.Unmarshal(resp.Data, &result); err != nil {
return zero, fmt.Errorf("unmarshal result: %w", err)
}
return result, nil
}
func invokeStreamOut[R any](ctx context.Context, pool Pool, method string, rt reflect.Type, args ...any) (R, error) {
var zero R
argsJSON, err := json.Marshal(args)
if err != nil {
return zero, fmt.Errorf("marshal args: %w", err)
}
conn, w, err := pool.acquire(ctx)
if err != nil {
return zero, err
}
id := pool.nextReqID()
if err := writeMsg(conn, Message{
ID: id,
Type: TypeCall,
Method: method,
Args: argsJSON,
}); err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
}
ch := reflect.MakeChan(rt, 64)
go func() {
var mu sync.Mutex
write := func(msg Message) { mu.Lock(); writeMsg(conn, msg); mu.Unlock() } //nolint
stop := watchCtx(ctx, conn, id, write)
defer func() {
stop()
ch.Close()
w.release(conn, ctx.Err() == nil)
}()
for {
msg, err := readResult(ctx, conn, pool, write)
if err != nil || msg.Type == TypeEnd || msg.Type == TypeError {
return
}
if msg.Type == TypeChunk {
val := reflect.New(rt.Elem())
if err := json.Unmarshal(msg.Data, val.Interface()); err != nil {
return
}
ch.Send(val.Elem())
}
}
}()
return ch.Interface().(R), nil
}
func invokeStreamIn[R any](ctx context.Context, pool Pool, method string, streamArgIdx int, streamCh reflect.Value, args ...any) (R, error) {
var zero R
jsonArgs := make([]any, len(args))
copy(jsonArgs, args)
jsonArgs[streamArgIdx] = nil
argsJSON, err := json.Marshal(jsonArgs)
if err != nil {
return zero, fmt.Errorf("marshal args: %w", err)
}
conn, w, err := pool.acquire(ctx)
if err != nil {
return zero, err
}
var mu sync.Mutex
writeErr := func(msg Message) error {
mu.Lock()
defer mu.Unlock()
return writeMsg(conn, msg)
}
write := func(msg Message) { writeErr(msg) } //nolint
id := pool.nextReqID()
stop := watchCtx(ctx, conn, id, write)
defer stop()
if err := writeErr(Message{
ID: id,
Type: TypeCall,
Method: method,
Args: argsJSON,
StreamInput: true,
StreamArgIdx: streamArgIdx,
}); err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
}
for {
val, ok, cancelled := chanRecv(ctx, streamCh)
if cancelled {
w.release(conn, false)
return zero, ctx.Err()
}
if !ok {
break
}
chunkData, err := json.Marshal(val.Interface())
if err != nil {
w.release(conn, false)
return zero, fmt.Errorf("marshal chunk: %w", err)
}
if err := writeErr(Message{ID: id, Type: TypeChunk, Data: chunkData}); err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("write chunk: %w", err))
}
}
if err := writeErr(Message{ID: id, Type: TypeEnd}); err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("write end: %w", err))
}
resp, err := readResult(ctx, conn, pool, write)
if err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("read response: %w", err))
}
stop()
w.release(conn, true)
if resp.Type == TypeError {
return zero, fmt.Errorf("remote error: %s", resp.Error)
}
var result R
if err := json.Unmarshal(resp.Data, &result); err != nil {
return zero, fmt.Errorf("unmarshal result: %w", err)
}
return result, nil
}
func invokeStreamBoth[R any](ctx context.Context, pool Pool, method string, streamArgIdx int, streamCh reflect.Value, rt reflect.Type, args ...any) (R, error) {
var zero R
jsonArgs := make([]any, len(args))
copy(jsonArgs, args)
jsonArgs[streamArgIdx] = nil
argsJSON, err := json.Marshal(jsonArgs)
if err != nil {
return zero, fmt.Errorf("marshal args: %w", err)
}
conn, w, err := pool.acquire(ctx)
if err != nil {
return zero, err
}
id := pool.nextReqID()
if err := writeMsg(conn, Message{
ID: id,
Type: TypeCall,
Method: method,
Args: argsJSON,
StreamInput: true,
StreamArgIdx: streamArgIdx,
}); err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
}
outCh := reflect.MakeChan(rt, 64)
var mu sync.Mutex
write := func(msg Message) { mu.Lock(); writeMsg(conn, msg); mu.Unlock() } //nolint
// 写入 goroutine输入 channel → Python chunks
go func() {
for {
val, ok, cancelled := chanRecv(ctx, streamCh)
if cancelled || !ok {
break
}
data, err := json.Marshal(val.Interface())
if err != nil {
break
}
mu.Lock()
err = writeMsg(conn, Message{ID: id, Type: TypeChunk, Data: data})
mu.Unlock()
if err != nil {
break
}
}
write(Message{ID: id, Type: TypeEnd})
}()
// 读取 goroutinePython chunks → 输出 channel内联处理 callback
go func() {
stop := watchCtx(ctx, conn, id, write)
defer func() {
stop()
outCh.Close()
w.release(conn, ctx.Err() == nil)
}()
for {
msg, err := readResult(ctx, conn, pool, write)
if err != nil || msg.Type == TypeEnd || msg.Type == TypeError {
return
}
if msg.Type == TypeChunk {
val := reflect.New(rt.Elem())
if err := json.Unmarshal(msg.Data, val.Interface()); err != nil {
return
}
outCh.Send(val.Elem())
}
}
}()
return outCh.Interface().(R), nil
}