Range-diff rd-192
- title
- chore(pubsub): add more tests
- description
-
Patch equal - old #1
603ca6e- new #1
603ca6e
- title
- feat(pubsub): round robin
- description
-
Patch changed - old #2
501c042- new #2
17e00b2
1: 603ca6e = 1: 603ca6e chore(pubsub): add more tests
2: 501c042 ! 2: 17e00b2 feat(pubsub): round robin
old
old:
pkg/pubsub/channel.go
new:pkg/pubsub/channel.go
Clients *syncmap.Map[string, *Client] handleOnce sync.Once cleanupOnce sync.Once + Dispatcher MessageDispatcher } func (c *Channel) GetClients() iter.Seq2[string, *Client] { } func (c *Channel) Handle() { + // If no dispatcher is set, use multicast as default + if c.Dispatcher == nil { + c.Dispatcher = &MulticastDispatcher{} + } + c.handleOnce.Do(func() { go func() { defer func() { case <-c.Done: return case data, ok := <-c.Data: - var wg sync.WaitGroup - for _, client := range c.GetClients() { - if client.Direction == ChannelDirectionInput || (client.ID == data.ClientID && !client.Replay) { - continue + if !ok { + // Channel is closing, close all client data channels + for _, client := range c.GetClients() { + client.onceData.Do(func() { + close(client.Data) + }) } - - wg.Add(1) - go func() { - defer wg.Done() - if !ok { - client.onceData.Do(func() { - close(client.Data) - }) - return - } - - select { - case client.Data <- data: - case <-client.Done: - case <-c.Done: - } - }() + return } - wg.Wait() + + // Collect eligible subscribers + subscribers := dispatcherForGetClients(c.GetClients(), data) + + // Dispatch message using the configured dispatcher + _ = c.Dispatcher.Dispatch(data, subscribers, c.Done) } } }()
new
old:
pkg/pubsub/channel.go
new:pkg/pubsub/channel.go
Clients *syncmap.Map[string, *Client] handleOnce sync.Once cleanupOnce sync.Once + Dispatcher MessageDispatcher } func (c *Channel) GetClients() iter.Seq2[string, *Client] { case <-c.Done: return case data, ok := <-c.Data: - var wg sync.WaitGroup - for _, client := range c.GetClients() { - if client.Direction == ChannelDirectionInput || (client.ID == data.ClientID && !client.Replay) { - continue + if !ok { + // Channel is closing, close all client data channels + for _, client := range c.GetClients() { + client.onceData.Do(func() { + close(client.Data) + }) } - - wg.Add(1) - go func() { - defer wg.Done() - if !ok { - client.onceData.Do(func() { - close(client.Data) - }) - return - } - - select { - case client.Data <- data: - case <-client.Done: - case <-c.Done: - } - }() + return } - wg.Wait() + + // Collect eligible subscribers + subscribers := dispatcherForGetClients(c.GetClients(), data) + + // Dispatch message using the configured dispatcher + _ = c.Dispatcher.Dispatch(data, subscribers, c.Done) } } }()
old
new:
pkg/pubsub/roundrobin.go
+package pubsub + +import ( + "context" + "errors" + "io" + "iter" + "log/slog" + "sync" + + "github.com/antoniomika/syncmap" +) + +/* +RoundRobin is a load-balancing broker that distributes published messages +to subscribers using a round-robin algorithm. + +Unlike Multicast which sends each message to all subscribers, RoundRobin +sends each message to exactly one subscriber, rotating through the available +subscribers for each published message. This provides load balancing for +message processing. + +It maintains independent round-robin state per channel/topic. +*/ +type RoundRobin struct { + Broker + Logger *slog.Logger +} + +func NewRoundRobin(logger *slog.Logger) *RoundRobin { + return &RoundRobin{ + Logger: logger, + Broker: &BaseBroker{ + Channels: syncmap.New[string, *Channel](), + Logger: logger.With(slog.Bool("broker", true)), + }, + } +} + +func (p *RoundRobin) getClients(direction ChannelDirection) iter.Seq2[string, *Client] { + return func(yield func(string, *Client) bool) { + for clientID, client := range p.GetClients() { + if client.Direction == direction { + yield(clientID, client) + } + } + } +} + +func (p *RoundRobin) GetPipes() iter.Seq2[string, *Client] { + return p.getClients(ChannelDirectionInputOutput) +} + +func (p *RoundRobin) GetPubs() iter.Seq2[string, *Client] { + return p.getClients(ChannelDirectionInput) +} + +func (p *RoundRobin) GetSubs() iter.Seq2[string, *Client] { + return p.getClients(ChannelDirectionOutput) +} + +func (p *RoundRobin) connect(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, direction ChannelDirection, blockWrite bool, replay, keepAlive bool) (error, error) { + client := NewClient(ID, rw, direction, blockWrite, replay, keepAlive) + + go func() { + <-ctx.Done() + client.Cleanup() + }() + + return p.Connect(client, channels) +} + +func (p *RoundRobin) Pipe(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, replay bool) (error, error) { + return p.connect(ctx, ID, rw, channels, ChannelDirectionInputOutput, false, replay, false) +} + +func (p *RoundRobin) Pub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, blockWrite bool) error { + return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionInput, blockWrite, false, false)) +} + +func (p *RoundRobin) Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, keepAlive bool) error { + return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionOutput, false, false, keepAlive)) +} + +// ensureChannel wraps BaseBroker.ensureChannel to set up round-robin dispatcher. +func (p *RoundRobin) ensureChannel(channel *Channel) *Channel { + baseBroker := p.Broker.(*BaseBroker) + dataChannel, _ := baseBroker.Channels.LoadOrStore(channel.Topic, channel) + // Set the round-robin dispatcher on the channel + if dataChannel.Dispatcher == nil { + dataChannel.Dispatcher = &RoundRobinDispatcher{} + } + dataChannel.Handle() + return dataChannel +} + +// Override Connect to use our custom ensureChannel. +func (p *RoundRobin) Connect(client *Client, channels []*Channel) (error, error) { + for _, channel := range channels { + dataChannel := p.ensureChannel(channel) + dataChannel.Clients.Store(client.ID, client) + client.Channels.Store(dataChannel.Topic, dataChannel) + defer func() { + client.Channels.Delete(channel.Topic) + dataChannel.Clients.Delete(client.ID) + + client.Cleanup() + + count := 0 + for _, cl := range dataChannel.GetClients() { + if cl.Direction == ChannelDirectionInput || cl.Direction == ChannelDirectionInputOutput { + count++ + } + } + + if count == 0 { + for _, cl := range dataChannel.GetClients() { + if !cl.KeepAlive { + cl.Cleanup() + } + } + } + + p.Cleanup() + }() + } + + baseBroker := p.Broker.(*BaseBroker) + return baseBroker.Connect(client, channels) +} + +// Cleanup delegates to BaseBroker. +func (p *RoundRobin) Cleanup() { + baseBroker := p.Broker.(*BaseBroker) + baseBroker.Cleanup() +} + +// RoundRobinDispatcher sends each message to a single subscriber in round-robin order. +type RoundRobinDispatcher struct { + index uint32 + mu sync.Mutex +} + +func (d *RoundRobinDispatcher) Dispatch(msg ChannelMessage, subscribers []*Client, channelDone chan struct{}) error { + // If no subscribers, nothing to dispatch + // BlockWrite behavior at publish time ensures subscribers are present when needed + if len(subscribers) == 0 { + return nil + } + + // Select the next subscriber in round-robin order + d.mu.Lock() + selectedIdx := int(d.index % uint32(len(subscribers))) + d.index++ + d.mu.Unlock() + + selectedClient := subscribers[selectedIdx] + + select { + case selectedClient.Data <- msg: + case <-selectedClient.Done: + case <-channelDone: + } + + return nil +} + +var _ PubSub = (*RoundRobin)(nil)
new
new:
pkg/pubsub/roundrobin.go
+package pubsub + +import ( + "slices" + "strings" + "sync" +) + +/* +RoundRobin is a load-balancing broker that distributes published messages +to subscribers using a round-robin algorithm. + +Unlike Multicast which sends each message to all subscribers, RoundRobin +sends each message to exactly one subscriber, rotating through the available +subscribers for each published message. This provides load balancing for +message processing. + +It maintains independent round-robin state per channel/topic. +*/ +type RoundRobinDispatcher struct { + index uint32 + mu sync.Mutex +} + +func (d *RoundRobinDispatcher) Dispatch(msg ChannelMessage, subscribers []*Client, channelDone chan struct{}) error { + // If no subscribers, nothing to dispatch + // BlockWrite behavior at publish time ensures subscribers are present when needed + if len(subscribers) == 0 { + return nil + } + + slices.SortFunc(subscribers, func(a, b *Client) int { + return strings.Compare(a.ID, b.ID) + }) + + // Select the next subscriber in round-robin order + d.mu.Lock() + selectedIdx := int(d.index % uint32(len(subscribers))) + d.index++ + d.mu.Unlock() + + selectedClient := subscribers[selectedIdx] + + select { + case selectedClient.Data <- msg: + case <-selectedClient.Done: + case <-channelDone: + } + + return nil +}
old
new
old:
pkg/apps/pipe/cli.go
new:pkg/apps/pipe/cli.go
block := pubCmd.Bool("b", true, "Block writes until a subscriber is available") timeout := pubCmd.Duration("t", 30*24*time.Hour, "Timeout as a Go duration to block for a subscriber to be available. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.") clean := pubCmd.Bool("c", false, "Don't send status messages") + broker := pubCmd.String("bk", "multicast", "Type of broker (e.g. multicast, round_robin)") if !flagCheck(pubCmd, topic, cmd.args) { return fmt.Errorf("invalid cmd args") "topic", topic, "access", *access, "clean", *clean, + "broker", *broker, ) var accessList []string throttledRW := newThrottledMonitorRW(rw, handler, cmd, name) + var bk psub.MessageDispatcher + bk = &psub.MulticastDispatcher{} + if *broker == "round_robin" { + fmt.Println("BROKER ROUND ROBIN") + bk = &psub.RoundRobinDispatcher{} + } + channel := psub.NewChannel(name) + channel.Dispatcher = bk + err := handler.PubSub.Pub( cmd.pipeCtx, clientID, throttledRW, - []*psub.Channel{ - psub.NewChannel(name), - }, + []*psub.Channel{channel}, *block, )
old
new
old:
pkg/pubsub/broker.go
new:pkg/pubsub/broker.go
func (b *BaseBroker) ensureChannel(channel *Channel) *Channel { dataChannel, _ := b.Channels.LoadOrStore(channel.Topic, channel) + // Allow overwriting the dispatcher + if channel.Dispatcher != nil && dataChannel.Dispatcher == nil { + dataChannel.Dispatcher = channel.Dispatcher + } + dataChannel.Handle() return dataChannel }