Skip to content

Commit ff9efd2

Browse files
support resource claim cel builder
1 parent 4959c61 commit ff9efd2

File tree

4 files changed

+147
-27
lines changed

4 files changed

+147
-27
lines changed

internal/constants/constants.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ const (
7474
// In remote vGPU mode, selected workload is set by user with /workload annotation or generated by system
7575
SelectedWorkloadAnnotation = Domain + "/selected-workload"
7676

77-
CELFilterExpressionAnnotation = Domain + "/cel-filter-expression"
78-
7977
WorkloadModeAnnotation = Domain + "/workload-mode"
8078
WorkloadModeDynamic = "dynamic"
8179
WorkloadModeFixed = "fixed"

internal/gpuallocator/gpuallocator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1355,7 +1355,7 @@ func (s *GpuAllocator) ComposeAllocationRequest(pod *v1.Pod) (*tfv1.AllocRequest
13551355
Limit: gpuLimitResource,
13561356

13571357
DisableCELFilter: disableCELFilter,
1358-
CELFilterExpression: pod.Annotations[constants.CELFilterExpressionAnnotation],
1358+
CELFilterExpression: pod.Annotations[constants.DRACelExpressionAnnotation],
13591359

13601360
Count: uint(count),
13611361
GPUModel: pod.Annotations[constants.GPUModelAnnotation],

internal/webhook/v1/pod_dra.go

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ func (p *DRAProcessor) HandleDRAAdmission(ctx context.Context, pod *corev1.Pod,
136136
return nil
137137
}
138138

139-
// TODO: support more attributes for filtering
139+
// BuildCELSelector constructs a CEL expression for DRA device selection based on TensorFusion requirements
140140
func BuildCELSelector(pod *corev1.Pod, tfInfo *utils.TensorFusionInfo) (string, error) {
141141
var conditions []string
142142

@@ -154,6 +154,28 @@ func BuildCELSelector(pod *corev1.Pod, tfInfo *utils.TensorFusionInfo) (string,
154154
conditions = append(conditions, fmt.Sprintf(`device.attributes["model"] == "%s"`, tfInfo.Profile.GPUModel))
155155
}
156156

157+
// 3. GPU count requirement (important for multi-GPU workloads)
158+
if tfInfo.Profile.GPUCount > 0 {
159+
conditions = append(conditions, fmt.Sprintf(`int(device.attributes["gpu_count"]) >= %d`, tfInfo.Profile.GPUCount))
160+
}
161+
162+
// 4. Pool name filter (for resource isolation and scheduling preferences)
163+
if tfInfo.Profile.PoolName != "" {
164+
conditions = append(conditions, fmt.Sprintf(`device.attributes["pool_name"] == "%s"`, tfInfo.Profile.PoolName))
165+
}
166+
167+
// 5. Workload name filter (for workload-specific device assignment)
168+
if tfInfo.WorkloadName != "" {
169+
conditions = append(conditions, fmt.Sprintf(`device.attributes["workload_name"] == "%s"`, tfInfo.WorkloadName))
170+
// Workload namespace is same as pod namespace in TensorFusion
171+
conditions = append(conditions, fmt.Sprintf(`device.attributes["workload_namespace"] == "%s"`, pod.Namespace))
172+
}
173+
174+
// 6. Pod namespace filter (for namespace-based device isolation)
175+
if pod.Namespace != "" {
176+
conditions = append(conditions, fmt.Sprintf(`device.attributes["pod_namespace"] == "%s"`, pod.Namespace))
177+
}
178+
157179
// Return a basic condition if no specific requirements
158180
if len(conditions) == 0 {
159181
// Simple condition that should work with most DRA drivers

internal/webhook/v1/pod_webhook_dra_test.go

Lines changed: 123 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -149,37 +149,137 @@ func TestDRAProcessor_HandleDRAAdmission(t *testing.T) {
149149
}
150150

151151
func TestBuildCELSelector(t *testing.T) {
152-
pod := &corev1.Pod{
153-
ObjectMeta: metav1.ObjectMeta{
154-
Name: "test-pod",
155-
Namespace: "test-namespace",
152+
tests := []struct {
153+
name string
154+
pod *corev1.Pod
155+
tfInfo *utils.TensorFusionInfo
156+
expectedConditions []string
157+
unexpectedConditions []string
158+
}{
159+
{
160+
name: "Basic resource filters",
161+
pod: &corev1.Pod{
162+
ObjectMeta: metav1.ObjectMeta{
163+
Name: "test-pod",
164+
Namespace: "test-namespace",
165+
},
166+
},
167+
tfInfo: &utils.TensorFusionInfo{
168+
Profile: &tfv1.WorkloadProfileSpec{
169+
GPUCount: 2,
170+
Resources: tfv1.Resources{
171+
Requests: tfv1.Resource{
172+
Tflops: resource.MustParse("20"),
173+
Vram: resource.MustParse("16Gi"),
174+
},
175+
},
176+
GPUModel: "H100",
177+
},
178+
},
179+
expectedConditions: []string{
180+
`device.attributes["tflops"].quantity >= quantity("20")`,
181+
`device.attributes["vram"].quantity >= quantity("16Gi")`,
182+
`device.attributes["model"] == "H100"`,
183+
`int(device.attributes["gpu_count"]) >= 2`,
184+
`device.attributes["pod_namespace"] == "test-namespace"`,
185+
},
156186
},
157-
}
158-
159-
tfInfo := &utils.TensorFusionInfo{
160-
Profile: &tfv1.WorkloadProfileSpec{
161-
GPUCount: 2,
162-
Resources: tfv1.Resources{
163-
Requests: tfv1.Resource{
164-
Tflops: resource.MustParse("20"),
165-
Vram: resource.MustParse("16Gi"),
187+
{
188+
name: "All filters including pool and workload",
189+
pod: &corev1.Pod{
190+
ObjectMeta: metav1.ObjectMeta{
191+
Name: "test-pod",
192+
Namespace: "production",
193+
},
194+
},
195+
tfInfo: &utils.TensorFusionInfo{
196+
Profile: &tfv1.WorkloadProfileSpec{
197+
GPUCount: 1,
198+
Resources: tfv1.Resources{
199+
Requests: tfv1.Resource{
200+
Tflops: resource.MustParse("10"),
201+
Vram: resource.MustParse("8Gi"),
202+
},
203+
},
204+
GPUModel: "A100",
205+
PoolName: "high-priority",
206+
},
207+
WorkloadName: "ml-training-job",
208+
},
209+
expectedConditions: []string{
210+
`device.attributes["tflops"].quantity >= quantity("10")`,
211+
`device.attributes["vram"].quantity >= quantity("8Gi")`,
212+
`device.attributes["model"] == "A100"`,
213+
`int(device.attributes["gpu_count"]) >= 1`,
214+
`device.attributes["pool_name"] == "high-priority"`,
215+
`device.attributes["workload_name"] == "ml-training-job"`,
216+
`device.attributes["workload_namespace"] == "production"`,
217+
`device.attributes["pod_namespace"] == "production"`,
218+
},
219+
},
220+
{
221+
name: "Zero resources fallback to default condition",
222+
pod: &corev1.Pod{
223+
ObjectMeta: metav1.ObjectMeta{
224+
Name: "test-pod",
225+
Namespace: "default",
226+
},
227+
},
228+
tfInfo: &utils.TensorFusionInfo{
229+
Profile: &tfv1.WorkloadProfileSpec{
230+
GPUCount: 0, // Zero count should not add condition
231+
Resources: tfv1.Resources{
232+
Requests: tfv1.Resource{
233+
// Zero resources
234+
},
235+
},
236+
},
237+
},
238+
expectedConditions: []string{
239+
`device.attributes["pod_namespace"] == "default"`,
240+
},
241+
},
242+
{
243+
name: "Empty resources fallback to basic condition",
244+
pod: &corev1.Pod{
245+
ObjectMeta: metav1.ObjectMeta{
246+
Name: "test-pod",
247+
Namespace: "",
248+
},
249+
},
250+
tfInfo: &utils.TensorFusionInfo{
251+
Profile: &tfv1.WorkloadProfileSpec{
252+
// All empty/zero values
166253
},
167254
},
168-
GPUModel: "H100",
255+
expectedConditions: []string{
256+
`device.attributes.exists("type")`,
257+
},
169258
},
170259
}
171260

172-
celExpression, err := BuildCELSelector(pod, tfInfo)
173-
require.NoError(t, err)
174-
require.NotEmpty(t, celExpression)
261+
for _, tt := range tests {
262+
t.Run(tt.name, func(t *testing.T) {
263+
celExpression, err := BuildCELSelector(tt.pod, tt.tfInfo)
264+
require.NoError(t, err)
265+
require.NotEmpty(t, celExpression)
266+
267+
// Verify expected conditions are present
268+
for _, condition := range tt.expectedConditions {
269+
assert.Contains(t, celExpression, condition, "Expected condition not found: %s", condition)
270+
}
175271

176-
// Verify it contains the expected resource filters
177-
assert.Contains(t, celExpression, `device.attributes["tflops"].quantity >= quantity("20")`)
178-
assert.Contains(t, celExpression, `device.attributes["vram"].quantity >= quantity("16Gi")`)
179-
assert.Contains(t, celExpression, `device.attributes["model"] == "H100"`)
272+
// Verify unexpected conditions are not present
273+
for _, condition := range tt.unexpectedConditions {
274+
assert.NotContains(t, celExpression, condition, "Unexpected condition found: %s", condition)
275+
}
180276

181-
// Verify conditions are combined with AND
182-
assert.Contains(t, celExpression, " && ")
277+
// Verify proper AND joining (unless it's the fallback condition)
278+
if len(tt.expectedConditions) > 1 {
279+
assert.Contains(t, celExpression, " && ", "Conditions should be joined with &&")
280+
}
281+
})
282+
}
183283
}
184284

185285
func TestHasDRAClaim(t *testing.T) {

0 commit comments

Comments
 (0)