first commit
This commit is contained in:
418
README.md
Normal file
418
README.md
Normal file
@@ -0,0 +1,418 @@
|
|||||||
|
# gobridge
|
||||||
|
|
||||||
|
在 Go 与 Python 之间建立双向通信桥接,将 Go 的 `channel` 与 Python 的 `yield` 原生对接,支持普通调用、单向流、双向流。
|
||||||
|
|
||||||
|
底层通过 **Unix Domain Socket (UDS)** 通信,Go 侧维护 **Worker 进程池**,Python 侧以多线程方式并发处理请求。
|
||||||
|
|
||||||
|
## 特性
|
||||||
|
|
||||||
|
- **零配置序列化**:Go struct/slice ↔ Python dict/list 通过 JSON 自动互转
|
||||||
|
- **原生流语义**:Go `chan T` 对应 Python `Iterator[T]`,无需额外 API
|
||||||
|
- **进程池**:Go 自动启动并管理多个 Python 子进程,崩溃后自动重启
|
||||||
|
- **ctx 取消**:Go `context` 取消时自动中断 Python 计算,无需函数内检查
|
||||||
|
- **四种调用模式**:普通、流式输出、流式输入、双向流,同一个 `Invoke` 函数自动推断
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get git.fsdpf.net/go/gobridge
|
||||||
|
```
|
||||||
|
|
||||||
|
Python 端直接复制 `python/gobridge/` 目录到项目中,无需安装依赖(仅用标准库)。
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
**Python 端(worker.py):**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from gobridge import gobridge, run
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
@gobridge
|
||||||
|
def add(a: int, b: int) -> int:
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
run()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Go 端:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
pool, _ := gobridge.NewPool("worker.py")
|
||||||
|
defer pool.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
sum, _ := gobridge.Invoke[int](ctx, pool, "add", 3, 4)
|
||||||
|
fmt.Println(sum) // 7
|
||||||
|
```
|
||||||
|
|
||||||
|
## 四种调用模式
|
||||||
|
|
||||||
|
### 1. 普通调用
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Go
|
||||||
|
sum, err := gobridge.Invoke[int](ctx, pool, "add", 3, 4)
|
||||||
|
|
||||||
|
user, err := gobridge.Invoke[User](ctx, pool, "get_user", 42)
|
||||||
|
|
||||||
|
result, err := gobridge.Invoke[[]User](ctx, pool, "enrich_users", users)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Python
|
||||||
|
@gobridge
|
||||||
|
def add(a: int, b: int) -> int:
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
@gobridge
|
||||||
|
def get_user(uid: int) -> dict:
|
||||||
|
return {"id": uid, "name": f"user_{uid}", "score": uid * 1.5}
|
||||||
|
|
||||||
|
@gobridge
|
||||||
|
def enrich_users(users: list) -> list:
|
||||||
|
for u in users:
|
||||||
|
u["level"] = "gold" if u["score"] >= 10 else "silver"
|
||||||
|
return users
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 流式输出(Python yield → Go channel)
|
||||||
|
|
||||||
|
返回类型为 `chan T` 时自动进入流式输出模式,Python 函数使用 `yield`,Go 侧通过 `range` 消费。
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Go
|
||||||
|
ch, err := gobridge.Invoke[chan int](ctx, pool, "range_gen", 1, 6)
|
||||||
|
for v := range ch {
|
||||||
|
fmt.Println(v) // 1 2 3 4 5
|
||||||
|
}
|
||||||
|
|
||||||
|
userCh, err := gobridge.Invoke[chan User](ctx, pool, "gen_users", 3)
|
||||||
|
for u := range userCh {
|
||||||
|
fmt.Println(u)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Python
|
||||||
|
@gobridge
|
||||||
|
def range_gen(start: int, stop: int) -> Iterator[int]:
|
||||||
|
for i in range(start, stop):
|
||||||
|
yield i
|
||||||
|
|
||||||
|
@gobridge
|
||||||
|
def gen_users(count: int) -> Iterator[dict]:
|
||||||
|
for i in range(1, count + 1):
|
||||||
|
yield {"id": i, "name": f"user_{i}", "score": float(i * 3)}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 流式输入(Go channel → Python Iterator)
|
||||||
|
|
||||||
|
参数中含 `chan T` 且返回非 `chan` 时自动进入流式输入模式。
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Go
|
||||||
|
inputCh := make(chan int, 10)
|
||||||
|
go func() {
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
inputCh <- i
|
||||||
|
}
|
||||||
|
close(inputCh)
|
||||||
|
}()
|
||||||
|
total, err := gobridge.Invoke[int](ctx, pool, "sum_stream", inputCh)
|
||||||
|
fmt.Println(total) // 15
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Python
|
||||||
|
@gobridge
|
||||||
|
def sum_stream(numbers: Iterator[int]) -> int:
|
||||||
|
return sum(numbers)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 双向流(Go channel 输入 + Go channel 输出)
|
||||||
|
|
||||||
|
参数含 `chan T` 且返回类型也为 `chan R` 时自动进入双向流模式。
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Go
|
||||||
|
inCh := make(chan User, 5)
|
||||||
|
go func() {
|
||||||
|
for _, u := range users {
|
||||||
|
inCh <- u
|
||||||
|
}
|
||||||
|
close(inCh)
|
||||||
|
}()
|
||||||
|
outCh, err := gobridge.Invoke[chan User](ctx, pool, "process_users", inCh)
|
||||||
|
for u := range outCh {
|
||||||
|
fmt.Println(u)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Python
|
||||||
|
@gobridge
|
||||||
|
def process_users(users: Iterator[dict]) -> Iterator[dict]:
|
||||||
|
for u in users:
|
||||||
|
yield {"id": u["id"], "name": u["name"].upper(), "score": u["score"] * 2}
|
||||||
|
```
|
||||||
|
|
||||||
|
## ctx 取消
|
||||||
|
|
||||||
|
Go 的 `context` 取消会自动中断 Python 侧的执行,无需在 Python 函数中做任何检查:
|
||||||
|
|
||||||
|
```go
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// 超时或手动 cancel() 后,Python 计算立即中断,返回 context.DeadlineExceeded
|
||||||
|
result, err := gobridge.Invoke[int](ctx, pool, "slow_compute", 1000000)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
@gobridge
|
||||||
|
def slow_compute(n: int) -> int:
|
||||||
|
total = 0
|
||||||
|
for i in range(n):
|
||||||
|
total += i # ctx 取消时此处自动抛出 InterruptedError,无需手动检查
|
||||||
|
return total
|
||||||
|
```
|
||||||
|
|
||||||
|
**实现机制:**
|
||||||
|
1. Go ctx 取消 → 发送 `cancel` 消息给 Python
|
||||||
|
2. Python 单 reader 线程收到 `cancel` → 向执行线程注入 `InterruptedError`(`PyThreadState_SetAsyncExc`)
|
||||||
|
3. Python 函数在下一条字节码指令处中断
|
||||||
|
4. Go 同时关闭连接,解除阻塞的读写操作
|
||||||
|
|
||||||
|
> 限制:长时间不释放 GIL 的 C 扩展(如大规模 numpy 矩阵运算)无法被中断,需等其释放 GIL 后才触发。
|
||||||
|
|
||||||
|
## 进程自动重启
|
||||||
|
|
||||||
|
Python worker 进程崩溃时自动重启,调用方无感知:
|
||||||
|
|
||||||
|
```
|
||||||
|
Python 进程崩溃
|
||||||
|
→ monitor goroutine 检测到退出
|
||||||
|
→ 排空失效连接
|
||||||
|
→ 指数退避重启(100ms → 200ms → ... → 30s)
|
||||||
|
→ 新进程就绪后恢复连接池
|
||||||
|
```
|
||||||
|
|
||||||
|
## 配置
|
||||||
|
|
||||||
|
`NewPool` 使用函数选项模式,第一个参数为脚本路径:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 最简调用
|
||||||
|
pool, err := gobridge.NewPool("worker.py")
|
||||||
|
|
||||||
|
// 完整配置
|
||||||
|
pool, err := gobridge.NewPool("worker.py",
|
||||||
|
gobridge.WithWorkers(4), // Python 进程数量,默认 2
|
||||||
|
gobridge.WithMaxConns(8), // 每进程最大连接数,默认 4
|
||||||
|
gobridge.WithPythonExe("python3"), // 可执行文件,默认 "python3"
|
||||||
|
gobridge.WithWorkDir("/path/to/workdir"), // 工作目录,默认继承当前进程
|
||||||
|
gobridge.WithEnv("PYTHONUNBUFFERED=1", "K=V"), // 附加环境变量,与当前进程环境合并
|
||||||
|
gobridge.WithSocketDir("/var/run/myapp"), // socket 文件目录,默认 /tmp
|
||||||
|
gobridge.WithStdout(os.Stdout), // 子进程 stdout,默认 os.Stdout
|
||||||
|
gobridge.WithStderr(os.Stderr), // 子进程 stderr,默认 os.Stderr
|
||||||
|
)
|
||||||
|
|
||||||
|
// 静默模式:丢弃子进程输出
|
||||||
|
pool, err := gobridge.NewPool("worker.py",
|
||||||
|
gobridge.WithStdout(io.Discard),
|
||||||
|
gobridge.WithStderr(io.Discard),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
| Option | 说明 | 默认值 |
|
||||||
|
|--------|------|--------|
|
||||||
|
| `WithWorkers(n)` | Python 进程数量 | `2` |
|
||||||
|
| `WithMaxConns(n)` | 每进程最大连接数 | `4` |
|
||||||
|
| `WithPythonExe(exe)` | 可执行文件 | `"python3"` |
|
||||||
|
| `WithScriptArgs(args...)` | 脚本路径之后的附加参数 | 无 |
|
||||||
|
| `WithWorkDir(dir)` | 子进程工作目录 | 继承当前进程 |
|
||||||
|
| `WithEnv(kv...)` | 附加环境变量 `"K=V"` | 无 |
|
||||||
|
| `WithSocketDir(dir)` | UDS socket 文件目录 | `"/tmp"` |
|
||||||
|
| `WithStdout(w)` | 子进程标准输出 | `os.Stdout` |
|
||||||
|
| `WithStderr(w)` | 子进程标准错误 | `os.Stderr` |
|
||||||
|
|
||||||
|
## 使用 uv 管理 Python 环境
|
||||||
|
|
||||||
|
推荐使用 [uv](https://github.com/astral-sh/uv) 管理 Python 版本和虚拟环境。
|
||||||
|
|
||||||
|
**方式一:`uv run`(推荐,无需手动激活环境)**
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 等价于执行:uv run worker.py
|
||||||
|
pool, err := gobridge.NewPool("run",
|
||||||
|
gobridge.WithPythonExe("uv"),
|
||||||
|
gobridge.WithScriptArgs("worker.py"),
|
||||||
|
gobridge.WithWorkDir("./worker"), // uv 项目目录(含 pyproject.toml)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**方式二:直接使用虚拟环境的 python**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd worker && uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
```go
|
||||||
|
venvPython, _ := exec.LookPath("worker/.venv/bin/python")
|
||||||
|
pool, err := gobridge.NewPool("worker/worker.py",
|
||||||
|
gobridge.WithPythonExe(venvPython),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**方式三:shell 脚本封装(适合 CI/部署)**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#!/bin/sh
|
||||||
|
# run_worker.sh
|
||||||
|
cd "$(dirname "$0")"
|
||||||
|
exec uv run python worker.py
|
||||||
|
```
|
||||||
|
|
||||||
|
```go
|
||||||
|
pool, err := gobridge.NewPool("./run_worker.sh",
|
||||||
|
gobridge.WithPythonExe("/bin/sh"),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**典型项目结构:**
|
||||||
|
|
||||||
|
```
|
||||||
|
myproject/
|
||||||
|
├── main.go
|
||||||
|
├── go.mod
|
||||||
|
└── worker/
|
||||||
|
├── pyproject.toml
|
||||||
|
├── uv.lock
|
||||||
|
├── .venv/
|
||||||
|
└── worker.py
|
||||||
|
```
|
||||||
|
|
||||||
|
`pyproject.toml`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[project]
|
||||||
|
name = "worker"
|
||||||
|
version = "0.1.0"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = []
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
gobridge = { path = "../../python", editable = true }
|
||||||
|
```
|
||||||
|
|
||||||
|
## 通信协议
|
||||||
|
|
||||||
|
### 整体架构
|
||||||
|
|
||||||
|
```
|
||||||
|
Go 进程
|
||||||
|
┌─────────────────────────────────────────────────────────┐
|
||||||
|
│ Invoke[R](ctx, pool, method, args...) │
|
||||||
|
│ │ │
|
||||||
|
│ ┌────▼─────────────────────────────────────────────┐ │
|
||||||
|
│ │ Pool │ │
|
||||||
|
│ │ workers[0] workers[1] ... workers[N-1] │ │
|
||||||
|
│ │ (轮询选择) │ │
|
||||||
|
│ └────┬─────────────────────────────────────────────┘ │
|
||||||
|
│ │ 每 worker 维护 M 个可复用连接 │
|
||||||
|
└───────┼─────────────────────────────────────────────────┘
|
||||||
|
│ Unix Domain Socket(每 worker 独立 .sock 文件)
|
||||||
|
┌───────▼──────────────┐ ┌──────────────────────────┐
|
||||||
|
│ Python 进程 0 │ │ Python 进程 1 │
|
||||||
|
│ worker.py │ │ worker.py │
|
||||||
|
└──────────────────────┘ └──────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Python Worker 内部结构
|
||||||
|
|
||||||
|
```
|
||||||
|
Python 进程
|
||||||
|
┌──────────────────────────────────────────────────────────────┐
|
||||||
|
│ run() ── UDS server.accept() 循环 │
|
||||||
|
│ │ │
|
||||||
|
│ 每个连接 → 独立线程 _handle_conn() │
|
||||||
|
│ │
|
||||||
|
│ ┌─────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ _handle_conn(连接线程) │ │
|
||||||
|
│ │ │ │
|
||||||
|
│ │ ┌──────────────────────────────────────────────┐ │ │
|
||||||
|
│ │ │ _ConnMux(单 reader 线程) │ │ │
|
||||||
|
│ │ │ │ │ │
|
||||||
|
│ │ │ socket ──► 读消息 │ │ │
|
||||||
|
│ │ │ │ │ │ │
|
||||||
|
│ │ │ ┌────────┼──────────┐ │ │ │
|
||||||
|
│ │ │ ▼ ▼ ▼ │ │ │
|
||||||
|
│ │ │ call_q chunk_q cancel │ │ │
|
||||||
|
│ │ └─────────┬────────┬──────────┼───────────────┘ │ │
|
||||||
|
│ │ │ │ │ │ │
|
||||||
|
│ │ ▼ │ ▼ │ │
|
||||||
|
│ │ 主循环读取 │ PyThreadState_SetAsyncExc │ │
|
||||||
|
│ │ _dispatch() │ → 执行线程抛 InterruptedError│ │
|
||||||
|
│ │ │ │ │ │
|
||||||
|
│ │ ┌──────▼──────┐ │ │ │
|
||||||
|
│ │ │ @gobridge fn│ │ │ │
|
||||||
|
│ │ │ │ │ │ │
|
||||||
|
│ │ │ 普通函数 │ │ │ │
|
||||||
|
│ │ │ return val ──────────────► result/error │ │
|
||||||
|
│ │ │ │ │ │ │
|
||||||
|
│ │ │ 生成器函数 │ │ │ │
|
||||||
|
│ │ │ yield val ───────────────► chunk × N │ │
|
||||||
|
│ │ │ │ │ + end │ │
|
||||||
|
│ │ │ 流式输入 │ │ │ │
|
||||||
|
│ │ │ Iterator ◄─┘ │ ← chunk_q │ │
|
||||||
|
│ │ └─────────────┘ │ │
|
||||||
|
│ └─────────────────────────────────────────────────────┘ │
|
||||||
|
└──────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
**消息帧:** `[4字节大端长度][JSON载荷]`
|
||||||
|
|
||||||
|
**消息类型:**
|
||||||
|
|
||||||
|
| type | 方向 | 含义 |
|
||||||
|
|----------|--------------|------------------------------|
|
||||||
|
| `call` | Go → Python | 调用请求 |
|
||||||
|
| `result` | Python → Go | 普通返回值 |
|
||||||
|
| `chunk` | 双向 | 流数据块 |
|
||||||
|
| `end` | 双向 | 流结束标记 |
|
||||||
|
| `error` | 双向 | 错误响应 |
|
||||||
|
| `cancel` | Go → Python | 取消请求,触发 InterruptedError |
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
gobridge/
|
||||||
|
├── protocol.go # Message 结构与类型常量
|
||||||
|
├── framing.go # 帧读写(4字节长度前缀 + JSON)
|
||||||
|
├── worker.go # Python 子进程管理 + UDS 连接池 + 自动重启
|
||||||
|
├── pool.go # 多进程池(轮询负载均衡)+ Option 函数
|
||||||
|
├── client.go # Invoke[R] 泛型函数(四种模式自动推断)
|
||||||
|
├── example/
|
||||||
|
│ ├── main.go # 完整调用示例
|
||||||
|
│ └── worker.py # Python 函数示例
|
||||||
|
└── python/
|
||||||
|
└── gobridge/
|
||||||
|
└── __init__.py # Python 库(expose、run、_ConnMux)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 类型对应关系
|
||||||
|
|
||||||
|
| Go 类型 | Python 类型 |
|
||||||
|
|-----------|---------------|
|
||||||
|
| `int` | `int` |
|
||||||
|
| `float64` | `float` |
|
||||||
|
| `string` | `str` |
|
||||||
|
| `bool` | `bool` |
|
||||||
|
| `struct` | `dict` |
|
||||||
|
| `[]T` | `list` |
|
||||||
|
| `chan T` | `Iterator[T]` |
|
||||||
|
|
||||||
|
## 参考
|
||||||
|
|
||||||
|
本项目的进程池、UDS 通信、帧协议设计参考自 [pyproc](https://github.com/YuminosukeSato/pyproc),在此基础上增加了 Go channel 与 Python yield 的流式对接及 ctx 取消支持。
|
||||||
342
client.go
Normal file
342
client.go
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
// Package gobridge 提供 Go 与 Python 之间的双向通信桥接,
|
||||||
|
// 支持普通调用、流式输出、流式输入和双向流四种模式。
|
||||||
|
package gobridge
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Invoke 调用 Python 暴露的函数,支持四种模式:
|
||||||
|
//
|
||||||
|
// 普通调用: Invoke[int](ctx, pool, "Add", 3, 4)
|
||||||
|
// 流式输出: Invoke[chan int](ctx, pool, "RangeGen", 1, 10) // Python yield → Go channel
|
||||||
|
// 流式输入: Invoke[int](ctx, pool, "SumStream", inputChan) // Go channel → Python Iterator
|
||||||
|
// 双向流: Invoke[chan int](ctx, pool, "Transform", inputChan) // 两端均为流
|
||||||
|
//
|
||||||
|
// ctx 取消时会立即中断与 Python 的通信并返回 ctx.Err()。
|
||||||
|
// 对于流式输出/双向流,ctx 取消会关闭返回的 channel。
|
||||||
|
func Invoke[R any](ctx context.Context, pool *Pool, method string, args ...any) (R, error) {
|
||||||
|
rt := reflect.TypeFor[R]()
|
||||||
|
|
||||||
|
// 查找 chan 类型的输入参数
|
||||||
|
streamArgIdx := -1
|
||||||
|
var streamCh reflect.Value
|
||||||
|
for i, arg := range args {
|
||||||
|
if arg != nil {
|
||||||
|
rv := reflect.ValueOf(arg)
|
||||||
|
if rv.Kind() == reflect.Chan {
|
||||||
|
streamArgIdx = i
|
||||||
|
streamCh = rv
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case rt.Kind() == reflect.Chan && streamArgIdx >= 0:
|
||||||
|
return invokeStreamBoth[R](ctx, pool, method, streamArgIdx, streamCh, rt, args...)
|
||||||
|
case rt.Kind() == reflect.Chan:
|
||||||
|
return invokeStreamOut[R](ctx, pool, method, rt, args...)
|
||||||
|
case streamArgIdx >= 0:
|
||||||
|
return invokeStreamIn[R](ctx, pool, method, streamArgIdx, streamCh, args...)
|
||||||
|
default:
|
||||||
|
return invokeRegular[R](ctx, pool, method, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// watchCtx 启动一个 goroutine 监听 ctx:
|
||||||
|
// - ctx 取消时先发送 cancel 消息(Python 侧收到后注入 InterruptedError)
|
||||||
|
// - 再关闭连接,解除阻塞中的读写操作
|
||||||
|
//
|
||||||
|
// 返回 stop 函数,必须在 conn 归还连接池前调用,可安全多次调用。
|
||||||
|
func watchCtx(ctx context.Context, conn net.Conn, id uint64) (stop func()) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
var once sync.Once
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
writeMsg(conn, Message{ID: id, Type: TypeCancel}) //nolint
|
||||||
|
conn.Close()
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return func() { once.Do(func() { close(done) }) }
|
||||||
|
}
|
||||||
|
|
||||||
|
// chanRecv 从 ch 接收一个值,同时监听 ctx.Done()。
|
||||||
|
// 返回 (值, channel是否open, ctx是否已取消)。
|
||||||
|
func chanRecv(ctx context.Context, ch reflect.Value) (reflect.Value, bool, bool) {
|
||||||
|
chosen, val, ok := reflect.Select([]reflect.SelectCase{
|
||||||
|
{Dir: reflect.SelectRecv, Chan: ch},
|
||||||
|
{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ctx.Done())},
|
||||||
|
})
|
||||||
|
if chosen == 1 {
|
||||||
|
return reflect.Value{}, false, true
|
||||||
|
}
|
||||||
|
return val, ok, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// contextErr 在 io 错误时优先返回 ctx 的错误原因
|
||||||
|
func contextErr(ctx context.Context, err error) error {
|
||||||
|
if e := ctx.Err(); e != nil {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func invokeRegular[R any](ctx context.Context, pool *Pool, method string, args ...any) (R, error) {
|
||||||
|
var zero R
|
||||||
|
|
||||||
|
argsJSON, err := json.Marshal(args)
|
||||||
|
if err != nil {
|
||||||
|
return zero, fmt.Errorf("marshal args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, w, err := pool.acquire(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
|
||||||
|
id := pool.reqID.Add(1)
|
||||||
|
stop := watchCtx(ctx, conn, id)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
if err := writeMsg(conn, Message{
|
||||||
|
ID: id,
|
||||||
|
Type: TypeCall,
|
||||||
|
Method: method,
|
||||||
|
Args: argsJSON,
|
||||||
|
}); err != nil {
|
||||||
|
w.release(conn, false)
|
||||||
|
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := readMsg(conn)
|
||||||
|
if err != nil {
|
||||||
|
w.release(conn, false)
|
||||||
|
return zero, contextErr(ctx, fmt.Errorf("read response: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
stop()
|
||||||
|
w.release(conn, true)
|
||||||
|
|
||||||
|
if resp.Type == TypeError {
|
||||||
|
return zero, fmt.Errorf("remote error: %s", resp.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result R
|
||||||
|
if err := json.Unmarshal(resp.Data, &result); err != nil {
|
||||||
|
return zero, fmt.Errorf("unmarshal result: %w", err)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func invokeStreamOut[R any](ctx context.Context, pool *Pool, method string, rt reflect.Type, args ...any) (R, error) {
|
||||||
|
var zero R
|
||||||
|
|
||||||
|
argsJSON, err := json.Marshal(args)
|
||||||
|
if err != nil {
|
||||||
|
return zero, fmt.Errorf("marshal args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, w, err := pool.acquire(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
|
||||||
|
id := pool.reqID.Add(1)
|
||||||
|
if err := writeMsg(conn, Message{
|
||||||
|
ID: id,
|
||||||
|
Type: TypeCall,
|
||||||
|
Method: method,
|
||||||
|
Args: argsJSON,
|
||||||
|
}); err != nil {
|
||||||
|
w.release(conn, false)
|
||||||
|
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := reflect.MakeChan(rt, 64)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
stop := watchCtx(ctx, conn, id)
|
||||||
|
defer func() {
|
||||||
|
stop()
|
||||||
|
ch.Close()
|
||||||
|
w.release(conn, ctx.Err() == nil)
|
||||||
|
}()
|
||||||
|
for {
|
||||||
|
msg, err := readMsg(conn)
|
||||||
|
if err != nil || msg.Type == TypeEnd || msg.Type == TypeError {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msg.Type == TypeChunk {
|
||||||
|
val := reflect.New(rt.Elem())
|
||||||
|
if err := json.Unmarshal(msg.Data, val.Interface()); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ch.Send(val.Elem())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ch.Interface().(R), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func invokeStreamIn[R any](ctx context.Context, pool *Pool, method string, streamArgIdx int, streamCh reflect.Value, args ...any) (R, error) {
|
||||||
|
var zero R
|
||||||
|
|
||||||
|
jsonArgs := make([]any, len(args))
|
||||||
|
copy(jsonArgs, args)
|
||||||
|
jsonArgs[streamArgIdx] = nil
|
||||||
|
|
||||||
|
argsJSON, err := json.Marshal(jsonArgs)
|
||||||
|
if err != nil {
|
||||||
|
return zero, fmt.Errorf("marshal args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, w, err := pool.acquire(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
|
||||||
|
id := pool.reqID.Add(1)
|
||||||
|
stop := watchCtx(ctx, conn, id)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
if err := writeMsg(conn, Message{
|
||||||
|
ID: id,
|
||||||
|
Type: TypeCall,
|
||||||
|
Method: method,
|
||||||
|
Args: argsJSON,
|
||||||
|
StreamInput: true,
|
||||||
|
StreamArgIdx: streamArgIdx,
|
||||||
|
}); err != nil {
|
||||||
|
w.release(conn, false)
|
||||||
|
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
val, ok, cancelled := chanRecv(ctx, streamCh)
|
||||||
|
if cancelled {
|
||||||
|
w.release(conn, false)
|
||||||
|
return zero, ctx.Err()
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
chunkData, err := json.Marshal(val.Interface())
|
||||||
|
if err != nil {
|
||||||
|
w.release(conn, false)
|
||||||
|
return zero, fmt.Errorf("marshal chunk: %w", err)
|
||||||
|
}
|
||||||
|
if err := writeMsg(conn, Message{ID: id, Type: TypeChunk, Data: chunkData}); err != nil {
|
||||||
|
w.release(conn, false)
|
||||||
|
return zero, contextErr(ctx, fmt.Errorf("write chunk: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := writeMsg(conn, Message{ID: id, Type: TypeEnd}); err != nil {
|
||||||
|
w.release(conn, false)
|
||||||
|
return zero, contextErr(ctx, fmt.Errorf("write end: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := readMsg(conn)
|
||||||
|
if err != nil {
|
||||||
|
w.release(conn, false)
|
||||||
|
return zero, contextErr(ctx, fmt.Errorf("read response: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
stop()
|
||||||
|
w.release(conn, true)
|
||||||
|
|
||||||
|
if resp.Type == TypeError {
|
||||||
|
return zero, fmt.Errorf("remote error: %s", resp.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result R
|
||||||
|
if err := json.Unmarshal(resp.Data, &result); err != nil {
|
||||||
|
return zero, fmt.Errorf("unmarshal result: %w", err)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func invokeStreamBoth[R any](ctx context.Context, pool *Pool, method string, streamArgIdx int, streamCh reflect.Value, rt reflect.Type, args ...any) (R, error) {
|
||||||
|
var zero R
|
||||||
|
|
||||||
|
jsonArgs := make([]any, len(args))
|
||||||
|
copy(jsonArgs, args)
|
||||||
|
jsonArgs[streamArgIdx] = nil
|
||||||
|
|
||||||
|
argsJSON, err := json.Marshal(jsonArgs)
|
||||||
|
if err != nil {
|
||||||
|
return zero, fmt.Errorf("marshal args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, w, err := pool.acquire(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
|
||||||
|
id := pool.reqID.Add(1)
|
||||||
|
if err := writeMsg(conn, Message{
|
||||||
|
ID: id,
|
||||||
|
Type: TypeCall,
|
||||||
|
Method: method,
|
||||||
|
Args: argsJSON,
|
||||||
|
StreamInput: true,
|
||||||
|
StreamArgIdx: streamArgIdx,
|
||||||
|
}); err != nil {
|
||||||
|
w.release(conn, false)
|
||||||
|
return zero, contextErr(ctx, fmt.Errorf("write call: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
outCh := reflect.MakeChan(rt, 64)
|
||||||
|
|
||||||
|
// 写入 goroutine:输入 channel → Python chunks
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
val, ok, cancelled := chanRecv(ctx, streamCh)
|
||||||
|
if cancelled || !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(val.Interface())
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err := writeMsg(conn, Message{ID: id, Type: TypeChunk, Data: data}); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writeMsg(conn, Message{ID: id, Type: TypeEnd}) //nolint
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 读取 goroutine:Python chunks → 输出 channel
|
||||||
|
go func() {
|
||||||
|
stop := watchCtx(ctx, conn, id)
|
||||||
|
defer func() {
|
||||||
|
stop()
|
||||||
|
outCh.Close()
|
||||||
|
w.release(conn, ctx.Err() == nil)
|
||||||
|
}()
|
||||||
|
for {
|
||||||
|
msg, err := readMsg(conn)
|
||||||
|
if err != nil || msg.Type == TypeEnd || msg.Type == TypeError {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msg.Type == TypeChunk {
|
||||||
|
val := reflect.New(rt.Elem())
|
||||||
|
if err := json.Unmarshal(msg.Data, val.Interface()); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
outCh.Send(val.Elem())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return outCh.Interface().(R), nil
|
||||||
|
}
|
||||||
141
example/main.go
Normal file
141
example/main.go
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"git.fsdpf.net/go/gobridge"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Score float64 `json:"score"`
|
||||||
|
Level string `json:"level,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
_, file, _, _ := runtime.Caller(0)
|
||||||
|
script := filepath.Join(filepath.Dir(file), "worker.py")
|
||||||
|
|
||||||
|
pool, err := gobridge.NewPool(script,
|
||||||
|
gobridge.WithWorkers(2),
|
||||||
|
gobridge.WithMaxConns(4),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer pool.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// ── 普通调用 ──────────────────────────────────────────────────────────
|
||||||
|
sum, err := gobridge.Invoke[int](ctx, pool, "add", 3, 4)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
fmt.Println("add(3, 4) =", sum) // 7
|
||||||
|
|
||||||
|
// ── 流式输出:Python yield → Go channel ──────────────────────────────
|
||||||
|
ch, err := gobridge.Invoke[chan int](ctx, pool, "range_gen", 1, 6)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
fmt.Print("range_gen(1, 6) =")
|
||||||
|
for v := range ch {
|
||||||
|
fmt.Print(" ", v)
|
||||||
|
}
|
||||||
|
fmt.Println() // 1 2 3 4 5
|
||||||
|
|
||||||
|
// ── 流式输入:Go channel → Python Iterator ───────────────────────────
|
||||||
|
inputCh := make(chan int, 10)
|
||||||
|
go func() {
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
inputCh <- i
|
||||||
|
}
|
||||||
|
close(inputCh)
|
||||||
|
}()
|
||||||
|
total, err := gobridge.Invoke[int](ctx, pool, "sum_stream", inputCh)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
fmt.Println("sum_stream(1..5) =", total) // 15
|
||||||
|
|
||||||
|
// ── 双向流:Go channel 输入 + Go channel 输出 ────────────────────────
|
||||||
|
inputCh2 := make(chan int, 10)
|
||||||
|
go func() {
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
inputCh2 <- i
|
||||||
|
}
|
||||||
|
close(inputCh2)
|
||||||
|
}()
|
||||||
|
outCh, err := gobridge.Invoke[chan int](ctx, pool, "double_stream", inputCh2)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
fmt.Print("double_stream(1..5) =")
|
||||||
|
for v := range outCh {
|
||||||
|
fmt.Print(" ", v)
|
||||||
|
}
|
||||||
|
fmt.Println() // 1 4 9 16 25
|
||||||
|
|
||||||
|
// ── struct 普通调用 ───────────────────────────────────────────────────
|
||||||
|
user, err := gobridge.Invoke[User](ctx, pool, "get_user", 42)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("get_user(42) = %+v\n", user)
|
||||||
|
|
||||||
|
// ── slice 输入,返回标量 ───────────────────────────────────────────────
|
||||||
|
users := []User{
|
||||||
|
{ID: 1, Name: "alice", Score: 5.0},
|
||||||
|
{ID: 2, Name: "bob", Score: 8.0},
|
||||||
|
{ID: 3, Name: "carol", Score: 12.0},
|
||||||
|
}
|
||||||
|
scoreSum, err := gobridge.Invoke[float64](ctx, pool, "total_score", users)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("total_score([alice,bob,carol]) = %.1f\n", scoreSum)
|
||||||
|
|
||||||
|
// ── slice 输入输出 ─────────────────────────────────────────────────────
|
||||||
|
enriched, err := gobridge.Invoke[[]User](ctx, pool, "enrich_users", users)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
fmt.Println("enrich_users:")
|
||||||
|
for _, u := range enriched {
|
||||||
|
fmt.Printf(" %+v\n", u)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 流式输出 struct:Python yield User → Go chan User ─────────────────
|
||||||
|
userCh, err := gobridge.Invoke[chan User](ctx, pool, "gen_users", 3)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
fmt.Print("gen_users(3) =")
|
||||||
|
for u := range userCh {
|
||||||
|
fmt.Printf(" {%d %s %.0f}", u.ID, u.Name, u.Score)
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
// ── 双向流 struct:Go chan User 输入 → Python 处理 → Go chan User 输出 ─
|
||||||
|
inCh := make(chan User, 5)
|
||||||
|
go func() {
|
||||||
|
for _, u := range users {
|
||||||
|
inCh <- u
|
||||||
|
}
|
||||||
|
close(inCh)
|
||||||
|
}()
|
||||||
|
procCh, err := gobridge.Invoke[chan User](ctx, pool, "process_users", inCh)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
fmt.Println("process_users:")
|
||||||
|
for u := range procCh {
|
||||||
|
fmt.Printf(" %+v\n", u)
|
||||||
|
}
|
||||||
|
}
|
||||||
75
example/worker.py
Normal file
75
example/worker.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "python"))
|
||||||
|
|
||||||
|
from gobridge import gobridge, run
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
|
||||||
|
@gobridge
|
||||||
|
def add(a: int, b: int) -> int:
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
|
||||||
|
@gobridge
|
||||||
|
def range_gen(start: int, stop: int) -> Iterator[int]:
|
||||||
|
"""流式输出:对应 Go 侧 Invoke[chan int]"""
|
||||||
|
for i in range(start, stop):
|
||||||
|
yield i
|
||||||
|
|
||||||
|
|
||||||
|
@gobridge
|
||||||
|
def sum_stream(numbers: Iterator[int]) -> int:
|
||||||
|
"""流式输入:对应 Go 侧传入 chan int 参数"""
|
||||||
|
return sum(numbers)
|
||||||
|
|
||||||
|
|
||||||
|
@gobridge
|
||||||
|
def double_stream(numbers: Iterator[int]) -> Iterator[int]:
|
||||||
|
"""双向流:输入每个数,yield 其平方,对应 Go 侧 Invoke[chan int](c, ctx, "double_stream", inputChan)"""
|
||||||
|
for n in numbers:
|
||||||
|
yield n * n
|
||||||
|
|
||||||
|
|
||||||
|
# ── struct(dict)类型 ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@gobridge
|
||||||
|
def get_user(uid: int) -> dict:
|
||||||
|
"""普通调用:返回一个 struct(Go 对应 User)"""
|
||||||
|
return {"id": uid, "name": f"user_{uid}", "score": uid * 1.5}
|
||||||
|
|
||||||
|
|
||||||
|
@gobridge
|
||||||
|
def total_score(users: list) -> float:
|
||||||
|
"""slice 输入:接收 []User,返回总分"""
|
||||||
|
return sum(u["score"] for u in users)
|
||||||
|
|
||||||
|
|
||||||
|
@gobridge
|
||||||
|
def enrich_users(users: list) -> list:
|
||||||
|
"""slice 输入输出:为每个 user 追加 level 字段"""
|
||||||
|
result = []
|
||||||
|
for u in users:
|
||||||
|
u = dict(u)
|
||||||
|
u["level"] = "gold" if u["score"] >= 10 else "silver"
|
||||||
|
result.append(u)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ── struct/slice 流式组合 ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@gobridge
|
||||||
|
def gen_users(count: int) -> Iterator[dict]:
|
||||||
|
"""流式输出 struct:yield 多个 User,对应 Go 侧 Invoke[chan User]"""
|
||||||
|
for i in range(1, count + 1):
|
||||||
|
yield {"id": i, "name": f"user_{i}", "score": float(i * 3)}
|
||||||
|
|
||||||
|
|
||||||
|
@gobridge
|
||||||
|
def process_users(users: Iterator[dict]) -> Iterator[dict]:
|
||||||
|
"""双向流 struct:输入流式 User,yield 处理后的 User"""
|
||||||
|
for u in users:
|
||||||
|
yield {"id": u["id"], "name": u["name"].upper(), "score": u["score"] * 2}
|
||||||
|
|
||||||
|
|
||||||
|
run()
|
||||||
39
framing.go
Normal file
39
framing.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package gobridge
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// writeMsg 将消息以 [4字节长度][JSON载荷] 格式写入连接
|
||||||
|
func writeMsg(conn net.Conn, msg Message) error {
|
||||||
|
data, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal message: %w", err)
|
||||||
|
}
|
||||||
|
header := make([]byte, 4)
|
||||||
|
binary.BigEndian.PutUint32(header, uint32(len(data)))
|
||||||
|
_, err = conn.Write(append(header, data...))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// readMsg 从连接中读取一条消息
|
||||||
|
func readMsg(conn net.Conn) (Message, error) {
|
||||||
|
header := make([]byte, 4)
|
||||||
|
if _, err := io.ReadFull(conn, header); err != nil {
|
||||||
|
return Message{}, fmt.Errorf("read header: %w", err)
|
||||||
|
}
|
||||||
|
length := binary.BigEndian.Uint32(header)
|
||||||
|
body := make([]byte, length)
|
||||||
|
if _, err := io.ReadFull(conn, body); err != nil {
|
||||||
|
return Message{}, fmt.Errorf("read body: %w", err)
|
||||||
|
}
|
||||||
|
var msg Message
|
||||||
|
if err := json.Unmarshal(body, &msg); err != nil {
|
||||||
|
return Message{}, fmt.Errorf("unmarshal message: %w", err)
|
||||||
|
}
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
133
pool.go
Normal file
133
pool.go
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
package gobridge
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// poolConfig 是进程池内部配置,通过 Option 函数填充
|
||||||
|
type poolConfig struct {
|
||||||
|
workers int
|
||||||
|
maxConnsPerWorker int
|
||||||
|
pythonExe string
|
||||||
|
scriptArgs []string
|
||||||
|
workDir string
|
||||||
|
env []string
|
||||||
|
socketDir string
|
||||||
|
stdout io.Writer
|
||||||
|
stderr io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option 是 NewPool 的函数选项
|
||||||
|
type Option func(*poolConfig)
|
||||||
|
|
||||||
|
// WithWorkers 设置 Python 进程数量(默认 2)
|
||||||
|
func WithWorkers(n int) Option {
|
||||||
|
return func(c *poolConfig) { c.workers = n }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMaxConns 设置每个进程的最大连接数(默认 4)
|
||||||
|
func WithMaxConns(n int) Option {
|
||||||
|
return func(c *poolConfig) { c.maxConnsPerWorker = n }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPythonExe 设置 Python 可执行文件(默认 "python3")
|
||||||
|
// uv 模式:WithPythonExe("uv"), WithScriptArgs("run")
|
||||||
|
func WithPythonExe(exe string) Option {
|
||||||
|
return func(c *poolConfig) { c.pythonExe = exe }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithScriptArgs 设置脚本路径之后的附加参数
|
||||||
|
// uv 模式示例:WithScriptArgs("run") → 执行 uv run <script>
|
||||||
|
func WithScriptArgs(args ...string) Option {
|
||||||
|
return func(c *poolConfig) { c.scriptArgs = args }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithWorkDir 设置子进程工作目录(默认继承当前进程)
|
||||||
|
func WithWorkDir(workDir string) Option {
|
||||||
|
return func(c *poolConfig) { c.workDir = workDir }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithEnv 设置附加环境变量,格式为 "KEY=VALUE"
|
||||||
|
// 与当前进程环境合并,同名时以此处为准
|
||||||
|
func WithEnv(env ...string) Option {
|
||||||
|
return func(c *poolConfig) { c.env = env }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSocketDir 设置 UDS socket 文件目录(默认 /tmp)
|
||||||
|
func WithSocketDir(dir string) Option {
|
||||||
|
return func(c *poolConfig) { c.socketDir = dir }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithStdout 设置子进程标准输出目标(默认 os.Stdout,传 io.Discard 可静默)
|
||||||
|
func WithStdout(w io.Writer) Option {
|
||||||
|
return func(c *poolConfig) { c.stdout = w }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithStderr 设置子进程标准错误目标(默认 os.Stderr,传 io.Discard 可静默)
|
||||||
|
func WithStderr(w io.Writer) Option {
|
||||||
|
return func(c *poolConfig) { c.stderr = w }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pool 管理多个 Python worker 进程及其连接池
|
||||||
|
type Pool struct {
|
||||||
|
workers []*worker
|
||||||
|
idx atomic.Uint64
|
||||||
|
reqID atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPool 创建并启动进程池
|
||||||
|
//
|
||||||
|
// pool, err := gobridge.NewPool("worker.py")
|
||||||
|
// pool, err := gobridge.NewPool("worker.py", gobridge.WithWorkers(4))
|
||||||
|
// pool, err := gobridge.NewPool("run",
|
||||||
|
// gobridge.WithPythonExe("uv"),
|
||||||
|
// gobridge.WithScriptArgs("worker.py"),
|
||||||
|
// gobridge.WithWorkDir("./worker"),
|
||||||
|
// )
|
||||||
|
func NewPool(script string, opts ...Option) (*Pool, error) {
|
||||||
|
cfg := &poolConfig{
|
||||||
|
workers: 2,
|
||||||
|
maxConnsPerWorker: 4,
|
||||||
|
pythonExe: "python3",
|
||||||
|
socketDir: "/tmp",
|
||||||
|
}
|
||||||
|
for _, o := range opts {
|
||||||
|
o(cfg)
|
||||||
|
}
|
||||||
|
if script == "" {
|
||||||
|
return nil, fmt.Errorf("NewPool: script must not be empty")
|
||||||
|
}
|
||||||
|
cfg.scriptArgs = append([]string{script}, cfg.scriptArgs...)
|
||||||
|
|
||||||
|
workers := make([]*worker, cfg.workers)
|
||||||
|
for i := range workers {
|
||||||
|
w, err := newWorker(cfg, i)
|
||||||
|
if err != nil {
|
||||||
|
for j := 0; j < i; j++ {
|
||||||
|
workers[j].stop()
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("create worker %d: %w", i, err)
|
||||||
|
}
|
||||||
|
workers[i] = w
|
||||||
|
}
|
||||||
|
return &Pool{workers: workers}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// acquire 以轮询方式从进程池取出一个可用连接
|
||||||
|
func (p *Pool) acquire(ctx context.Context) (net.Conn, *worker, error) {
|
||||||
|
idx := p.idx.Add(1) % uint64(len(p.workers))
|
||||||
|
w := p.workers[idx]
|
||||||
|
conn, err := w.acquire(ctx)
|
||||||
|
return conn, w, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 关闭所有 worker 进程和连接
|
||||||
|
func (p *Pool) Close() {
|
||||||
|
for _, w := range p.workers {
|
||||||
|
w.stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
26
protocol.go
Normal file
26
protocol.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package gobridge
|
||||||
|
|
||||||
|
import "encoding/json"
|
||||||
|
|
||||||
|
// 消息类型常量
|
||||||
|
const (
|
||||||
|
TypeCall = "call" // Go → Python: 调用请求
|
||||||
|
TypeResult = "result" // Python → Go: 调用结果
|
||||||
|
TypeError = "error" // 双向: 错误响应
|
||||||
|
TypeChunk = "chunk" // 双向: 流数据块
|
||||||
|
TypeEnd = "end" // 双向: 流结束标记
|
||||||
|
TypeCancel = "cancel" // Go → Python: 取消正在执行的调用(ctx 取消时发送)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Message 是 Go 与 Python 之间传输的消息结构
|
||||||
|
// 使用 4 字节大端长度前缀 + JSON 编码
|
||||||
|
type Message struct {
|
||||||
|
ID uint64 `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Method string `json:"method,omitempty"`
|
||||||
|
Args json.RawMessage `json:"args,omitempty"`
|
||||||
|
StreamInput bool `json:"stream_input,omitempty"` // 是否有流式输入参数
|
||||||
|
StreamArgIdx int `json:"stream_arg_idx,omitempty"` // 流式参数在 args 中的下标
|
||||||
|
Data json.RawMessage `json:"data,omitempty"` // 结果/数据块内容
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
286
python/gobridge/__init__.py
Normal file
286
python/gobridge/__init__.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
"""
|
||||||
|
gobridge - Python 端库,配合 Go 侧 gobridge 使用
|
||||||
|
|
||||||
|
用法::
|
||||||
|
|
||||||
|
from gobridge import gobridge, run
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
@gobridge
|
||||||
|
def add(a: int, b: int) -> int:
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
@gobridge
|
||||||
|
def range_gen(start: int, stop: int) -> Iterator[int]:
|
||||||
|
for i in range(start, stop):
|
||||||
|
yield i # 对应 Go 侧 Invoke[chan int]
|
||||||
|
|
||||||
|
@gobridge
|
||||||
|
def sum_stream(numbers: Iterator[int]) -> int:
|
||||||
|
return sum(numbers) # 对应 Go 侧传入 chan int 参数
|
||||||
|
|
||||||
|
# ctx 取消时,框架自动向该线程注入 InterruptedError,无需在函数中检查
|
||||||
|
@gobridge
|
||||||
|
def slow_compute(n: int) -> int:
|
||||||
|
total = 0
|
||||||
|
for i in range(n):
|
||||||
|
total += i # ctx 取消时这里会抛 InterruptedError
|
||||||
|
return total
|
||||||
|
|
||||||
|
run()
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ctypes
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
|
import signal
|
||||||
|
import socket
|
||||||
|
import struct
|
||||||
|
import threading
|
||||||
|
|
||||||
|
_exposed: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
def gobridge(fn):
|
||||||
|
"""装饰器:将函数暴露给 Go 侧调用"""
|
||||||
|
_exposed[fn.__name__] = fn
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_in_thread(thread_id: int, exc_type: type) -> bool:
|
||||||
|
"""向指定线程注入异常(在下一条字节码指令时触发)。
|
||||||
|
|
||||||
|
对纯 Python 代码及大多数 I/O 操作有效;
|
||||||
|
长时间不释放 GIL 的 C 扩展(如大规模 numpy 运算)无法被中断。
|
||||||
|
"""
|
||||||
|
ret = ctypes.pythonapi.PyThreadState_SetAsyncExc(
|
||||||
|
ctypes.c_ulong(thread_id),
|
||||||
|
ctypes.py_object(exc_type),
|
||||||
|
)
|
||||||
|
return ret == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ─── 底层 IO ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _recv_exactly(sock: socket.socket, n: int) -> bytes:
|
||||||
|
buf = bytearray()
|
||||||
|
while len(buf) < n:
|
||||||
|
chunk = sock.recv(n - len(buf))
|
||||||
|
if not chunk:
|
||||||
|
raise ConnectionError("connection closed")
|
||||||
|
buf.extend(chunk)
|
||||||
|
return bytes(buf)
|
||||||
|
|
||||||
|
|
||||||
|
def _read_msg(sock: socket.socket):
|
||||||
|
"""读取一条消息,返回 dict;连接断开返回 None"""
|
||||||
|
try:
|
||||||
|
header = _recv_exactly(sock, 4)
|
||||||
|
length = struct.unpack(">I", header)[0]
|
||||||
|
body = _recv_exactly(sock, length)
|
||||||
|
return json.loads(body)
|
||||||
|
except (ConnectionError, OSError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _write_msg(sock: socket.socket, msg: dict):
|
||||||
|
"""将 dict 序列化为 JSON 帧并发送"""
|
||||||
|
data = json.dumps(msg, ensure_ascii=False).encode()
|
||||||
|
header = struct.pack(">I", len(data))
|
||||||
|
sock.sendall(header + data)
|
||||||
|
|
||||||
|
|
||||||
|
# ─── 连接多路分发 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _ConnMux:
|
||||||
|
"""单一 reader 线程将消息分发到对应队列,消除多线程读 socket 的竞争。
|
||||||
|
|
||||||
|
消息路由规则:
|
||||||
|
call → call_q(容量 1,主线程阻塞读取)
|
||||||
|
chunk / end → chunk_q(_ChunkIter 消费)
|
||||||
|
cancel → 直接触发已注册线程的 InterruptedError
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, conn: socket.socket):
|
||||||
|
self.conn = conn
|
||||||
|
self.call_q: queue.Queue = queue.Queue(1)
|
||||||
|
self.chunk_q: queue.Queue = queue.Queue()
|
||||||
|
self._active_tids: dict[int, int] = {} # msg_id → thread id
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
threading.Thread(target=self._reader, daemon=True).start()
|
||||||
|
|
||||||
|
def _reader(self):
|
||||||
|
while True:
|
||||||
|
msg = _read_msg(self.conn)
|
||||||
|
if msg is None:
|
||||||
|
# 连接关闭:唤醒主循环,并中断所有正在执行的函数
|
||||||
|
self.call_q.put(None)
|
||||||
|
with self._lock:
|
||||||
|
for tid in self._active_tids.values():
|
||||||
|
_raise_in_thread(tid, InterruptedError)
|
||||||
|
return
|
||||||
|
t = msg.get("type")
|
||||||
|
if t == "call":
|
||||||
|
self.call_q.put(msg)
|
||||||
|
elif t in ("chunk", "end"):
|
||||||
|
self.chunk_q.put(msg)
|
||||||
|
elif t == "cancel":
|
||||||
|
mid = msg.get("id")
|
||||||
|
with self._lock:
|
||||||
|
tid = self._active_tids.get(mid)
|
||||||
|
if tid is not None:
|
||||||
|
_raise_in_thread(tid, InterruptedError)
|
||||||
|
|
||||||
|
def register(self, msg_id: int, thread_id: int):
|
||||||
|
with self._lock:
|
||||||
|
self._active_tids[msg_id] = thread_id
|
||||||
|
|
||||||
|
def unregister(self, msg_id: int):
|
||||||
|
with self._lock:
|
||||||
|
self._active_tids.pop(msg_id, None)
|
||||||
|
|
||||||
|
def write(self, msg: dict):
|
||||||
|
try:
|
||||||
|
_write_msg(self.conn, msg)
|
||||||
|
except OSError:
|
||||||
|
pass # 连接已被 Go 侧关闭
|
||||||
|
|
||||||
|
|
||||||
|
# ─── 连接处理 ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _handle_conn(conn: socket.socket):
|
||||||
|
"""每个连接在独立线程中串行处理消息"""
|
||||||
|
mux = _ConnMux(conn)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
msg = mux.call_q.get()
|
||||||
|
if msg is None:
|
||||||
|
break
|
||||||
|
_dispatch(mux, msg)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _dispatch(mux: _ConnMux, msg: dict):
|
||||||
|
"""处理一条 call 消息"""
|
||||||
|
msg_id = msg["id"]
|
||||||
|
method = msg.get("method", "")
|
||||||
|
fn = _exposed.get(method)
|
||||||
|
|
||||||
|
if fn is None:
|
||||||
|
mux.write({"id": msg_id, "type": "error", "error": f"unknown method: {method}"})
|
||||||
|
return
|
||||||
|
|
||||||
|
args: list = list(msg.get("args") or [])
|
||||||
|
stream_input: bool = msg.get("stream_input", False)
|
||||||
|
stream_arg_idx: int = msg.get("stream_arg_idx", 0)
|
||||||
|
|
||||||
|
# 登记当前线程 id,cancel 消息到达时向该线程注入 InterruptedError
|
||||||
|
mux.register(msg_id, threading.current_thread().ident)
|
||||||
|
|
||||||
|
chunk_iter_instance = None
|
||||||
|
if stream_input:
|
||||||
|
chunk_iter_instance = _ChunkIter(mux)
|
||||||
|
while len(args) <= stream_arg_idx:
|
||||||
|
args.append(None)
|
||||||
|
args[stream_arg_idx] = chunk_iter_instance
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = fn(*args)
|
||||||
|
|
||||||
|
if inspect.isgenerator(result):
|
||||||
|
try:
|
||||||
|
for item in result:
|
||||||
|
mux.write({"id": msg_id, "type": "chunk", "data": item})
|
||||||
|
finally:
|
||||||
|
mux.write({"id": msg_id, "type": "end"})
|
||||||
|
else:
|
||||||
|
mux.write({"id": msg_id, "type": "result", "data": result})
|
||||||
|
|
||||||
|
except InterruptedError:
|
||||||
|
pass # ctx 取消,连接已被 Go 侧关闭,无需回写
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
mux.write({"id": msg_id, "type": "error", "error": str(e)})
|
||||||
|
|
||||||
|
finally:
|
||||||
|
mux.unregister(msg_id)
|
||||||
|
if chunk_iter_instance is not None:
|
||||||
|
chunk_iter_instance.drain()
|
||||||
|
|
||||||
|
|
||||||
|
class _ChunkIter:
|
||||||
|
"""从 _ConnMux.chunk_q 读取数据块的迭代器(对应 Go 侧 chan 输入)。
|
||||||
|
|
||||||
|
所有 socket 读操作已在 _ConnMux reader 线程中完成,此处无竞争。
|
||||||
|
ctx 取消时 _raise_in_thread 会中断阻塞在 queue.get() 的线程。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mux: _ConnMux):
|
||||||
|
self._mux = mux
|
||||||
|
self._done = False
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self._done:
|
||||||
|
raise StopIteration
|
||||||
|
msg = self._mux.chunk_q.get()
|
||||||
|
if msg is None or msg.get("type") == "end":
|
||||||
|
self._done = True
|
||||||
|
raise StopIteration
|
||||||
|
if msg.get("type") == "error":
|
||||||
|
self._done = True
|
||||||
|
raise RuntimeError(msg.get("error", "stream error"))
|
||||||
|
return msg.get("data")
|
||||||
|
|
||||||
|
def drain(self):
|
||||||
|
"""排空未消费的数据块"""
|
||||||
|
if self._done:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
for _ in self:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ─── 主入口 ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def run():
|
||||||
|
"""启动 UDS 服务,阻塞直到进程退出。
|
||||||
|
|
||||||
|
子进程忽略 SIGINT,由 Go 主进程统一处理信号,
|
||||||
|
避免 Ctrl+C 时打印多余的 KeyboardInterrupt 堆栈。
|
||||||
|
"""
|
||||||
|
sock_path = os.environ.get("GOBRIDGE_SOCKET_PATH")
|
||||||
|
if not sock_path:
|
||||||
|
raise RuntimeError("GOBRIDGE_SOCKET_PATH environment variable is not set")
|
||||||
|
|
||||||
|
# 子进程忽略 SIGINT,Go 主进程会主动 kill 子进程
|
||||||
|
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||||
|
|
||||||
|
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||||
|
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
server.bind(sock_path)
|
||||||
|
server.listen(64)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
conn, _ = server.accept()
|
||||||
|
except (KeyboardInterrupt, SystemExit, OSError):
|
||||||
|
break
|
||||||
|
t = threading.Thread(target=_handle_conn, args=(conn,), daemon=True)
|
||||||
|
t.start()
|
||||||
|
finally:
|
||||||
|
server.close()
|
||||||
|
try:
|
||||||
|
os.unlink(sock_path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
182
worker.go
Normal file
182
worker.go
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
package gobridge
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type worker struct {
|
||||||
|
cfg *poolConfig
|
||||||
|
id int
|
||||||
|
sockPath string
|
||||||
|
conns chan net.Conn
|
||||||
|
stopped atomic.Bool
|
||||||
|
stopCh chan struct{}
|
||||||
|
stopOnce sync.Once
|
||||||
|
|
||||||
|
mu sync.Mutex // 保护 cmd
|
||||||
|
cmd *exec.Cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWorker(cfg *poolConfig, id int) (*worker, error) {
|
||||||
|
w := &worker{
|
||||||
|
cfg: cfg,
|
||||||
|
id: id,
|
||||||
|
conns: make(chan net.Conn, cfg.maxConnsPerWorker),
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
if err := w.start(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
go w.monitor()
|
||||||
|
return w, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// start 启动 Python 子进程并预建连接池,返回当前 cmd 供 monitor() 等待
|
||||||
|
func (w *worker) start() error {
|
||||||
|
sockPath := fmt.Sprintf("%s/gobridge-%d-%d.sock", w.cfg.socketDir, os.Getpid(), w.id)
|
||||||
|
os.Remove(sockPath)
|
||||||
|
w.sockPath = sockPath
|
||||||
|
|
||||||
|
cmd := exec.Command(w.cfg.pythonExe, w.cfg.scriptArgs...)
|
||||||
|
cmd.Dir = w.cfg.workDir
|
||||||
|
cmd.Env = append(os.Environ(), w.cfg.env...)
|
||||||
|
cmd.Env = append(cmd.Env, "GOBRIDGE_SOCKET_PATH="+sockPath)
|
||||||
|
if w.cfg.stdout != nil {
|
||||||
|
cmd.Stdout = w.cfg.stdout
|
||||||
|
} else {
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
}
|
||||||
|
if w.cfg.stderr != nil {
|
||||||
|
cmd.Stderr = w.cfg.stderr
|
||||||
|
} else {
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return fmt.Errorf("start python worker: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.mu.Lock()
|
||||||
|
w.cmd = cmd
|
||||||
|
w.mu.Unlock()
|
||||||
|
|
||||||
|
// 等待 socket 文件出现(最多 10 秒)
|
||||||
|
deadline := time.Now().Add(10 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if _, err := os.Stat(sockPath); err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-w.stopCh:
|
||||||
|
cmd.Process.Kill()
|
||||||
|
cmd.Wait()
|
||||||
|
return fmt.Errorf("stopped while waiting for socket")
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(sockPath); err != nil {
|
||||||
|
cmd.Process.Kill()
|
||||||
|
cmd.Wait()
|
||||||
|
return fmt.Errorf("worker socket did not appear: %s", sockPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 预建连接
|
||||||
|
for i := 0; i < w.cfg.maxConnsPerWorker; i++ {
|
||||||
|
conn, err := net.DialTimeout("unix", sockPath, 5*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
cmd.Process.Kill()
|
||||||
|
cmd.Wait()
|
||||||
|
for len(w.conns) > 0 {
|
||||||
|
(<-w.conns).Close()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("connect to worker: %w", err)
|
||||||
|
}
|
||||||
|
w.conns <- conn
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// monitor 监控 Python 进程,崩溃时自动重启(指数退避,最长 30s)
|
||||||
|
func (w *worker) monitor() {
|
||||||
|
for {
|
||||||
|
// 等待当前进程退出
|
||||||
|
w.mu.Lock()
|
||||||
|
cmd := w.cmd
|
||||||
|
w.mu.Unlock()
|
||||||
|
cmd.Wait()
|
||||||
|
|
||||||
|
if w.stopped.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
os.Remove(w.sockPath)
|
||||||
|
|
||||||
|
// 排空失效连接
|
||||||
|
for len(w.conns) > 0 {
|
||||||
|
(<-w.conns).Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 指数退避重启
|
||||||
|
for attempt := 0; !w.stopped.Load(); attempt++ {
|
||||||
|
if err := w.start(); err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
delay := time.Duration(min(1<<attempt, 300)) * 100 * time.Millisecond
|
||||||
|
select {
|
||||||
|
case <-time.After(delay):
|
||||||
|
case <-w.stopCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *worker) acquire(ctx context.Context) (net.Conn, error) {
|
||||||
|
select {
|
||||||
|
case conn := <-w.conns:
|
||||||
|
return conn, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *worker) release(conn net.Conn, healthy bool) {
|
||||||
|
if !healthy {
|
||||||
|
conn.Close()
|
||||||
|
newConn, err := net.DialTimeout("unix", w.sockPath, time.Second)
|
||||||
|
if err == nil {
|
||||||
|
select {
|
||||||
|
case w.conns <- newConn:
|
||||||
|
default:
|
||||||
|
newConn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.conns <- conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *worker) stop() {
|
||||||
|
w.stopped.Store(true)
|
||||||
|
w.stopOnce.Do(func() { close(w.stopCh) })
|
||||||
|
|
||||||
|
w.mu.Lock()
|
||||||
|
cmd := w.cmd
|
||||||
|
w.mu.Unlock()
|
||||||
|
|
||||||
|
if cmd != nil && cmd.Process != nil {
|
||||||
|
cmd.Process.Signal(os.Interrupt) // 先发 SIGINT,Python 已忽略,等效于 Kill
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
cmd.Process.Kill()
|
||||||
|
}
|
||||||
|
for len(w.conns) > 0 {
|
||||||
|
(<-w.conns).Close()
|
||||||
|
}
|
||||||
|
os.Remove(w.sockPath)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user