diff --git a/pkg/kthena-router/datastore/store.go b/pkg/kthena-router/datastore/store.go index c85240aac..4b283d8c3 100644 --- a/pkg/kthena-router/datastore/store.go +++ b/pkg/kthena-router/datastore/store.go @@ -662,26 +662,29 @@ func (s *store) DecrPodOnFlightRequests(podName types.NamespacedName) { func (s *store) AddOrUpdateModelServer(ms *aiv1alpha1.ModelServer, pods sets.Set[types.NamespacedName]) error { name := utils.GetNamespaceName(ms) - var modelServerObj *modelServer - if value, ok := s.modelServer.Load(name); !ok { - modelServerObj = newModelServer(ms) - // New object — no concurrent access yet, safe to write without lock + + actual, loaded := s.modelServer.Load(name) + if !loaded { + newObj := newModelServer(ms) if len(pods) != 0 { - modelServerObj.pods = pods + newObj.pods = pods.Copy() } - } else { - modelServerObj = value.(*modelServer) + actual, loaded = s.modelServer.LoadOrStore(name, newObj) + } + + modelServerObj := actual.(*modelServer) + + if loaded { // Existing object — concurrent readers may access modelServer and pods, // so we must hold the lock to prevent data races. modelServerObj.mutex.Lock() modelServerObj.modelServer = ms if len(pods) != 0 { // do not operate s.pods here, which are done within pod handler - modelServerObj.pods = pods + modelServerObj.pods = pods.Copy() } modelServerObj.mutex.Unlock() } - s.modelServer.Store(name, modelServerObj) return nil } @@ -818,36 +821,38 @@ func (s *store) AddOrUpdatePod(pod *corev1.Pod, modelServers []*aiv1alpha1.Model } } - if value, ok := s.pods.Load(podName); ok { - // Update existing pod in place — preserve runtime metrics and models. - oldPodInfo := value.(*PodInfo) - oldModelServers := oldPodInfo.GetModelServers() - // Handle the case where the pod no longer belongs to some model servers - oldPodLabels := oldPodInfo.GetPodLabels() - for msName := range oldModelServers.Difference(newModelServers) { - if value, ok := s.modelServer.Load(msName); ok { - ms := value.(*modelServer) - ms.deletePod(podName) - // Remove from PDGroup categorizations - ms.removePodFromPDGroups(podName, oldPodLabels) - } + actual, loaded := s.pods.Load(podName) + if !loaded { + newPodInfo := &PodInfo{ + Pod: pod, + engine: engine, + modelServer: newModelServers, + models: sets.New[string](), + } + actual, loaded = s.pods.LoadOrStore(podName, newPodInfo) + if !loaded { + // We were the one that successfully stored the new pod + s.updatePodMetrics(newPodInfo) + s.updatePodModels(newPodInfo) + return nil } - - oldPodInfo.UpdatePod(pod, engine, newModelServers) - return nil } - // New pod — create PodInfo and fetch initial metrics. - newPodInfo := &PodInfo{ - Pod: pod, - engine: engine, - modelServer: newModelServers, - models: sets.New[string](), + // Update existing pod in place — preserve runtime metrics and models. + oldPodInfo := actual.(*PodInfo) + oldModelServers := oldPodInfo.GetModelServers() + // Handle the case where the pod no longer belongs to some model servers + oldPodLabels := oldPodInfo.GetPodLabels() + for msName := range oldModelServers.Difference(newModelServers) { + if value, ok := s.modelServer.Load(msName); ok { + ms := value.(*modelServer) + ms.deletePod(podName) + // Remove from PDGroup categorizations + ms.removePodFromPDGroups(podName, oldPodLabels) + } } - s.pods.Store(podName, newPodInfo) - s.updatePodMetrics(newPodInfo) - s.updatePodModels(newPodInfo) + oldPodInfo.UpdatePod(pod, engine, newModelServers) return nil } diff --git a/pkg/kthena-router/datastore/store_test.go b/pkg/kthena-router/datastore/store_test.go index e1edc4d89..23040f7a1 100644 --- a/pkg/kthena-router/datastore/store_test.go +++ b/pkg/kthena-router/datastore/store_test.go @@ -2243,3 +2243,81 @@ func TestMatchModelServer_GatewayScoped(t *testing.T) { }) } } + +func TestStoreAddOrUpdateModelServer_ConcurrentAccess(t *testing.T) { + s := &store{ + modelServer: sync.Map{}, + } + var wg sync.WaitGroup + errCh := make(chan error, 100) + // Simulate many concurrent goroutines trying to add the same ModelServer + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ms := &aiv1alpha1.ModelServer{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "default", + Name: "concurrent-model1", + }, + } + pods := sets.New[types.NamespacedName](types.NamespacedName{Namespace: "default", Name: "pod1"}) + errCh <- s.AddOrUpdateModelServer(ms, pods) + }() + } + wg.Wait() + close(errCh) + for err := range errCh { + assert.NoError(t, err) + } + + // Verify the model server is safely in the datastore and no panic occurred + msName := types.NamespacedName{Namespace: "default", Name: "concurrent-model1"} + value, ok := s.modelServer.Load(msName) + assert.True(t, ok) + msInfo := value.(*modelServer) + assert.NotNil(t, msInfo) + assert.True(t, msInfo.pods.Contains(types.NamespacedName{Namespace: "default", Name: "pod1"})) +} + +func TestStoreAddOrUpdatePod_ConcurrentAccess(t *testing.T) { + s := &store{ + modelServer: sync.Map{}, + pods: sync.Map{}, + } + var wg sync.WaitGroup + errCh := make(chan error, 100) + // Simulate many concurrent goroutines trying to add the same Pod + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "default", + Name: "concurrent-pod1", + }, + } + ms := &aiv1alpha1.ModelServer{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "default", + Name: "model1", + }, + } + errCh <- s.AddOrUpdatePod(pod, []*aiv1alpha1.ModelServer{ms}) + }() + } + wg.Wait() + close(errCh) + for err := range errCh { + assert.NoError(t, err) + } + + // Verify the pod is safely in the datastore and no panic occurred + podName := types.NamespacedName{Namespace: "default", Name: "concurrent-pod1"} + value, ok := s.pods.Load(podName) + assert.True(t, ok) + podInfo := value.(*PodInfo) + assert.NotNil(t, podInfo) + assert.True(t, podInfo.modelServer.Contains(types.NamespacedName{Namespace: "default", Name: "model1"})) +}