first commit
This commit is contained in:
342
client.go
Normal file
342
client.go
Normal file
@@ -0,0 +1,342 @@
|
||||
// Package gobridge 提供 Go 与 Python 之间的双向通信桥接,
|
||||
// 支持普通调用、流式输出、流式输入和双向流四种模式。
|
||||
package gobridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Invoke 调用 Python 暴露的函数,支持四种模式:
|
||||
//
|
||||
// 普通调用: Invoke[int](ctx, pool, "Add", 3, 4)
|
||||
// 流式输出: Invoke[chan int](ctx, pool, "RangeGen", 1, 10) // Python yield → Go channel
|
||||
// 流式输入: Invoke[int](ctx, pool, "SumStream", inputChan) // Go channel → Python Iterator
|
||||
// 双向流: Invoke[chan int](ctx, pool, "Transform", inputChan) // 两端均为流
|
||||
//
|
||||
// ctx 取消时会立即中断与 Python 的通信并返回 ctx.Err()。
|
||||
// 对于流式输出/双向流,ctx 取消会关闭返回的 channel。
|
||||
func Invoke[R any](ctx context.Context, pool *Pool, method string, args ...any) (R, error) {
|
||||
rt := reflect.TypeFor[R]()
|
||||
|
||||
// 查找 chan 类型的输入参数
|
||||
streamArgIdx := -1
|
||||
var streamCh reflect.Value
|
||||
for i, arg := range args {
|
||||
if arg != nil {
|
||||
rv := reflect.ValueOf(arg)
|
||||
if rv.Kind() == reflect.Chan {
|
||||
streamArgIdx = i
|
||||
streamCh = rv
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case rt.Kind() == reflect.Chan && streamArgIdx >= 0:
|
||||
return invokeStreamBoth[R](ctx, pool, method, streamArgIdx, streamCh, rt, args...)
|
||||
case rt.Kind() == reflect.Chan:
|
||||
return invokeStreamOut[R](ctx, pool, method, rt, args...)
|
||||
case streamArgIdx >= 0:
|
||||
return invokeStreamIn[R](ctx, pool, method, streamArgIdx, streamCh, args...)
|
||||
default:
|
||||
return invokeRegular[R](ctx, pool, method, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// watchCtx 启动一个 goroutine 监听 ctx:
|
||||
// - ctx 取消时先发送 cancel 消息(Python 侧收到后注入 InterruptedError)
|
||||
// - 再关闭连接,解除阻塞中的读写操作
|
||||
//
|
||||
// 返回 stop 函数,必须在 conn 归还连接池前调用,可安全多次调用。
|
||||
func watchCtx(ctx context.Context, conn net.Conn, id uint64) (stop func()) {
|
||||
done := make(chan struct{})
|
||||
var once sync.Once
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
writeMsg(conn, Message{ID: id, Type: TypeCancel}) //nolint
|
||||
conn.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
return func() { once.Do(func() { close(done) }) }
|
||||
}
|
||||
|
||||
// chanRecv 从 ch 接收一个值,同时监听 ctx.Done()。
|
||||
// 返回 (值, channel是否open, ctx是否已取消)。
|
||||
func chanRecv(ctx context.Context, ch reflect.Value) (reflect.Value, bool, bool) {
|
||||
chosen, val, ok := reflect.Select([]reflect.SelectCase{
|
||||
{Dir: reflect.SelectRecv, Chan: ch},
|
||||
{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ctx.Done())},
|
||||
})
|
||||
if chosen == 1 {
|
||||
return reflect.Value{}, false, true
|
||||
}
|
||||
return val, ok, false
|
||||
}
|
||||
|
||||
// contextErr 在 io 错误时优先返回 ctx 的错误原因
|
||||
func contextErr(ctx context.Context, err error) error {
|
||||
if e := ctx.Err(); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func invokeRegular[R any](ctx context.Context, pool *Pool, method string, args ...any) (R, error) {
|
||||
var zero R
|
||||
|
||||
argsJSON, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return zero, fmt.Errorf("marshal args: %w", err)
|
||||
}
|
||||
|
||||
conn, w, err := pool.acquire(ctx)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
id := pool.reqID.Add(1)
|
||||
stop := watchCtx(ctx, conn, id)
|
||||
defer stop()
|
||||
|
||||
if err := writeMsg(conn, Message{
|
||||
ID: id,
|
||||
Type: TypeCall,
|
||||
Method: method,
|
||||
Args: argsJSON,
|
||||
}); err != nil {
|
||||
w.release(conn, false)
|
||||
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
|
||||
}
|
||||
|
||||
resp, err := readMsg(conn)
|
||||
if err != nil {
|
||||
w.release(conn, false)
|
||||
return zero, contextErr(ctx, fmt.Errorf("read response: %w", err))
|
||||
}
|
||||
|
||||
stop()
|
||||
w.release(conn, true)
|
||||
|
||||
if resp.Type == TypeError {
|
||||
return zero, fmt.Errorf("remote error: %s", resp.Error)
|
||||
}
|
||||
|
||||
var result R
|
||||
if err := json.Unmarshal(resp.Data, &result); err != nil {
|
||||
return zero, fmt.Errorf("unmarshal result: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func invokeStreamOut[R any](ctx context.Context, pool *Pool, method string, rt reflect.Type, args ...any) (R, error) {
|
||||
var zero R
|
||||
|
||||
argsJSON, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return zero, fmt.Errorf("marshal args: %w", err)
|
||||
}
|
||||
|
||||
conn, w, err := pool.acquire(ctx)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
id := pool.reqID.Add(1)
|
||||
if err := writeMsg(conn, Message{
|
||||
ID: id,
|
||||
Type: TypeCall,
|
||||
Method: method,
|
||||
Args: argsJSON,
|
||||
}); err != nil {
|
||||
w.release(conn, false)
|
||||
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
|
||||
}
|
||||
|
||||
ch := reflect.MakeChan(rt, 64)
|
||||
|
||||
go func() {
|
||||
stop := watchCtx(ctx, conn, id)
|
||||
defer func() {
|
||||
stop()
|
||||
ch.Close()
|
||||
w.release(conn, ctx.Err() == nil)
|
||||
}()
|
||||
for {
|
||||
msg, err := readMsg(conn)
|
||||
if err != nil || msg.Type == TypeEnd || msg.Type == TypeError {
|
||||
return
|
||||
}
|
||||
if msg.Type == TypeChunk {
|
||||
val := reflect.New(rt.Elem())
|
||||
if err := json.Unmarshal(msg.Data, val.Interface()); err != nil {
|
||||
return
|
||||
}
|
||||
ch.Send(val.Elem())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch.Interface().(R), nil
|
||||
}
|
||||
|
||||
func invokeStreamIn[R any](ctx context.Context, pool *Pool, method string, streamArgIdx int, streamCh reflect.Value, args ...any) (R, error) {
|
||||
var zero R
|
||||
|
||||
jsonArgs := make([]any, len(args))
|
||||
copy(jsonArgs, args)
|
||||
jsonArgs[streamArgIdx] = nil
|
||||
|
||||
argsJSON, err := json.Marshal(jsonArgs)
|
||||
if err != nil {
|
||||
return zero, fmt.Errorf("marshal args: %w", err)
|
||||
}
|
||||
|
||||
conn, w, err := pool.acquire(ctx)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
id := pool.reqID.Add(1)
|
||||
stop := watchCtx(ctx, conn, id)
|
||||
defer stop()
|
||||
|
||||
if err := writeMsg(conn, Message{
|
||||
ID: id,
|
||||
Type: TypeCall,
|
||||
Method: method,
|
||||
Args: argsJSON,
|
||||
StreamInput: true,
|
||||
StreamArgIdx: streamArgIdx,
|
||||
}); err != nil {
|
||||
w.release(conn, false)
|
||||
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
|
||||
}
|
||||
|
||||
for {
|
||||
val, ok, cancelled := chanRecv(ctx, streamCh)
|
||||
if cancelled {
|
||||
w.release(conn, false)
|
||||
return zero, ctx.Err()
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
chunkData, err := json.Marshal(val.Interface())
|
||||
if err != nil {
|
||||
w.release(conn, false)
|
||||
return zero, fmt.Errorf("marshal chunk: %w", err)
|
||||
}
|
||||
if err := writeMsg(conn, Message{ID: id, Type: TypeChunk, Data: chunkData}); err != nil {
|
||||
w.release(conn, false)
|
||||
return zero, contextErr(ctx, fmt.Errorf("write chunk: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := writeMsg(conn, Message{ID: id, Type: TypeEnd}); err != nil {
|
||||
w.release(conn, false)
|
||||
return zero, contextErr(ctx, fmt.Errorf("write end: %w", err))
|
||||
}
|
||||
|
||||
resp, err := readMsg(conn)
|
||||
if err != nil {
|
||||
w.release(conn, false)
|
||||
return zero, contextErr(ctx, fmt.Errorf("read response: %w", err))
|
||||
}
|
||||
|
||||
stop()
|
||||
w.release(conn, true)
|
||||
|
||||
if resp.Type == TypeError {
|
||||
return zero, fmt.Errorf("remote error: %s", resp.Error)
|
||||
}
|
||||
|
||||
var result R
|
||||
if err := json.Unmarshal(resp.Data, &result); err != nil {
|
||||
return zero, fmt.Errorf("unmarshal result: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func invokeStreamBoth[R any](ctx context.Context, pool *Pool, method string, streamArgIdx int, streamCh reflect.Value, rt reflect.Type, args ...any) (R, error) {
|
||||
var zero R
|
||||
|
||||
jsonArgs := make([]any, len(args))
|
||||
copy(jsonArgs, args)
|
||||
jsonArgs[streamArgIdx] = nil
|
||||
|
||||
argsJSON, err := json.Marshal(jsonArgs)
|
||||
if err != nil {
|
||||
return zero, fmt.Errorf("marshal args: %w", err)
|
||||
}
|
||||
|
||||
conn, w, err := pool.acquire(ctx)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
id := pool.reqID.Add(1)
|
||||
if err := writeMsg(conn, Message{
|
||||
ID: id,
|
||||
Type: TypeCall,
|
||||
Method: method,
|
||||
Args: argsJSON,
|
||||
StreamInput: true,
|
||||
StreamArgIdx: streamArgIdx,
|
||||
}); err != nil {
|
||||
w.release(conn, false)
|
||||
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
|
||||
}
|
||||
|
||||
outCh := reflect.MakeChan(rt, 64)
|
||||
|
||||
// 写入 goroutine:输入 channel → Python chunks
|
||||
go func() {
|
||||
for {
|
||||
val, ok, cancelled := chanRecv(ctx, streamCh)
|
||||
if cancelled || !ok {
|
||||
break
|
||||
}
|
||||
data, err := json.Marshal(val.Interface())
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if err := writeMsg(conn, Message{ID: id, Type: TypeChunk, Data: data}); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
writeMsg(conn, Message{ID: id, Type: TypeEnd}) //nolint
|
||||
}()
|
||||
|
||||
// 读取 goroutine:Python chunks → 输出 channel
|
||||
go func() {
|
||||
stop := watchCtx(ctx, conn, id)
|
||||
defer func() {
|
||||
stop()
|
||||
outCh.Close()
|
||||
w.release(conn, ctx.Err() == nil)
|
||||
}()
|
||||
for {
|
||||
msg, err := readMsg(conn)
|
||||
if err != nil || msg.Type == TypeEnd || msg.Type == TypeError {
|
||||
return
|
||||
}
|
||||
if msg.Type == TypeChunk {
|
||||
val := reflect.New(rt.Elem())
|
||||
if err := json.Unmarshal(msg.Data, val.Interface()); err != nil {
|
||||
return
|
||||
}
|
||||
outCh.Send(val.Elem())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return outCh.Interface().(R), nil
|
||||
}
|
||||
Reference in New Issue
Block a user