feat: 添加 Python→Go 全双工回调支持(call_go)

- 新增 WithHandlers 选项,通过反射将 Go 结构体方法暴露给 Python
- 新增 callback/callback_result 消息类型,支持 Python 在处理中回调 Go
- client 侧新增 readResult,内联处理 callback,复用同一连接避免死锁
- Python 侧新增 call_go[T]() 泛型调用,支持 dataclass 自动构造
- 注入 GOBRIDGE_WORKER_ID/WORKER_COUNT 环境变量,支持多 worker 初始化分工
- 新增示例演示 Go→Python→Go→Python 四层全双工链路
- Python 包版本升至 0.1.1
This commit is contained in:
2026-04-14 13:06:50 +08:00
parent 07e9239ac5
commit b390effd8e
8 changed files with 490 additions and 54 deletions

View File

@@ -53,14 +53,15 @@ func Invoke[R any](ctx context.Context, pool Pool, method string, args ...any) (
// - ctx 取消时先发送 cancel 消息Python 侧收到后注入 InterruptedError
// - 再关闭连接,解除阻塞中的读写操作
//
// write 是调用方提供的互斥写函数,保证与其他写操作不并发。
// 返回 stop 函数,必须在 conn 归还连接池前调用,可安全多次调用。
func watchCtx(ctx context.Context, conn net.Conn, id uint64) (stop func()) {
func watchCtx(ctx context.Context, conn net.Conn, id uint64, write func(Message)) (stop func()) {
done := make(chan struct{})
var once sync.Once
go func() {
select {
case <-ctx.Done():
writeMsg(conn, Message{ID: id, Type: TypeCancel}) //nolint
write(Message{ID: id, Type: TypeCancel})
conn.Close()
case <-done:
}
@@ -89,6 +90,30 @@ func contextErr(ctx context.Context, err error) error {
return err
}
// readResult 读取下一条非 callback 消息,期间内联处理所有 Python→Go 回调。
// 保证 py→go→py→go→... 全链路复用同一条连接,不产生额外线程。
// write 是调用方提供的互斥写函数,与 watchCtx 共享同一把锁,避免并发写。
func readResult(ctx context.Context, conn net.Conn, pool Pool, write func(Message)) (Message, error) {
for {
msg, err := readMsg(conn)
if err != nil {
return Message{}, err
}
if msg.Type != TypeCallback {
return msg, nil
}
result, errStr := pool.callbackDispatch(ctx, msg)
var resp Message
if errStr != "" {
resp = Message{ID: msg.ID, Type: TypeError, Error: errStr}
} else {
data, _ := json.Marshal(result)
resp = Message{ID: msg.ID, Type: TypeCallbackResult, Data: data}
}
write(resp)
}
}
func invokeRegular[R any](ctx context.Context, pool Pool, method string, args ...any) (R, error) {
var zero R
@@ -102,21 +127,22 @@ func invokeRegular[R any](ctx context.Context, pool Pool, method string, args ..
return zero, err
}
var mu sync.Mutex
write := func(msg Message) { mu.Lock(); writeMsg(conn, msg); mu.Unlock() } //nolint
id := pool.nextReqID()
stop := watchCtx(ctx, conn, id)
stop := watchCtx(ctx, conn, id, write)
defer stop()
if err := writeMsg(conn, Message{
ID: id,
Type: TypeCall,
Method: method,
Args: argsJSON,
}); err != nil {
mu.Lock()
err = writeMsg(conn, Message{ID: id, Type: TypeCall, Method: method, Args: argsJSON})
mu.Unlock()
if err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
}
resp, err := readMsg(conn)
resp, err := readResult(ctx, conn, pool, write)
if err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("read response: %w", err))
@@ -163,14 +189,16 @@ func invokeStreamOut[R any](ctx context.Context, pool Pool, method string, rt re
ch := reflect.MakeChan(rt, 64)
go func() {
stop := watchCtx(ctx, conn, id)
var mu sync.Mutex
write := func(msg Message) { mu.Lock(); writeMsg(conn, msg); mu.Unlock() } //nolint
stop := watchCtx(ctx, conn, id, write)
defer func() {
stop()
ch.Close()
w.release(conn, ctx.Err() == nil)
}()
for {
msg, err := readMsg(conn)
msg, err := readResult(ctx, conn, pool, write)
if err != nil || msg.Type == TypeEnd || msg.Type == TypeError {
return
}
@@ -204,11 +232,19 @@ func invokeStreamIn[R any](ctx context.Context, pool Pool, method string, stream
return zero, err
}
var mu sync.Mutex
writeErr := func(msg Message) error {
mu.Lock()
defer mu.Unlock()
return writeMsg(conn, msg)
}
write := func(msg Message) { writeErr(msg) } //nolint
id := pool.nextReqID()
stop := watchCtx(ctx, conn, id)
stop := watchCtx(ctx, conn, id, write)
defer stop()
if err := writeMsg(conn, Message{
if err := writeErr(Message{
ID: id,
Type: TypeCall,
Method: method,
@@ -234,18 +270,18 @@ func invokeStreamIn[R any](ctx context.Context, pool Pool, method string, stream
w.release(conn, false)
return zero, fmt.Errorf("marshal chunk: %w", err)
}
if err := writeMsg(conn, Message{ID: id, Type: TypeChunk, Data: chunkData}); err != nil {
if err := writeErr(Message{ID: id, Type: TypeChunk, Data: chunkData}); err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("write chunk: %w", err))
}
}
if err := writeMsg(conn, Message{ID: id, Type: TypeEnd}); err != nil {
if err := writeErr(Message{ID: id, Type: TypeEnd}); err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("write end: %w", err))
}
resp, err := readMsg(conn)
resp, err := readResult(ctx, conn, pool, write)
if err != nil {
w.release(conn, false)
return zero, contextErr(ctx, fmt.Errorf("read response: %w", err))
@@ -297,6 +333,9 @@ func invokeStreamBoth[R any](ctx context.Context, pool Pool, method string, stre
outCh := reflect.MakeChan(rt, 64)
var mu sync.Mutex
write := func(msg Message) { mu.Lock(); writeMsg(conn, msg); mu.Unlock() } //nolint
// 写入 goroutine输入 channel → Python chunks
go func() {
for {
@@ -308,23 +347,26 @@ func invokeStreamBoth[R any](ctx context.Context, pool Pool, method string, stre
if err != nil {
break
}
if err := writeMsg(conn, Message{ID: id, Type: TypeChunk, Data: data}); err != nil {
mu.Lock()
err = writeMsg(conn, Message{ID: id, Type: TypeChunk, Data: data})
mu.Unlock()
if err != nil {
break
}
}
writeMsg(conn, Message{ID: id, Type: TypeEnd}) //nolint
write(Message{ID: id, Type: TypeEnd})
}()
// 读取 goroutinePython chunks → 输出 channel
// 读取 goroutinePython chunks → 输出 channel,内联处理 callback
go func() {
stop := watchCtx(ctx, conn, id)
stop := watchCtx(ctx, conn, id, write)
defer func() {
stop()
outCh.Close()
w.release(conn, ctx.Err() == nil)
}()
for {
msg, err := readMsg(conn)
msg, err := readResult(ctx, conn, pool, write)
if err != nil || msg.Type == TypeEnd || msg.Type == TypeError {
return
}

View File

@@ -32,6 +32,12 @@ func main() {
ctx := context.Background()
demoPool(ctx, pool)
demoServer(ctx, script)
}
func demoPool(ctx context.Context, pool gobridge.Pool) {
// ── 普通调用 ──────────────────────────────────────────────────────────
sum, err := gobridge.Invoke[int](ctx, pool, "add", 3, 4)
if err != nil {
@@ -139,3 +145,78 @@ func main() {
fmt.Printf(" %+v\n", u)
}
}
// goService 实现 Handler 接口,公开方法自动暴露给 Python 通过 call_go() 调用
type goService struct {
pool gobridge.Pool // 用于 EnrichName 内部再调 Python
}
func (s *goService) Multiply(ctx context.Context, a, b int) (int, error) {
return a * b, nil
}
func (s *goService) Log(msg string) {
fmt.Println("[Go Log]", msg)
}
// EnrichName 内部通过 Invoke 调用 Python 的 to_upper演示 Go→Python→Go→Python 四层链路
func (s *goService) EnrichName(ctx context.Context, name string) (string, error) {
upper, err := gobridge.Invoke[string](ctx, s.pool, "to_upper", name)
if err != nil {
return "", err
}
return "Hello, " + upper + "!", nil
}
func (s *goService) MakeUser(ctx context.Context, uid int) (User, error) {
return User{ID: uid, Name: fmt.Sprintf("user_%d", uid), Score: float64(uid) * 1.5}, nil
}
func demoServer(ctx context.Context, script string) {
fmt.Println("\n── Server 全双工示例 ─────────────────────────────────────────────")
svc := &goService{}
serv, err := gobridge.NewPool(script,
gobridge.WithWorkers(1),
gobridge.WithHandlers(svc),
)
if err != nil {
log.Fatal(err)
}
defer serv.Close()
svc.pool = serv // 注入 pool 供 EnrichName 内部调用
// ── 示例1Python 调用 Go Multiply ───────────────────────────────────────
result, err := gobridge.Invoke[int](ctx, serv, "compute_with_go_mul", 6, 7)
if err != nil {
log.Fatal(err)
}
fmt.Println("compute_with_go_mul(6, 7) =", result) // 42
// ── 示例2流式输出 + Go Log 回调 ────────────────────────────────────────
ch, err := gobridge.Invoke[chan int](ctx, serv, "squared_with_log", 4)
if err != nil {
log.Fatal(err)
}
fmt.Print("squared_with_log(4) =")
for v := range ch {
fmt.Print(" ", v)
}
fmt.Println() // 1 4 9 16
// ── 示例3Go→Python→Go→Python 四层全双工链路 ───────────────────────────
// full_chain("world") → call_go[str]("EnrichName","world") → Invoke to_upper("world") → "WORLD"
// ← "Hello, WORLD!" ← "Hello, WORLD!"
greeting, err := gobridge.Invoke[string](ctx, serv, "full_chain", "world")
if err != nil {
log.Fatal(err)
}
fmt.Println("full_chain(world) =", greeting) // Hello, WORLD!
// ── 示例4call_go[User] 将 Go 返回的 dict 自动构造为 dataclass ──────────
enriched, err := gobridge.Invoke[User](ctx, serv, "get_user_via_go", 12)
if err != nil {
log.Fatal(err)
}
fmt.Printf("get_user_via_go(12) = %+v\n", enriched) // {ID:12 Name:user_12 Score:18 Level:gold}
}

View File

@@ -2,9 +2,24 @@ import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "python"))
from gobridge import expose, run
import dataclasses
import threading
from typing import Iterator
from gobridge import expose, call_go, run, worker_id, worker_count
# ── worker_id / worker_count ──────────────────────────────────────────────────
# 只有 worker 0 才执行一次性初始化(如监听端口、建立长连接等),
# 其余 worker 跳过,避免端口冲突 / 重复连接。
print(f"[worker {worker_id}/{worker_count}] started", flush=True)
if worker_id == 0:
def _init_shared_resource():
# 示例:此处可启动 WebSocket 客户端、监听 TCP 端口等
print(f"[worker {worker_id}] shared resource initialized", flush=True)
threading.Thread(target=_init_shared_resource, daemon=True).start()
# ── 基础类型 ─────────────────────────────────────────────────────────────────
@expose
def add(a: int, b: int) -> int:
@@ -26,12 +41,20 @@ def sum_stream(numbers: Iterator[int]) -> int:
@expose
def double_stream(numbers: Iterator[int]) -> Iterator[int]:
"""双向流输入每个数yield 其平方,对应 Go 侧 Invoke[chan int](c, ctx, "double_stream", inputChan)"""
"""双向流输入每个数yield 其平方"""
for n in numbers:
yield n * n
# ── structdict类型 ──────────────────────────────────────────────────────
# ── structdataclass / dict类型 ───────────────────────────────────────────
@dataclasses.dataclass
class User:
id: int
name: str
score: float
level: str = ""
@expose
def get_user(uid: int) -> dict:
@@ -56,8 +79,6 @@ def enrich_users(users: list) -> list:
return result
# ── struct/slice 流式组合 ────────────────────────────────────────────────────
@expose
def gen_users(count: int) -> Iterator[dict]:
"""流式输出 structyield 多个 User对应 Go 侧 Invoke[chan User]"""
@@ -72,4 +93,48 @@ def process_users(users: Iterator[dict]) -> Iterator[dict]:
yield {"id": u["id"], "name": u["name"].upper(), "score": u["score"] * 2}
# ── Server 全双工示例 ────────────────────────────────────────────────────────
@expose
def compute_with_go_mul(a: int, b: int) -> int:
"""示例1call_go[int] 指定返回类型"""
return call_go[int]("Multiply", a, b)
@expose
def squared_with_log(n: int) -> Iterator[int]:
"""示例2流式输出每次 yield 前 call_go("Log") 回调 Go"""
for i in range(1, n + 1):
call_go("Log", f"yielding {i}² = {i * i}")
yield i * i
@expose
def to_upper(s: str) -> str:
"""辅助方法:被 Go 的 EnrichName handler 内部调用"""
return s.upper()
@expose
def full_chain(name: str) -> str:
"""示例3Go→Python→Go→Python 四层链路
full_chain("world")
→ call_go[str]("EnrichName", "world") # Python 调 Go
→ Invoke[string](ctx, serv, "to_upper", "world") # Go 再调 Python
"WORLD"
"Hello, WORLD!"
"Hello, WORLD!"
"""
return call_go[str]("EnrichName", name)
@expose
def get_user_via_go(uid: int) -> dict:
"""示例4call_go[User] 自动将 Go 返回的 dict 构造为 dataclass 实例"""
user = call_go[User]("MakeUser", uid) # Go 返回 {"id":..,"name":..,"score":..}
user.level = "gold" if user.score >= 10 else "silver"
return dataclasses.asdict(user)
run()

155
pool.go
View File

@@ -2,9 +2,12 @@ package gobridge
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"reflect"
"sync"
"sync/atomic"
)
@@ -19,6 +22,7 @@ type poolConfig struct {
socketDir string
stdout io.Writer
stderr io.Writer
handler any
}
// Option 是 NewPool 的函数选项
@@ -41,7 +45,6 @@ func WithPythonExe(exe string) Option {
}
// WithScriptArgs 设置脚本路径之后的附加参数
// uv 模式示例WithScriptArgs("run") → 执行 uv run <script>
func WithScriptArgs(args ...string) Option {
return func(c *poolConfig) { c.scriptArgs = args }
}
@@ -52,7 +55,6 @@ func WithWorkDir(workDir string) Option {
}
// WithEnv 设置附加环境变量,格式为 "KEY=VALUE"
// 与当前进程环境合并,同名时以此处为准
func WithEnv(env ...string) Option {
return func(c *poolConfig) { c.env = env }
}
@@ -72,29 +74,55 @@ func WithStderr(w io.Writer) Option {
return func(c *poolConfig) { c.stderr = w }
}
// WithHandlers 注册 Go handler struct其所有公开方法自动暴露给 Python 通过 call_go() 调用。
//
// type MyService struct{}
// func (s *MyService) Multiply(ctx context.Context, a, b int) (int, error) { ... }
//
// pool, err := gobridge.NewPool("worker.py", gobridge.WithHandlers(&MyService{}))
// // Python: call_go("Multiply", 3, 4)
func WithHandlers(h any) Option {
return func(c *poolConfig) { c.handler = h }
}
// ── Pool 接口 ────────────────────────────────────────────────────────────────
// Pool 是 Python worker 进程池的接口
type Pool interface {
// Close 关闭所有 worker 进程和连接
Close()
acquire(ctx context.Context) (net.Conn, *worker, error)
nextReqID() uint64
callbackDispatch(ctx context.Context, msg Message) (any, string)
}
// pool 是 Pool 的具体实现
// ── goHandler ────────────────────────────────────────────────────────────────
type goHandler struct {
fn reflect.Value
hasCtx bool
inTypes []reflect.Type
outType reflect.Type
hasErr bool
}
// ── pool内部实现─────────────────────────────────────────────────────────
type pool struct {
workers []*worker
idx atomic.Uint64
reqID atomic.Uint64
mu sync.RWMutex
handlers map[string]goHandler
}
// NewPool 创建并启动进程池
// NewPool 创建并启动进程池
//
// pool, err := gobridge.NewPool("worker.py")
// pool, err := gobridge.NewPool("worker.py", gobridge.WithWorkers(4))
// pool, err := gobridge.NewPool("run",
// gobridge.WithPythonExe("uv"),
// gobridge.WithScriptArgs("worker.py"),
// gobridge.WithWorkDir("./worker"),
// pool, err := gobridge.NewPool("worker.py",
// gobridge.WithHandlers(&MyService{}),
// gobridge.WithWorkers(2),
// )
func NewPool(script string, opts ...Option) (Pool, error) {
cfg := &poolConfig{
@@ -107,7 +135,7 @@ func NewPool(script string, opts ...Option) (Pool, error) {
o(cfg)
}
if script == "" {
return nil, fmt.Errorf("NewPool: script must not be empty")
return nil, fmt.Errorf("gobridge: script must not be empty")
}
cfg.scriptArgs = append([]string{script}, cfg.scriptArgs...)
@@ -115,17 +143,72 @@ func NewPool(script string, opts ...Option) (Pool, error) {
for i := range workers {
w, err := newWorker(cfg, i)
if err != nil {
for j := 0; j < i; j++ {
for j := range i {
workers[j].stop()
}
return nil, fmt.Errorf("create worker %d: %w", i, err)
}
workers[i] = w
}
return &pool{workers: workers}, nil
p := &pool{workers: workers, handlers: make(map[string]goHandler)}
if cfg.handler != nil {
p.bindHandlers(cfg.handler)
}
return p, nil
}
func (p *pool) bindHandlers(h any) {
rv := reflect.ValueOf(h)
rt := rv.Type()
for i := range rt.NumMethod() {
m := rt.Method(i)
if !m.IsExported() {
continue
}
p.bindHandler(m.Name, rv.Method(i))
}
}
func (p *pool) bindHandler(name string, fn reflect.Value) {
errType := reflect.TypeFor[error]()
ctxType := reflect.TypeFor[context.Context]()
ft := fn.Type()
h := goHandler{fn: fn}
startIdx := 0
if ft.NumIn() > 0 && ft.In(0).Implements(ctxType) {
h.hasCtx = true
startIdx = 1
}
for i := startIdx; i < ft.NumIn(); i++ {
h.inTypes = append(h.inTypes, ft.In(i))
}
switch ft.NumOut() {
case 0:
case 1:
if ft.Out(0).Implements(errType) {
h.hasErr = true
} else {
h.outType = ft.Out(0)
}
case 2:
if !ft.Out(1).Implements(errType) {
panic(fmt.Sprintf("gobridge: handler %s: second return value must implement error", name))
}
h.outType = ft.Out(0)
h.hasErr = true
default:
panic(fmt.Sprintf("gobridge: handler %s: must return at most (value, error)", name))
}
p.handlers[name] = h
}
// acquire 以轮询方式从进程池取出一个可用连接
func (p *pool) acquire(ctx context.Context) (net.Conn, *worker, error) {
idx := p.idx.Add(1) % uint64(len(p.workers))
w := p.workers[idx]
@@ -137,7 +220,51 @@ func (p *pool) nextReqID() uint64 {
return p.reqID.Add(1)
}
// Close 关闭所有 worker 进程和连接
func (p *pool) callbackDispatch(ctx context.Context, msg Message) (any, string) {
p.mu.RLock()
h, ok := p.handlers[msg.Method]
p.mu.RUnlock()
if !ok {
return nil, fmt.Sprintf("unknown go handler: %s", msg.Method)
}
var rawArgs []json.RawMessage
if len(msg.Args) > 0 {
if err := json.Unmarshal(msg.Args, &rawArgs); err != nil {
return nil, fmt.Sprintf("unmarshal args: %v", err)
}
}
if len(rawArgs) != len(h.inTypes) {
return nil, fmt.Sprintf("arg count mismatch: want %d got %d", len(h.inTypes), len(rawArgs))
}
in := make([]reflect.Value, 0, len(h.inTypes)+1)
if h.hasCtx {
in = append(in, reflect.ValueOf(ctx))
}
for i, t := range h.inTypes {
v := reflect.New(t)
if err := json.Unmarshal(rawArgs[i], v.Interface()); err != nil {
return nil, fmt.Sprintf("unmarshal arg %d: %v", i, err)
}
in = append(in, v.Elem())
}
out := h.fn.Call(in)
if h.hasErr {
errVal := out[len(out)-1]
if !errVal.IsNil() {
return nil, errVal.Interface().(error).Error()
}
}
if h.outType == nil || len(out) == 0 {
return nil, ""
}
return out[0].Interface(), ""
}
func (p *pool) Close() {
for _, w := range p.workers {
w.stop()

View File

@@ -10,6 +10,8 @@ const (
TypeChunk = "chunk" // 双向: 流数据块
TypeEnd = "end" // 双向: 流结束标记
TypeCancel = "cancel" // Go → Python: 取消正在执行的调用ctx 取消时发送)
TypeCallback = "callback" // Python → Go: 调用 Go 注册方法
TypeCallbackResult = "callback_result" // Go → Python: Go 方法调用结果
)
// Message 是 Go 与 Python 之间传输的消息结构

View File

@@ -3,7 +3,7 @@ gobridge - Python 端库,配合 Go 侧 gobridge 使用
用法::
from gobridge import expose, run
from gobridge import expose, call_go, run
from typing import Iterator
@expose
@@ -27,10 +27,17 @@ gobridge - Python 端库,配合 Go 侧 gobridge 使用
total += i # ctx 取消时这里会抛 InterruptedError
return total
# 全双工:在 handler 内调用 Go 注册的方法(需配合 gobridge.NewServer
@expose
def greet(name: str) -> str:
prefix = call_go("GetPrefix") # 调用 Go 注册的 GetPrefix 方法
return f"{prefix}, {name}!"
run()
"""
import ctypes
import dataclasses
import inspect
import json
import os
@@ -39,9 +46,40 @@ import signal
import socket
import struct
import threading
from typing import Any, Callable, TypeVar
T = TypeVar("T")
_exposed: dict = {}
# ─── Worker 标识 ──────────────────────────────────────────────────────────────
# 每个 worker 进程独有的序号0-based和总数由 Go 启动时通过环境变量注入。
# 用于避免多 worker 场景下的重复初始化(如监听端口、建立长连接等):
#
# if gobridge.worker_id == 0:
# start_websocket_server() # 只有 worker 0 才执行
#
worker_id: int = int(os.environ.get("GOBRIDGE_WORKER_ID", "0"))
worker_count: int = int(os.environ.get("GOBRIDGE_WORKER_COUNT", "1"))
# ─── 线程局部存储 ─────────────────────────────────────────────────────────────
# _local.mux 当前线程正在服务的 _ConnMux由 _dispatch 写入)
_local = threading.local()
# ─── call_go 状态 ─────────────────────────────────────────────────────────────
# _cb_pending: callback id → Queue用于接收 Go 发回的 callback_result
_cb_pending: dict[int, queue.Queue] = {}
_cb_lock = threading.Lock()
_cb_id_counter = 0
_cb_id_lock = threading.Lock()
def _next_cb_id() -> int:
global _cb_id_counter
with _cb_id_lock:
_cb_id_counter += 1
return _cb_id_counter
def expose(fn):
"""装饰器:将函数暴露给 Go 侧调用"""
@@ -49,6 +87,71 @@ def expose(fn):
return fn
def _cast(t: type, value: Any) -> Any:
"""将 Go 返回的 JSON 值转换为指定类型。
- dataclass用 dict 字段构造实例
- 其余类型JSON 已经是正确的 Python 原生类型,直接返回
"""
if value is None or t is type(None):
return value
if dataclasses.is_dataclass(t) and not isinstance(value, t) and isinstance(value, dict):
return t(**value)
return value
class _CallGoType:
"""实现 call_go 和 call_go[Type] 两种调用形式。
用法::
call_go("Multiply", 3, 4) # 返回 Any
call_go[int]("Multiply", 3, 4) # 返回 int类型检查器可感知
call_go[User]("GetUser", uid) # 返回 User dataclass 实例
"""
def __call__(self, method: str, *args) -> Any:
return self._invoke(method, args)
def __getitem__(self, t: type[T]) -> Callable[..., T]:
def _typed(method: str, *args) -> T:
return _cast(t, self._invoke(method, args))
return _typed
def _invoke(self, method: str, args: tuple) -> Any:
"""在 Python handler 内调用 Go 注册的方法(全双工回调)。
通过**同一条连接**发送 callback 消息并同步等待 Go 的回复,
保证整条调用链 Go→Python→Go→... 始终串行执行、不产生额外线程。
只能在 gobridge handler通过 @expose 注册、由 gobridge.NewServer 调用)内使用。
在流式输入 handler 中,应在迭代器耗尽后才调用(流式写入期间 Go 正在发送 chunk
"""
mux: "_ConnMux | None" = getattr(_local, "mux", None)
if mux is None:
raise RuntimeError("call_go() must be called within a gobridge handler")
cb_id = _next_cb_id()
result_q: queue.Queue = queue.Queue(1)
with _cb_lock:
_cb_pending[cb_id] = result_q
try:
mux.write({"id": cb_id, "type": "callback", "method": method, "args": list(args)})
# 阻塞等待 Go 回复ctx 取消时 _raise_in_thread 会注入 InterruptedError
resp = result_q.get()
finally:
with _cb_lock:
_cb_pending.pop(cb_id, None)
if resp is None or resp.get("type") == "error":
raise RuntimeError(resp.get("error", "go callback error") if resp else "connection closed")
return resp.get("data")
call_go = _CallGoType()
def _raise_in_thread(thread_id: int, exc_type: type) -> bool:
"""向指定线程注入异常(在下一条字节码指令时触发)。
@@ -115,11 +218,14 @@ class _ConnMux:
while True:
msg = _read_msg(self.conn)
if msg is None:
# 连接关闭:唤醒主循环,中断所有正在执行的函数
# 连接关闭:唤醒主循环,中断所有正在执行的函数,并唤醒所有 call_go 等待
self.call_q.put(None)
with self._lock:
for tid in self._active_tids.values():
_raise_in_thread(tid, InterruptedError)
with _cb_lock:
for q in _cb_pending.values():
q.put(None) # 通知 call_go 连接已关闭
return
t = msg.get("type")
if t == "call":
@@ -132,6 +238,13 @@ class _ConnMux:
tid = self._active_tids.get(mid)
if tid is not None:
_raise_in_thread(tid, InterruptedError)
elif t in ("callback_result", "error"):
# Go 对 call_go 的回复:按 id 路由到对应等待队列
mid = msg.get("id")
with _cb_lock:
q = _cb_pending.get(mid)
if q is not None:
q.put(msg)
def register(self, msg_id: int, thread_id: int):
with self._lock:
@@ -167,6 +280,8 @@ def _handle_conn(conn: socket.socket):
def _dispatch(mux: _ConnMux, msg: dict):
"""处理一条 call 消息"""
# 将当前连接的 mux 写入线程局部存储,供 call_go() 使用
_local.mux = mux
msg_id = msg["id"]
method = msg.get("method", "")
fn = _exposed.get(method)

View File

@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "gobridge"
version = "0.1.0"
version = "0.1.1"
description = "Python 端库,配合 Go 侧 gobridge 使用"
requires-python = ">=3.10"

View File

@@ -47,7 +47,11 @@ func (w *worker) start() error {
cmd := exec.Command(w.cfg.pythonExe, w.cfg.scriptArgs...)
cmd.Dir = w.cfg.workDir
cmd.Env = append(os.Environ(), w.cfg.env...)
cmd.Env = append(cmd.Env, "GOBRIDGE_SOCKET_PATH="+sockPath)
cmd.Env = append(cmd.Env,
"GOBRIDGE_SOCKET_PATH="+sockPath,
fmt.Sprintf("GOBRIDGE_WORKER_ID=%d", w.id),
fmt.Sprintf("GOBRIDGE_WORKER_COUNT=%d", w.cfg.workers),
)
if w.cfg.stdout != nil {
cmd.Stdout = w.cfg.stdout
} else {