Files
gobridge/worker.go
2026-04-10 09:24:28 +08:00

183 lines
3.6 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package gobridge
import (
"context"
"fmt"
"net"
"os"
"os/exec"
"sync"
"sync/atomic"
"time"
)
type worker struct {
cfg *poolConfig
id int
sockPath string
conns chan net.Conn
stopped atomic.Bool
stopCh chan struct{}
stopOnce sync.Once
mu sync.Mutex // 保护 cmd
cmd *exec.Cmd
}
func newWorker(cfg *poolConfig, id int) (*worker, error) {
w := &worker{
cfg: cfg,
id: id,
conns: make(chan net.Conn, cfg.maxConnsPerWorker),
stopCh: make(chan struct{}),
}
if err := w.start(); err != nil {
return nil, err
}
go w.monitor()
return w, nil
}
// start 启动 Python 子进程并预建连接池,返回当前 cmd 供 monitor() 等待
func (w *worker) start() error {
sockPath := fmt.Sprintf("%s/gobridge-%d-%d.sock", w.cfg.socketDir, os.Getpid(), w.id)
os.Remove(sockPath)
w.sockPath = sockPath
cmd := exec.Command(w.cfg.pythonExe, w.cfg.scriptArgs...)
cmd.Dir = w.cfg.workDir
cmd.Env = append(os.Environ(), w.cfg.env...)
cmd.Env = append(cmd.Env, "GOBRIDGE_SOCKET_PATH="+sockPath)
if w.cfg.stdout != nil {
cmd.Stdout = w.cfg.stdout
} else {
cmd.Stdout = os.Stdout
}
if w.cfg.stderr != nil {
cmd.Stderr = w.cfg.stderr
} else {
cmd.Stderr = os.Stderr
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("start python worker: %w", err)
}
w.mu.Lock()
w.cmd = cmd
w.mu.Unlock()
// 等待 socket 文件出现(最多 10 秒)
deadline := time.Now().Add(10 * time.Second)
for time.Now().Before(deadline) {
if _, err := os.Stat(sockPath); err == nil {
break
}
select {
case <-w.stopCh:
cmd.Process.Kill()
cmd.Wait()
return fmt.Errorf("stopped while waiting for socket")
case <-time.After(50 * time.Millisecond):
}
}
if _, err := os.Stat(sockPath); err != nil {
cmd.Process.Kill()
cmd.Wait()
return fmt.Errorf("worker socket did not appear: %s", sockPath)
}
// 预建连接
for i := 0; i < w.cfg.maxConnsPerWorker; i++ {
conn, err := net.DialTimeout("unix", sockPath, 5*time.Second)
if err != nil {
cmd.Process.Kill()
cmd.Wait()
for len(w.conns) > 0 {
(<-w.conns).Close()
}
return fmt.Errorf("connect to worker: %w", err)
}
w.conns <- conn
}
return nil
}
// monitor 监控 Python 进程,崩溃时自动重启(指数退避,最长 30s
func (w *worker) monitor() {
for {
// 等待当前进程退出
w.mu.Lock()
cmd := w.cmd
w.mu.Unlock()
cmd.Wait()
if w.stopped.Load() {
return
}
os.Remove(w.sockPath)
// 排空失效连接
for len(w.conns) > 0 {
(<-w.conns).Close()
}
// 指数退避重启
for attempt := 0; !w.stopped.Load(); attempt++ {
if err := w.start(); err == nil {
break
}
delay := time.Duration(min(1<<attempt, 300)) * 100 * time.Millisecond
select {
case <-time.After(delay):
case <-w.stopCh:
return
}
}
}
}
func (w *worker) acquire(ctx context.Context) (net.Conn, error) {
select {
case conn := <-w.conns:
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (w *worker) release(conn net.Conn, healthy bool) {
if !healthy {
conn.Close()
newConn, err := net.DialTimeout("unix", w.sockPath, time.Second)
if err == nil {
select {
case w.conns <- newConn:
default:
newConn.Close()
}
}
return
}
w.conns <- conn
}
func (w *worker) stop() {
w.stopped.Store(true)
w.stopOnce.Do(func() { close(w.stopCh) })
w.mu.Lock()
cmd := w.cmd
w.mu.Unlock()
if cmd != nil && cmd.Process != nil {
cmd.Process.Signal(os.Interrupt) // 先发 SIGINTPython 已忽略,等效于 Kill
time.Sleep(50 * time.Millisecond)
cmd.Process.Kill()
}
for len(w.conns) > 0 {
(<-w.conns).Close()
}
os.Remove(w.sockPath)
}