first commit
This commit is contained in:
286
python/gobridge/__init__.py
Normal file
286
python/gobridge/__init__.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
gobridge - Python 端库,配合 Go 侧 gobridge 使用
|
||||
|
||||
用法::
|
||||
|
||||
from gobridge import gobridge, run
|
||||
from typing import Iterator
|
||||
|
||||
@gobridge
|
||||
def add(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
@gobridge
|
||||
def range_gen(start: int, stop: int) -> Iterator[int]:
|
||||
for i in range(start, stop):
|
||||
yield i # 对应 Go 侧 Invoke[chan int]
|
||||
|
||||
@gobridge
|
||||
def sum_stream(numbers: Iterator[int]) -> int:
|
||||
return sum(numbers) # 对应 Go 侧传入 chan int 参数
|
||||
|
||||
# ctx 取消时,框架自动向该线程注入 InterruptedError,无需在函数中检查
|
||||
@gobridge
|
||||
def slow_compute(n: int) -> int:
|
||||
total = 0
|
||||
for i in range(n):
|
||||
total += i # ctx 取消时这里会抛 InterruptedError
|
||||
return total
|
||||
|
||||
run()
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
import socket
|
||||
import struct
|
||||
import threading
|
||||
|
||||
_exposed: dict = {}
|
||||
|
||||
|
||||
def gobridge(fn):
|
||||
"""装饰器:将函数暴露给 Go 侧调用"""
|
||||
_exposed[fn.__name__] = fn
|
||||
return fn
|
||||
|
||||
|
||||
def _raise_in_thread(thread_id: int, exc_type: type) -> bool:
|
||||
"""向指定线程注入异常(在下一条字节码指令时触发)。
|
||||
|
||||
对纯 Python 代码及大多数 I/O 操作有效;
|
||||
长时间不释放 GIL 的 C 扩展(如大规模 numpy 运算)无法被中断。
|
||||
"""
|
||||
ret = ctypes.pythonapi.PyThreadState_SetAsyncExc(
|
||||
ctypes.c_ulong(thread_id),
|
||||
ctypes.py_object(exc_type),
|
||||
)
|
||||
return ret == 1
|
||||
|
||||
|
||||
# ─── 底层 IO ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _recv_exactly(sock: socket.socket, n: int) -> bytes:
|
||||
buf = bytearray()
|
||||
while len(buf) < n:
|
||||
chunk = sock.recv(n - len(buf))
|
||||
if not chunk:
|
||||
raise ConnectionError("connection closed")
|
||||
buf.extend(chunk)
|
||||
return bytes(buf)
|
||||
|
||||
|
||||
def _read_msg(sock: socket.socket):
|
||||
"""读取一条消息,返回 dict;连接断开返回 None"""
|
||||
try:
|
||||
header = _recv_exactly(sock, 4)
|
||||
length = struct.unpack(">I", header)[0]
|
||||
body = _recv_exactly(sock, length)
|
||||
return json.loads(body)
|
||||
except (ConnectionError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def _write_msg(sock: socket.socket, msg: dict):
|
||||
"""将 dict 序列化为 JSON 帧并发送"""
|
||||
data = json.dumps(msg, ensure_ascii=False).encode()
|
||||
header = struct.pack(">I", len(data))
|
||||
sock.sendall(header + data)
|
||||
|
||||
|
||||
# ─── 连接多路分发 ────────────────────────────────────────────────────────────
|
||||
|
||||
class _ConnMux:
|
||||
"""单一 reader 线程将消息分发到对应队列,消除多线程读 socket 的竞争。
|
||||
|
||||
消息路由规则:
|
||||
call → call_q(容量 1,主线程阻塞读取)
|
||||
chunk / end → chunk_q(_ChunkIter 消费)
|
||||
cancel → 直接触发已注册线程的 InterruptedError
|
||||
"""
|
||||
|
||||
def __init__(self, conn: socket.socket):
|
||||
self.conn = conn
|
||||
self.call_q: queue.Queue = queue.Queue(1)
|
||||
self.chunk_q: queue.Queue = queue.Queue()
|
||||
self._active_tids: dict[int, int] = {} # msg_id → thread id
|
||||
self._lock = threading.Lock()
|
||||
threading.Thread(target=self._reader, daemon=True).start()
|
||||
|
||||
def _reader(self):
|
||||
while True:
|
||||
msg = _read_msg(self.conn)
|
||||
if msg is None:
|
||||
# 连接关闭:唤醒主循环,并中断所有正在执行的函数
|
||||
self.call_q.put(None)
|
||||
with self._lock:
|
||||
for tid in self._active_tids.values():
|
||||
_raise_in_thread(tid, InterruptedError)
|
||||
return
|
||||
t = msg.get("type")
|
||||
if t == "call":
|
||||
self.call_q.put(msg)
|
||||
elif t in ("chunk", "end"):
|
||||
self.chunk_q.put(msg)
|
||||
elif t == "cancel":
|
||||
mid = msg.get("id")
|
||||
with self._lock:
|
||||
tid = self._active_tids.get(mid)
|
||||
if tid is not None:
|
||||
_raise_in_thread(tid, InterruptedError)
|
||||
|
||||
def register(self, msg_id: int, thread_id: int):
|
||||
with self._lock:
|
||||
self._active_tids[msg_id] = thread_id
|
||||
|
||||
def unregister(self, msg_id: int):
|
||||
with self._lock:
|
||||
self._active_tids.pop(msg_id, None)
|
||||
|
||||
def write(self, msg: dict):
|
||||
try:
|
||||
_write_msg(self.conn, msg)
|
||||
except OSError:
|
||||
pass # 连接已被 Go 侧关闭
|
||||
|
||||
|
||||
# ─── 连接处理 ───────────────────────────────────────────────────────────────
|
||||
|
||||
def _handle_conn(conn: socket.socket):
|
||||
"""每个连接在独立线程中串行处理消息"""
|
||||
mux = _ConnMux(conn)
|
||||
try:
|
||||
while True:
|
||||
msg = mux.call_q.get()
|
||||
if msg is None:
|
||||
break
|
||||
_dispatch(mux, msg)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _dispatch(mux: _ConnMux, msg: dict):
|
||||
"""处理一条 call 消息"""
|
||||
msg_id = msg["id"]
|
||||
method = msg.get("method", "")
|
||||
fn = _exposed.get(method)
|
||||
|
||||
if fn is None:
|
||||
mux.write({"id": msg_id, "type": "error", "error": f"unknown method: {method}"})
|
||||
return
|
||||
|
||||
args: list = list(msg.get("args") or [])
|
||||
stream_input: bool = msg.get("stream_input", False)
|
||||
stream_arg_idx: int = msg.get("stream_arg_idx", 0)
|
||||
|
||||
# 登记当前线程 id,cancel 消息到达时向该线程注入 InterruptedError
|
||||
mux.register(msg_id, threading.current_thread().ident)
|
||||
|
||||
chunk_iter_instance = None
|
||||
if stream_input:
|
||||
chunk_iter_instance = _ChunkIter(mux)
|
||||
while len(args) <= stream_arg_idx:
|
||||
args.append(None)
|
||||
args[stream_arg_idx] = chunk_iter_instance
|
||||
|
||||
try:
|
||||
result = fn(*args)
|
||||
|
||||
if inspect.isgenerator(result):
|
||||
try:
|
||||
for item in result:
|
||||
mux.write({"id": msg_id, "type": "chunk", "data": item})
|
||||
finally:
|
||||
mux.write({"id": msg_id, "type": "end"})
|
||||
else:
|
||||
mux.write({"id": msg_id, "type": "result", "data": result})
|
||||
|
||||
except InterruptedError:
|
||||
pass # ctx 取消,连接已被 Go 侧关闭,无需回写
|
||||
|
||||
except Exception as e:
|
||||
mux.write({"id": msg_id, "type": "error", "error": str(e)})
|
||||
|
||||
finally:
|
||||
mux.unregister(msg_id)
|
||||
if chunk_iter_instance is not None:
|
||||
chunk_iter_instance.drain()
|
||||
|
||||
|
||||
class _ChunkIter:
|
||||
"""从 _ConnMux.chunk_q 读取数据块的迭代器(对应 Go 侧 chan 输入)。
|
||||
|
||||
所有 socket 读操作已在 _ConnMux reader 线程中完成,此处无竞争。
|
||||
ctx 取消时 _raise_in_thread 会中断阻塞在 queue.get() 的线程。
|
||||
"""
|
||||
|
||||
def __init__(self, mux: _ConnMux):
|
||||
self._mux = mux
|
||||
self._done = False
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._done:
|
||||
raise StopIteration
|
||||
msg = self._mux.chunk_q.get()
|
||||
if msg is None or msg.get("type") == "end":
|
||||
self._done = True
|
||||
raise StopIteration
|
||||
if msg.get("type") == "error":
|
||||
self._done = True
|
||||
raise RuntimeError(msg.get("error", "stream error"))
|
||||
return msg.get("data")
|
||||
|
||||
def drain(self):
|
||||
"""排空未消费的数据块"""
|
||||
if self._done:
|
||||
return
|
||||
try:
|
||||
for _ in self:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ─── 主入口 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def run():
|
||||
"""启动 UDS 服务,阻塞直到进程退出。
|
||||
|
||||
子进程忽略 SIGINT,由 Go 主进程统一处理信号,
|
||||
避免 Ctrl+C 时打印多余的 KeyboardInterrupt 堆栈。
|
||||
"""
|
||||
sock_path = os.environ.get("GOBRIDGE_SOCKET_PATH")
|
||||
if not sock_path:
|
||||
raise RuntimeError("GOBRIDGE_SOCKET_PATH environment variable is not set")
|
||||
|
||||
# 子进程忽略 SIGINT,Go 主进程会主动 kill 子进程
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
server.bind(sock_path)
|
||||
server.listen(64)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
conn, _ = server.accept()
|
||||
except (KeyboardInterrupt, SystemExit, OSError):
|
||||
break
|
||||
t = threading.Thread(target=_handle_conn, args=(conn,), daemon=True)
|
||||
t.start()
|
||||
finally:
|
||||
server.close()
|
||||
try:
|
||||
os.unlink(sock_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
Reference in New Issue
Block a user