From 1322280daf74e81b05fc602053ecdcb7329996c3 Mon Sep 17 00:00:00 2001 From: what Date: Tue, 2 Jun 2026 19:21:47 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=97=A0=E8=AE=A2?= =?UTF-8?q?=E9=98=85=E8=80=85=E6=97=B6=E6=B6=88=E6=81=AF=E9=9D=99=E9=BB=98?= =?UTF-8?q?=E4=B8=A2=E5=A4=B1=E9=97=AE=E9=A2=98=EF=BC=8C=E5=AE=8C=E5=96=84?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 pending 缓冲区,publish 时若无订阅者则暂存消息 - subscribe 时自动将缓冲消息投入 channel,解决服务重启后恢复任务丢失的问题 - 去除 broadcast 5ms 超时导致的消息丢失 - chan bool 改为 chan struct{},RWMutex 改为 Mutex - 新增 broker_test.go,12 个单元测试覆盖核心场景(含 -race) - 为 client_test.go 中的无限循环 demo 添加 t.Skip() --- broker.go | 143 ++++++++++++----------------- broker_test.go | 238 +++++++++++++++++++++++++++++++++++++++++++++++++ client_test.go | 6 +- 3 files changed, 297 insertions(+), 90 deletions(-) create mode 100644 broker_test.go diff --git a/broker.go b/broker.go index 8fa394d..b2d7c43 100644 --- a/broker.go +++ b/broker.go @@ -3,102 +3,66 @@ package queue import ( "errors" "sync" - "time" ) type Broker struct { - exit chan bool // 关闭消息队列通道 - capacity int // 消息队列的容量 - topics map[string][]chan any // key: topic value : queue, 一个topic可以有多个订阅者,一个订阅者对应着一个通道 - sync.RWMutex // 同步锁 + exit chan struct{} + capacity int + topics map[string][]chan any + pending map[string][]any // 订阅前发布的消息缓冲,subscribe 时一次性投递 + mu sync.Mutex } -// 设置消息容量 -// @description 控制消息队列的大小 func (b *Broker) setConditions(capacity int) { + b.mu.Lock() b.capacity = capacity + b.mu.Unlock() } -// 关闭消息队列 func (b *Broker) close() { select { case <-b.exit: return default: close(b.exit) - b.Lock() + b.mu.Lock() b.topics = make(map[string][]chan any) - b.Unlock() + b.pending = make(map[string][]any) + b.mu.Unlock() } - return } -// 消息推送 -// @param topic 订阅的主题 -// @param msg 传递的消息 -func (b *Broker) publish(topic string, pub any) error { +// publish 推送消息;若暂无订阅者则缓冲,等待订阅者注册后投递。 +func (b *Broker) publish(topic string, msg any) error { select { case <-b.exit: return errors.New("broker closed") default: } - b.RLock() - subscribers, ok := b.topics[topic] - b.RUnlock() - if !ok { + b.mu.Lock() + subs := b.topics[topic] + if len(subs) == 0 { + b.pending[topic] = append(b.pending[topic], msg) + b.mu.Unlock() return nil } + // 持有锁期间只做列表复制,发送在锁外进行,避免阻塞其他 publish + chs := make([]chan any, len(subs)) + copy(chs, subs) + b.mu.Unlock() - b.broadcast(pub, subscribers) + for _, ch := range chs { + select { + case ch <- msg: + case <-b.exit: + return errors.New("broker closed") + } + } return nil } -// 消息广播 -// @description 对推送的消息进行广播,保证每一个订阅者都可以收到 -func (b *Broker) broadcast(msg any, subscribers []chan any) { - count := len(subscribers) - concurrency := 1 - - switch { - case count > 1000: - concurrency = 3 - case count > 100: - concurrency = 2 - default: - concurrency = 1 - } - - pub := func(start int) { - //采用Timer 而不是使用time.After 原因:time.After会产生内存泄漏 在计时器触发之前,垃圾回收器不会回收Timer - idleDuration := 5 * time.Millisecond - idleTimeout := time.NewTimer(idleDuration) - defer idleTimeout.Stop() - for j := start; j < count; j += concurrency { - if !idleTimeout.Stop() { - select { - case <-idleTimeout.C: - default: - } - } - idleTimeout.Reset(idleDuration) - select { - case subscribers[j] <- msg: - case <-idleTimeout.C: - case <-b.exit: - return - } - } - } - for i := 0; i < concurrency; i++ { - go pub(i) - } -} - -// 消息订阅 -// @description 传入订阅的主题,即可完成订阅 -// @param topic 订阅的主题 -// @return sub 通道用来接收数据 +// subscribe 订阅 topic,返回 channel;同时将该 topic 的缓冲消息立即投入 channel。 func (b *Broker) subscribe(topic string) (<-chan any, error) { select { case <-b.exit: @@ -106,16 +70,25 @@ func (b *Broker) subscribe(topic string) (<-chan any, error) { default: } - ch := make(chan any, b.capacity) - b.Lock() + b.mu.Lock() + capacity := b.capacity + if capacity <= 0 { + capacity = 10 + } + ch := make(chan any, capacity) b.topics[topic] = append(b.topics[topic], ch) - b.Unlock() + buffered := b.pending[topic] + delete(b.pending, topic) + b.mu.Unlock() + + // channel 刚创建必然不满,直接写入不会阻塞 + for _, msg := range buffered { + ch <- msg + } + return ch, nil } -// 取消订阅 -// @param topic 订阅的主题 -// @param sub 消息订阅的通道 func (b *Broker) unsubscribe(topic string, sub <-chan any) error { select { case <-b.exit: @@ -123,31 +96,25 @@ func (b *Broker) unsubscribe(topic string, sub <-chan any) error { default: } - b.RLock() - subscribers, ok := b.topics[topic] - b.RUnlock() + b.mu.Lock() + defer b.mu.Unlock() - if !ok { - return nil - } - // delete subscriber - b.Lock() - var newSubs []chan any - for _, subscriber := range subscribers { - if subscriber == sub { - continue + subs := b.topics[topic] + newSubs := subs[:0] + for _, s := range subs { + if s != sub { + newSubs = append(newSubs, s) } - newSubs = append(newSubs, subscriber) } - b.topics[topic] = newSubs - b.Unlock() return nil } func NewBroker() *Broker { return &Broker{ - exit: make(chan bool), - topics: make(map[string][]chan any), + exit: make(chan struct{}), + capacity: 10, + topics: make(map[string][]chan any), + pending: make(map[string][]any), } } diff --git a/broker_test.go b/broker_test.go new file mode 100644 index 0000000..0a132e7 --- /dev/null +++ b/broker_test.go @@ -0,0 +1,238 @@ +package queue + +import ( + "fmt" + "sync" + "sync/atomic" + "testing" + "time" +) + +// recv 从 channel 读取一条消息,超时则返回 nil。 +func recv(ch <-chan any, timeout time.Duration) any { + select { + case v := <-ch: + return v + case <-time.After(timeout): + return nil + } +} + +// TestPublishSubscribe 基本收发 +func TestPublishSubscribe(t *testing.T) { + b := NewBroker() + defer b.close() + + ch, err := b.subscribe("job-a") + if err != nil { + t.Fatalf("subscribe: %v", err) + } + + if err := b.publish("job-a", "hello"); err != nil { + t.Fatalf("publish: %v", err) + } + + got := recv(ch, time.Second) + if got != "hello" { + t.Fatalf("want %q, got %v", "hello", got) + } +} + +// TestPendingBuffer 先发布再订阅,消息不能丢失(队列核心保证) +func TestPendingBuffer(t *testing.T) { + b := NewBroker() + defer b.close() + + // 先 publish,此时无订阅者 + for i := 0; i < 3; i++ { + if err := b.publish("job-b", i); err != nil { + t.Fatalf("publish %d: %v", i, err) + } + } + + // 再 subscribe,应收到全部缓冲消息 + ch, _ := b.subscribe("job-b") + + for want := 0; want < 3; want++ { + got := recv(ch, time.Second) + if got != want { + t.Fatalf("pending msg[%d]: want %d, got %v", want, want, got) + } + } +} + +// TestPendingThenNormal 缓冲消息先于后续消息到达,顺序正确 +func TestPendingThenNormal(t *testing.T) { + b := NewBroker() + defer b.close() + + b.publish("job-c", "buffered") + + ch, _ := b.subscribe("job-c") + b.publish("job-c", "live") + + msgs := []any{recv(ch, time.Second), recv(ch, time.Second)} + if msgs[0] != "buffered" || msgs[1] != "live" { + t.Fatalf("order wrong: %v", msgs) + } +} + +// TestMultipleSubscribers 同一 topic 多个订阅者都能收到消息 +func TestMultipleSubscribers(t *testing.T) { + b := NewBroker() + defer b.close() + + ch1, _ := b.subscribe("broadcast") + ch2, _ := b.subscribe("broadcast") + + b.publish("broadcast", "msg") + + if recv(ch1, time.Second) != "msg" { + t.Fatal("ch1 did not receive message") + } + if recv(ch2, time.Second) != "msg" { + t.Fatal("ch2 did not receive message") + } +} + +// TestUnsubscribe 取消订阅后不再收到消息 +func TestUnsubscribe(t *testing.T) { + b := NewBroker() + defer b.close() + + ch, _ := b.subscribe("job-d") + b.unsubscribe("job-d", ch) + + b.publish("job-d", "should not arrive") + + if got := recv(ch, 100*time.Millisecond); got != nil { + t.Fatalf("after unsubscribe, still got %v", got) + } +} + +// TestClosedPublish 关闭后 publish 返回错误 +func TestClosedPublish(t *testing.T) { + b := NewBroker() + b.close() + + if err := b.publish("x", "msg"); err == nil { + t.Fatal("expected error after close, got nil") + } +} + +// TestClosedSubscribe 关闭后 subscribe 返回错误 +func TestClosedSubscribe(t *testing.T) { + b := NewBroker() + b.close() + + if _, err := b.subscribe("x"); err == nil { + t.Fatal("expected error after close, got nil") + } +} + +// TestConcurrentPublish 并发发布不丢消息,无 data race +func TestConcurrentPublish(t *testing.T) { + b := NewBroker() + defer b.close() + b.setConditions(100) + + const n = 50 + ch, _ := b.subscribe("concurrent") + + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + b.publish("concurrent", i) + }(i) + } + wg.Wait() + + var count int32 + done := make(chan struct{}) + go func() { + for recv(ch, 200*time.Millisecond) != nil { + atomic.AddInt32(&count, 1) + } + close(done) + }() + <-done + + if int(count) != n { + t.Fatalf("want %d messages, got %d", n, count) + } +} + +// TestSetConditions 容量设置生效(channel 满时不丢已缓冲的消息) +func TestSetConditions(t *testing.T) { + b := NewBroker() + defer b.close() + b.setConditions(5) + + ch, _ := b.subscribe("cap-test") + + for i := 0; i < 5; i++ { + if err := b.publish("cap-test", i); err != nil { + t.Fatalf("publish %d: %v", i, err) + } + } + + for want := 0; want < 5; want++ { + got := recv(ch, time.Second) + if got != want { + t.Fatalf("msg[%d]: want %d, got %v", want, want, got) + } + } +} + +// TestClientWrapper Client 封装与 Broker 行为一致 +func TestClientWrapper(t *testing.T) { + c := NewClient() + defer c.Close() + c.SetConditions(10) + + ch, err := c.Subscribe("wrap") + if err != nil { + t.Fatalf("Subscribe: %v", err) + } + + c.Publish("wrap", "ok") + + got := recv(ch, time.Second) + if got != "ok" { + t.Fatalf("want %q got %v", "ok", got) + } +} + +// TestGetPayload GetPayload 在 channel 关闭时返回 nil 而非阻塞 +func TestGetPayload(t *testing.T) { + c := NewClient() + ch, _ := c.Subscribe("gp") + c.Publish("gp", "payload") + + got := c.GetPayload(ch) + if got != "payload" { + t.Fatalf("want %q got %v", "payload", got) + } +} + +// TestPendingBufferMultipleTopics 多个 topic 的缓冲互不干扰 +func TestPendingBufferMultipleTopics(t *testing.T) { + b := NewBroker() + defer b.close() + + for i := 0; i < 5; i++ { + topic := fmt.Sprintf("topic-%d", i) + b.publish(topic, i*10) + } + + for i := 0; i < 5; i++ { + topic := fmt.Sprintf("topic-%d", i) + ch, _ := b.subscribe(topic) + got := recv(ch, time.Second) + if got != i*10 { + t.Fatalf("topic-%d: want %d, got %v", i, i*10, got) + } + } +} diff --git a/client_test.go b/client_test.go index 5521b42..f2edc83 100644 --- a/client_test.go +++ b/client_test.go @@ -10,6 +10,7 @@ const topic = "Golang梦工厂" // 一个topic 测试 func TestOnceTopic(t *testing.T) { + t.Skip("infinite loop demo, not a unit test") m := NewClient() defer m.Close() m.SetConditions(10) @@ -24,7 +25,7 @@ func TestOnceTopic(t *testing.T) { // 定时推送 func OncePub(c *Client) { - t := time.NewTicker(10 * time.Second) + t := time.NewTicker(1 * time.Second) defer t.Stop() for { select { @@ -47,8 +48,9 @@ func OnceSub(m <-chan interface{}, c *Client) { } } -//多个topic测试 +// 多个topic测试 func TestManyTopic(t *testing.T) { + t.Skip("infinite loop demo, not a unit test") m := NewClient() defer m.Close() m.SetConditions(10)