Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: Introduce batch subscription in msgdispatcher #39863

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion internal/datacoord/channel_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package datacoord
import (
"context"
"fmt"
"sort"
"sync"
"time"

Expand Down Expand Up @@ -545,7 +546,15 @@ func (m *ChannelManagerImpl) advanceToNotifies(ctx context.Context, toNotifies [
zap.Int("total operation count", len(nodeAssign.Channels)),
zap.Strings("channel names", chNames),
)
for _, ch := range nodeAssign.Channels {

// Sort watch tasks by seek position to minimize lag between
// positions during batch subscription in the dispatcher.
channels := lo.Values(nodeAssign.Channels)
sort.Slice(channels, func(i, j int) bool {
return channels[i].GetWatchInfo().GetVchan().GetSeekPosition().GetTimestamp() <
channels[j].GetWatchInfo().GetVchan().GetSeekPosition().GetTimestamp()
})
for _, ch := range channels {
innerCh := ch
tmpWatchInfo := typeutil.Clone(innerCh.GetWatchInfo())
tmpWatchInfo.Vchan = m.h.GetDataVChanPositions(innerCh, allPartitionID)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ func createNewInputFromDispatcher(initCtx context.Context,
retry.MaxSleepTime(paramtable.Get().MQCfg.RetryTimeout.GetAsDuration(time.Second))) // 5 minutes
if err != nil {
log.Warn("datanode consume failed after retried", zap.Error(err))
dispatcherClient.Deregister(vchannel)
return nil, err
}

Expand Down Expand Up @@ -130,6 +131,7 @@ func createNewInputFromDispatcher(initCtx context.Context,
retry.MaxSleepTime(paramtable.Get().MQCfg.RetryTimeout.GetAsDuration(time.Second))) // 5 minutes
if err != nil {
log.Warn("datanode consume failed after retried", zap.Error(err))
dispatcherClient.Deregister(vchannel)
return nil, err
}

Expand Down
30 changes: 10 additions & 20 deletions internal/querynodev2/services_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,11 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() {
}

// mocks
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Chan().Return(suite.msgChan)
suite.msgStream.EXPECT().Close()
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil).Maybe()
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.msgStream.EXPECT().Chan().Return(suite.msgChan).Maybe()
suite.msgStream.EXPECT().Close().Maybe()

// watchDmChannels
status, err := suite.node.WatchDmChannels(ctx, req)
Expand Down Expand Up @@ -363,11 +363,11 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() {
}

// mocks
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Chan().Return(suite.msgChan)
suite.msgStream.EXPECT().Close()
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil).Maybe()
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.msgStream.EXPECT().Chan().Return(suite.msgChan).Maybe()
suite.msgStream.EXPECT().Close().Maybe()

// watchDmChannels
status, err := suite.node.WatchDmChannels(ctx, req)
Expand Down Expand Up @@ -498,16 +498,6 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() {
suite.ErrorIs(merr.Error(status), merr.ErrChannelReduplicate)
suite.node.unsubscribingChannels.Remove(suite.vchannel)

// init msgstream failed
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil)
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil)
suite.msgStream.EXPECT().Close().Return()
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock error")).Once()

status, err = suite.node.WatchDmChannels(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode())

// load growing failed
badSegmentReq := typeutil.Clone(req)
for _, info := range badSegmentReq.SegmentInfos {
Expand Down
1 change: 1 addition & 0 deletions internal/util/pipeline/stream_pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ func (p *streamPipeline) ConsumeMsgStream(ctx context.Context, position *msgpb.M
retry.MaxSleepTime(paramtable.Get().MQCfg.RetryTimeout.GetAsDuration(time.Second))) // 5 minutes
if err != nil {
log.Error("dispatcher register failed after retried", zap.String("channel", position.ChannelName), zap.Error(err))
p.dispatcher.Deregister(p.vChannel)
return WrapErrRegDispather(err)
}

Expand Down
24 changes: 13 additions & 11 deletions pkg/mq/msgdispatcher/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,32 +82,32 @@ func NewClient(factory msgstream.Factory, role string, nodeID int64) Client {
}

func (c *client) Register(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error) {
start := time.Now()
vchannel := streamConfig.VChannel
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
pchannel := funcutil.ToPhysicalChannel(vchannel)
start := time.Now()

log := log.Ctx(ctx).With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))

c.managerMut.Lock(pchannel)
defer c.managerMut.Unlock(pchannel)

var manager DispatcherManager
manager, ok := c.managers.Get(pchannel)
if !ok {
manager = NewDispatcherManager(pchannel, c.role, c.nodeID, c.factory)
c.managers.Insert(pchannel, manager)
go manager.Run()
}

// Check if the consumer number limit has been reached.
limit := paramtable.Get().MQCfg.MaxDispatcherNumPerPchannel.GetAsInt()
if manager.NumConsumer() >= limit {
return nil, merr.WrapErrTooManyConsumers(vchannel, fmt.Sprintf("limit=%d", limit))
}

// Begin to register
ch, err := manager.Add(ctx, streamConfig)
if err != nil {
if manager.NumTarget() == 0 {
manager.Close()
c.managers.Remove(pchannel)
}
log.Error("register failed", zap.Error(err))
return nil, err
}
Expand All @@ -116,13 +116,15 @@ func (c *client) Register(ctx context.Context, streamConfig *StreamConfig) (<-ch
}

func (c *client) Deregister(vchannel string) {
pchannel := funcutil.ToPhysicalChannel(vchannel)
start := time.Now()
pchannel := funcutil.ToPhysicalChannel(vchannel)

c.managerMut.Lock(pchannel)
defer c.managerMut.Unlock(pchannel)

if manager, ok := c.managers.Get(pchannel); ok {
manager.Remove(vchannel)
if manager.NumTarget() == 0 {
if manager.NumTarget() == 0 && manager.NumConsumer() == 0 {
manager.Close()
c.managers.Remove(pchannel)
}
Expand All @@ -132,12 +134,12 @@ func (c *client) Deregister(vchannel string) {
}

func (c *client) Close() {
log := log.With(zap.String("role", c.role),
zap.Int64("nodeID", c.nodeID))
log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID))

c.managers.Range(func(pchannel string, manager DispatcherManager) bool {
c.managerMut.Lock(pchannel)
defer c.managerMut.Unlock(pchannel)

log.Info("close manager", zap.String("channel", pchannel))
c.managers.Remove(pchannel)
manager.Close()
Expand Down
65 changes: 42 additions & 23 deletions pkg/mq/msgdispatcher/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,48 +22,66 @@ import (
"math/rand"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"go.uber.org/atomic"

"github.com/milvus-io/milvus/pkg/mq/common"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

func TestClient(t *testing.T) {
client := NewClient(newMockFactory(), typeutil.ProxyRole, 1)
factory := newMockFactory()
client := NewClient(factory, typeutil.ProxyRole, 1)
assert.NotNil(t, client)
_, err := client.Register(context.Background(), NewStreamConfig("mock_vchannel_0", nil, common.SubscriptionPositionUnknown))
defer client.Close()

pchannel := fmt.Sprintf("by-dev-rootcoord-dml_%d", rand.Int63())

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
producer, err := newMockProducer(factory, pchannel)
assert.NoError(t, err)
go produceTimeTick(ctx, producer)

_, err = client.Register(ctx, NewStreamConfig(fmt.Sprintf("%s_v1", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
_, err = client.Register(context.Background(), NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))

_, err = client.Register(ctx, NewStreamConfig(fmt.Sprintf("%s_v2", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
assert.NotPanics(t, func() {
client.Deregister("mock_vchannel_0")
client.Close()
})

t.Run("with timeout ctx", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Millisecond)
defer cancel()
<-time.After(2 * time.Millisecond)

client := NewClient(newMockFactory(), typeutil.DataNodeRole, 1)
defer client.Close()
assert.NotNil(t, client)
_, err := client.Register(ctx, NewStreamConfig("mock_vchannel_1", nil, common.SubscriptionPositionUnknown))
assert.Error(t, err)
})

client.Deregister(fmt.Sprintf("%s_v1", pchannel))
client.Deregister(fmt.Sprintf("%s_v2", pchannel))
}

func TestClient_Concurrency(t *testing.T) {
client1 := NewClient(newMockFactory(), typeutil.ProxyRole, 1)
factory := newMockFactory()
client1 := NewClient(factory, typeutil.ProxyRole, 1)
assert.NotNil(t, client1)
defer client1.Close()

paramtable.Get().Save(paramtable.Get().MQCfg.TargetBufSize.Key, "65536")
defer paramtable.Get().Reset(paramtable.Get().MQCfg.TargetBufSize.Key)

paramtable.Get().Save(paramtable.Get().MQCfg.MaxDispatcherNumPerPchannel.Key, "65536")
defer paramtable.Get().Reset(paramtable.Get().MQCfg.MaxDispatcherNumPerPchannel.Key)

pchannel := fmt.Sprintf("by-dev-rootcoord-dml_%d", rand.Int63())

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
producer, err := newMockProducer(factory, pchannel)
assert.NoError(t, err)
go produceTimeTick(ctx, producer)
t.Logf("start to produce time tick to pchannel %s", pchannel)

wg := &sync.WaitGroup{}
const total = 100
deregisterCount := atomic.NewInt32(0)
for i := 0; i < total; i++ {
vchannel := fmt.Sprintf("mock-vchannel-%d-%d", i, rand.Int())
i := i
vchannel := fmt.Sprintf("%s_vchannel-%d-%d", pchannel, i, rand.Int())
wg.Add(1)
go func() {
_, err := client1.Register(context.Background(), NewStreamConfig(vchannel, nil, common.SubscriptionPositionUnknown))
Expand All @@ -76,7 +94,8 @@ func TestClient_Concurrency(t *testing.T) {
}()
}
wg.Wait()
expected := int(total - deregisterCount.Load())
// expected := int(total - deregisterCount.Load())
expected := 1

c := client1.(*client)
n := c.managers.Len()
Expand Down
Loading
Loading