Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 39 additions & 34 deletions pkg/kthena-router/datastore/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -662,26 +662,29 @@ func (s *store) DecrPodOnFlightRequests(podName types.NamespacedName) {

func (s *store) AddOrUpdateModelServer(ms *aiv1alpha1.ModelServer, pods sets.Set[types.NamespacedName]) error {
name := utils.GetNamespaceName(ms)
var modelServerObj *modelServer
if value, ok := s.modelServer.Load(name); !ok {
modelServerObj = newModelServer(ms)
// New object — no concurrent access yet, safe to write without lock

actual, loaded := s.modelServer.Load(name)
if !loaded {
newObj := newModelServer(ms)
if len(pods) != 0 {
modelServerObj.pods = pods
newObj.pods = pods.Copy()
}
} else {
modelServerObj = value.(*modelServer)
actual, loaded = s.modelServer.LoadOrStore(name, newObj)
Comment on lines +668 to +672
}

modelServerObj := actual.(*modelServer)

if loaded {
// Existing object — concurrent readers may access modelServer and pods,
// so we must hold the lock to prevent data races.
modelServerObj.mutex.Lock()
modelServerObj.modelServer = ms
if len(pods) != 0 {
// do not operate s.pods here, which are done within pod handler
modelServerObj.pods = pods
modelServerObj.pods = pods.Copy()
}
modelServerObj.mutex.Unlock()
}
s.modelServer.Store(name, modelServerObj)
return nil
}

Expand Down Expand Up @@ -818,36 +821,38 @@ func (s *store) AddOrUpdatePod(pod *corev1.Pod, modelServers []*aiv1alpha1.Model
}
}

if value, ok := s.pods.Load(podName); ok {
// Update existing pod in place — preserve runtime metrics and models.
oldPodInfo := value.(*PodInfo)
oldModelServers := oldPodInfo.GetModelServers()
// Handle the case where the pod no longer belongs to some model servers
oldPodLabels := oldPodInfo.GetPodLabels()
for msName := range oldModelServers.Difference(newModelServers) {
if value, ok := s.modelServer.Load(msName); ok {
ms := value.(*modelServer)
ms.deletePod(podName)
// Remove from PDGroup categorizations
ms.removePodFromPDGroups(podName, oldPodLabels)
}
actual, loaded := s.pods.Load(podName)
if !loaded {
newPodInfo := &PodInfo{
Pod: pod,
engine: engine,
modelServer: newModelServers,
models: sets.New[string](),
}
actual, loaded = s.pods.LoadOrStore(podName, newPodInfo)
if !loaded {
// We were the one that successfully stored the new pod
s.updatePodMetrics(newPodInfo)
s.updatePodModels(newPodInfo)
return nil
}

oldPodInfo.UpdatePod(pod, engine, newModelServers)
return nil
}

// New pod — create PodInfo and fetch initial metrics.
newPodInfo := &PodInfo{
Pod: pod,
engine: engine,
modelServer: newModelServers,
models: sets.New[string](),
// Update existing pod in place — preserve runtime metrics and models.
oldPodInfo := actual.(*PodInfo)
oldModelServers := oldPodInfo.GetModelServers()
// Handle the case where the pod no longer belongs to some model servers
oldPodLabels := oldPodInfo.GetPodLabels()
for msName := range oldModelServers.Difference(newModelServers) {
if value, ok := s.modelServer.Load(msName); ok {
ms := value.(*modelServer)
ms.deletePod(podName)
// Remove from PDGroup categorizations
ms.removePodFromPDGroups(podName, oldPodLabels)
}
}
s.pods.Store(podName, newPodInfo)
s.updatePodMetrics(newPodInfo)
s.updatePodModels(newPodInfo)

oldPodInfo.UpdatePod(pod, engine, newModelServers)
return nil
}

Expand Down
78 changes: 78 additions & 0 deletions pkg/kthena-router/datastore/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2243,3 +2243,81 @@ func TestMatchModelServer_GatewayScoped(t *testing.T) {
})
}
}

func TestStoreAddOrUpdateModelServer_ConcurrentAccess(t *testing.T) {
s := &store{
modelServer: sync.Map{},
}
var wg sync.WaitGroup
errCh := make(chan error, 100)
// Simulate many concurrent goroutines trying to add the same ModelServer
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
ms := &aiv1alpha1.ModelServer{
ObjectMeta: metav1.ObjectMeta{
Namespace: "default",
Name: "concurrent-model1",
},
}
pods := sets.New[types.NamespacedName](types.NamespacedName{Namespace: "default", Name: "pod1"})
errCh <- s.AddOrUpdateModelServer(ms, pods)
}()
Comment on lines +2256 to +2266
}
wg.Wait()
close(errCh)
for err := range errCh {
assert.NoError(t, err)
}

// Verify the model server is safely in the datastore and no panic occurred
msName := types.NamespacedName{Namespace: "default", Name: "concurrent-model1"}
value, ok := s.modelServer.Load(msName)
assert.True(t, ok)
msInfo := value.(*modelServer)
assert.NotNil(t, msInfo)
assert.True(t, msInfo.pods.Contains(types.NamespacedName{Namespace: "default", Name: "pod1"}))
}

func TestStoreAddOrUpdatePod_ConcurrentAccess(t *testing.T) {
s := &store{
modelServer: sync.Map{},
pods: sync.Map{},
}
var wg sync.WaitGroup
errCh := make(chan error, 100)
// Simulate many concurrent goroutines trying to add the same Pod
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Namespace: "default",
Name: "concurrent-pod1",
},
}
ms := &aiv1alpha1.ModelServer{
ObjectMeta: metav1.ObjectMeta{
Namespace: "default",
Name: "model1",
},
}
errCh <- s.AddOrUpdatePod(pod, []*aiv1alpha1.ModelServer{ms})
}()
Comment on lines +2293 to +2308
}
wg.Wait()
close(errCh)
for err := range errCh {
assert.NoError(t, err)
}

// Verify the pod is safely in the datastore and no panic occurred
podName := types.NamespacedName{Namespace: "default", Name: "concurrent-pod1"}
value, ok := s.pods.Load(podName)
assert.True(t, ok)
podInfo := value.(*PodInfo)
assert.NotNil(t, podInfo)
assert.True(t, podInfo.modelServer.Contains(types.NamespacedName{Namespace: "default", Name: "model1"}))
}
Loading