Skip to content

Commit 50bcff3

Browse files
authoredApr 2, 2024··
RUN-16744 Support KWOK (#70)
1 parent 6f72a74 commit 50bcff3

File tree

22 files changed

+778
-479
lines changed

22 files changed

+778
-479
lines changed
 

‎cmd/device-plugin/main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func main() {
4141
initNvidiaSmi()
4242
initPreloaders()
4343

44-
devicePlugin := deviceplugin.NewDevicePlugin(topology)
44+
devicePlugin := deviceplugin.NewDevicePlugin(topology, kubeClient)
4545
if err = devicePlugin.Serve(); err != nil {
4646
log.Printf("Failed to serve device plugin: %s\n", err)
4747
os.Exit(1)

‎cmd/status-updater/main.go

+4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@ package main
22

33
import (
44
"github.com/run-ai/fake-gpu-operator/internal/common/app"
5+
"github.com/run-ai/fake-gpu-operator/internal/common/config"
56
status_updater "github.com/run-ai/fake-gpu-operator/internal/status-updater"
67
)
78

89
func main() {
10+
requiredEnvVars := []string{"TOPOLOGY_CM_NAME", "TOPOLOGY_CM_NAMESPACE", "FAKE_GPU_OPERATOR_NAMESPACE"}
11+
config.ValidateConfig(requiredEnvVars)
12+
913
appRunner := app.NewAppRunner(&status_updater.StatusUpdaterApp{})
1014
appRunner.Run()
1115
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
{{- define "fake-gpu-operator.device-plugin.common.metadata.labels" -}}
2+
app: device-plugin
3+
{{- end -}}
4+
5+
{{- define "fake-gpu-operator.device-plugin.common.metadata.annotations" -}}
6+
openshift.io/scc: hostmount-anyuid
7+
{{- end -}}
8+
9+
{{- define "fake-gpu-operator.device-plugin.common.metadata.name" -}}
10+
device-plugin
11+
{{- end -}}
12+
13+
{{- define "fake-gpu-operator.device-plugin.common.podSelector" }}
14+
matchLabels:
15+
app: device-plugin
16+
component: device-plugin
17+
{{- end }}
18+
19+
{{- define "fake-gpu-operator.device-plugin.common.podTemplate.metadata" }}
20+
annotations:
21+
checksum/initialTopology: {{ include (print $.Template.BasePath "/topology-cm.yml") . | sha256sum }}
22+
labels:
23+
app: device-plugin
24+
component: device-plugin
25+
{{- end }}
26+
27+
{{- define "fake-gpu-operator.device-plugin.common.podTemplate.spec" }}
28+
containers:
29+
- image: "{{ .Values.devicePlugin.image.repository }}:{{ .Values.devicePlugin.image.tag }}"
30+
imagePullPolicy: "{{ .Values.devicePlugin.image.pullPolicy }}"
31+
resources:
32+
{{- toYaml .Values.devicePlugin.resources | nindent 12 }}
33+
env:
34+
- name: NODE_NAME
35+
valueFrom:
36+
fieldRef:
37+
fieldPath: spec.nodeName
38+
- name: TOPOLOGY_CM_NAME
39+
value: topology
40+
- name: TOPOLOGY_CM_NAMESPACE
41+
value: "{{ .Release.Namespace }}"
42+
name: nvidia-device-plugin-ctr
43+
securityContext:
44+
privileged: true
45+
terminationMessagePath: /dev/termination-log
46+
terminationMessagePolicy: File
47+
volumeMounts:
48+
- mountPath: /runai/bin
49+
name: runai-bin-directory
50+
- mountPath: /runai/shared
51+
name: runai-shared-directory
52+
- mountPath: /var/lib/kubelet/device-plugins
53+
name: device-plugin
54+
dnsPolicy: ClusterFirst
55+
restartPolicy: Always
56+
serviceAccountName: nvidia-device-plugin
57+
terminationGracePeriodSeconds: 30
58+
tolerations:
59+
- effect: NoSchedule
60+
key: nvidia.com/gpu
61+
operator: Exists
62+
imagePullSecrets:
63+
- name: gcr-secret
64+
volumes:
65+
- hostPath:
66+
path: /var/lib/kubelet/device-plugins
67+
type: ""
68+
name: device-plugin
69+
- hostPath:
70+
path: /var/lib/runai/bin
71+
type: DirectoryOrCreate
72+
name: runai-bin-directory
73+
- hostPath:
74+
path: /var/lib/runai/shared
75+
type: DirectoryOrCreate
76+
name: runai-shared-directory
77+
{{- end }}
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,16 @@
11
apiVersion: apps/v1
22
kind: DaemonSet
33
metadata:
4-
{{- if .Values.environment.openshift }}
5-
annotations:
6-
openshift.io/scc: hostmount-anyuid
7-
{{- end }}
4+
name: {{ include "fake-gpu-operator.device-plugin.common.metadata.name" . }}
85
labels:
9-
app: device-plugin
10-
name: device-plugin
6+
{{- include "fake-gpu-operator.device-plugin.common.metadata.labels" . | nindent 4 }}
117
spec:
128
selector:
13-
matchLabels:
14-
app: device-plugin
15-
component: device-plugin
9+
{{- include "fake-gpu-operator.device-plugin.common.podSelector" . | nindent 4 }}
1610
template:
1711
metadata:
18-
annotations:
19-
checksum/initialTopology: {{ include (print $.Template.BasePath "/topology-cm.yml") . | sha256sum }}
20-
labels:
21-
app: device-plugin
22-
component: device-plugin
12+
{{- include "fake-gpu-operator.device-plugin.common.podTemplate.metadata" . | nindent 6 }}
2313
spec:
24-
containers:
25-
- image: "{{ .Values.devicePlugin.image.repository }}:{{ .Values.devicePlugin.image.tag }}"
26-
imagePullPolicy: "{{ .Values.devicePlugin.image.pullPolicy }}"
27-
resources:
28-
{{- toYaml .Values.devicePlugin.resources | nindent 12 }}
29-
env:
30-
- name: NODE_NAME
31-
valueFrom:
32-
fieldRef:
33-
fieldPath: spec.nodeName
34-
- name: TOPOLOGY_CM_NAME
35-
value: topology
36-
- name: TOPOLOGY_CM_NAMESPACE
37-
value: "{{ .Release.Namespace }}"
38-
imagePullPolicy: Always
39-
name: nvidia-device-plugin-ctr
40-
securityContext:
41-
privileged: true
42-
terminationMessagePath: /dev/termination-log
43-
terminationMessagePolicy: File
44-
volumeMounts:
45-
- mountPath: /runai/bin
46-
name: runai-bin-directory
47-
- mountPath: /runai/shared
48-
name: runai-shared-directory
49-
- mountPath: /var/lib/kubelet/device-plugins
50-
name: device-plugin
51-
dnsPolicy: ClusterFirst
14+
{{- include "fake-gpu-operator.device-plugin.common.podTemplate.spec" . | nindent 6 }}
5215
nodeSelector:
5316
nvidia.com/gpu.deploy.device-plugin: "true"
54-
restartPolicy: Always
55-
serviceAccountName: nvidia-device-plugin
56-
terminationGracePeriodSeconds: 30
57-
tolerations:
58-
- effect: NoSchedule
59-
key: nvidia.com/gpu
60-
operator: Exists
61-
imagePullSecrets:
62-
- name: gcr-secret
63-
volumes:
64-
- hostPath:
65-
path: /var/lib/kubelet/device-plugins
66-
type: ""
67-
name: device-plugin
68-
- hostPath:
69-
path: /var/lib/runai/bin
70-
type: DirectoryOrCreate
71-
name: runai-bin-directory
72-
- hostPath:
73-
path: /var/lib/runai/shared
74-
type: DirectoryOrCreate
75-
name: runai-shared-directory
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
apiVersion: apps/v1
2+
kind: Deployment
3+
metadata:
4+
name: {{ include "fake-gpu-operator.device-plugin.common.metadata.name" . }}
5+
labels:
6+
{{- include "fake-gpu-operator.device-plugin.common.metadata.labels" . | nindent 4 }}
7+
run.ai/fake-node-deployment-template: "true"
8+
spec:
9+
replicas: 0
10+
selector:
11+
{{- include "fake-gpu-operator.device-plugin.common.podSelector" . | nindent 4 }}
12+
template:
13+
metadata:
14+
{{- include "fake-gpu-operator.device-plugin.common.podTemplate.metadata" . | nindent 6 }}
15+
spec:
16+
{{- include "fake-gpu-operator.device-plugin.common.podTemplate.spec" . | nindent 6 }}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
{{- define "fake-gpu-operator.status-exporter.common.metadata.labels" -}}
2+
app: nvidia-dcgm-exporter
3+
component: status-exporter
4+
app.kubernetes.io/name: nvidia-container-toolkit
5+
{{- end -}}
6+
7+
{{- define "fake-gpu-operator.status-exporter.common.metadata.name" -}}
8+
nvidia-dcgm-exporter
9+
{{- end -}}
10+
11+
{{- define "fake-gpu-operator.status-exporter.common.podSelector" -}}
12+
matchLabels:
13+
app: nvidia-dcgm-exporter
14+
{{- end -}}
15+
16+
{{- define "fake-gpu-operator.status-exporter.common.podTemplate.metadata" -}}
17+
labels:
18+
app: nvidia-dcgm-exporter
19+
app.kubernetes.io/name: nvidia-container-toolkit
20+
{{- end -}}
21+
22+
{{- define "fake-gpu-operator.status-exporter.common.podTemplate.spec" -}}
23+
containers:
24+
- image: "{{ .Values.statusExporter.image.repository }}:{{ .Values.statusExporter.image.tag }}"
25+
imagePullPolicy: "{{ .Values.statusExporter.image.pullPolicy }}"
26+
resources:
27+
{{- toYaml .Values.statusExporter.resources | nindent 8 }}
28+
name: nvidia-dcgm-exporter
29+
env:
30+
- name: NODE_NAME
31+
valueFrom:
32+
fieldRef:
33+
fieldPath: spec.nodeName
34+
- name: TOPOLOGY_CM_NAME
35+
value: topology
36+
- name: TOPOLOGY_CM_NAMESPACE
37+
value: "{{ .Release.Namespace }}"
38+
- name: TOPOLOGY_MAX_EXPORT_INTERVAL
39+
value: "{{ .Values.statusExporter.topologyMaxExportInterval }}"
40+
ports:
41+
- containerPort: 9400
42+
name: http
43+
volumeMounts:
44+
- mountPath: /runai/proc
45+
name: runai-proc-directory
46+
restartPolicy: Always
47+
schedulerName: default-scheduler
48+
serviceAccount: status-exporter
49+
serviceAccountName: status-exporter
50+
tolerations:
51+
- effect: NoSchedule
52+
key: nvidia.com/gpu
53+
operator: Exists
54+
imagePullSecrets:
55+
- name: gcr-secret
56+
volumes:
57+
- name: runai-proc-directory
58+
hostPath:
59+
path: /var/lib/runai/proc
60+
type: DirectoryOrCreate
61+
{{- end -}}
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,16 @@
11
apiVersion: apps/v1
22
kind: DaemonSet
33
metadata:
4+
name: {{ include "fake-gpu-operator.status-exporter.common.metadata.name" . }}
45
labels:
5-
app: nvidia-dcgm-exporter
6-
component: status-exporter
7-
# this label would make the deployment pod to mimic the container-toolkit, on top of mimicking the dcgm-exporter.
8-
app.kubernetes.io/name: nvidia-container-toolkit
9-
name: nvidia-dcgm-exporter
10-
6+
{{- include "fake-gpu-operator.status-exporter.common.metadata.labels" . | nindent 4 }}
117
spec:
128
selector:
13-
matchLabels:
14-
app: nvidia-dcgm-exporter
9+
{{- include "fake-gpu-operator.status-exporter.common.podSelector" . | nindent 4 }}
1510
template:
1611
metadata:
17-
creationTimestamp: null
18-
labels:
19-
app: nvidia-dcgm-exporter
20-
app.kubernetes.io/name: nvidia-container-toolkit
12+
{{- include "fake-gpu-operator.status-exporter.common.podTemplate.metadata" . | nindent 6 }}
2113
spec:
22-
containers:
23-
- image: "{{ .Values.statusExporter.image.repository }}:{{ .Values.statusExporter.image.tag }}"
24-
imagePullPolicy: "{{ .Values.statusExporter.image.pullPolicy }}"
25-
resources:
26-
{{- toYaml .Values.statusExporter.resources | nindent 12 }}
27-
name: nvidia-dcgm-exporter
28-
env:
29-
- name: NODE_NAME
30-
valueFrom:
31-
fieldRef:
32-
fieldPath: spec.nodeName
33-
- name: TOPOLOGY_CM_NAME
34-
value: topology
35-
- name: TOPOLOGY_CM_NAMESPACE
36-
value: "{{ .Release.Namespace }}"
37-
- name: TOPOLOGY_MAX_EXPORT_INTERVAL
38-
value: "{{ .Values.statusExporter.topologyMaxExportInterval }}"
39-
ports:
40-
- containerPort: 9400
41-
name: http
42-
volumeMounts:
43-
- mountPath: /runai/proc
44-
name: runai-proc-directory
14+
{{- include "fake-gpu-operator.status-exporter.common.podTemplate.spec" . | nindent 6 }}
4515
nodeSelector:
46-
nvidia.com/gpu.deploy.dcgm-exporter: "true"
47-
restartPolicy: Always
48-
schedulerName: default-scheduler
49-
serviceAccount: status-exporter
50-
serviceAccountName: status-exporter
51-
tolerations:
52-
- effect: NoSchedule
53-
key: nvidia.com/gpu
54-
operator: Exists
55-
imagePullSecrets:
56-
- name: gcr-secret
57-
volumes:
58-
- name: runai-proc-directory
59-
hostPath:
60-
path: /var/lib/runai/proc
61-
type: DirectoryOrCreate
62-
updateStrategy:
63-
rollingUpdate:
64-
maxSurge: 0
65-
maxUnavailable: 1
66-
type: RollingUpdate
16+
nvidia.com/gpu.deploy.dcgm-exporter: "true"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
apiVersion: apps/v1
2+
kind: Deployment
3+
metadata:
4+
name: {{ include "fake-gpu-operator.status-exporter.common.metadata.name" . }}
5+
labels:
6+
{{- include "fake-gpu-operator.status-exporter.common.metadata.labels" . | nindent 4 }}
7+
run.ai/fake-node-deployment-template: "true"
8+
spec:
9+
replicas: 0
10+
selector:
11+
{{- include "fake-gpu-operator.status-exporter.common.podSelector" . | nindent 4 }}
12+
template:
13+
metadata:
14+
{{- include "fake-gpu-operator.status-exporter.common.podTemplate.metadata" . | nindent 6 }}
15+
spec:
16+
{{- include "fake-gpu-operator.status-exporter.common.podTemplate.spec" . | nindent 6 }}

‎deploy/fake-gpu-operator/templates/status-updater/clusterrole.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ rules:
1212
- get
1313
- list
1414
- watch
15+
- patch
1516
- apiGroups:
1617
- ""
1718
resources:

‎deploy/fake-gpu-operator/templates/status-updater/deployment.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ spec:
2929
value: topology
3030
- name: TOPOLOGY_CM_NAMESPACE
3131
value: "{{ .Release.Namespace }}"
32+
- name: FAKE_GPU_OPERATOR_NAMESPACE
33+
value: "{{ .Release.Namespace }}"
3234
restartPolicy: Always
3335
serviceAccountName: status-updater
3436
imagePullSecrets:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
apiVersion: rbac.authorization.k8s.io/v1
2+
kind: Role
3+
metadata:
4+
name: fake-status-updater
5+
rules:
6+
- apiGroups:
7+
- apps
8+
resources:
9+
- deployments
10+
verbs:
11+
- update
12+
- list
13+
- get
14+
- watch
15+
- create
16+
- delete
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
apiVersion: rbac.authorization.k8s.io/v1
2+
kind: RoleBinding
3+
metadata:
4+
name: fake-status-updater
5+
roleRef:
6+
kind: Role
7+
apiGroup: rbac.authorization.k8s.io
8+
name: fake-status-updater
9+
subjects:
10+
- kind: ServiceAccount
11+
name: status-updater
12+
namespace: "{{ .Release.Namespace }}"

‎go.mod

+34-35
Original file line numberDiff line numberDiff line change
@@ -10,96 +10,95 @@ require (
1010
github.com/onsi/ginkgo/v2 v2.17.1
1111
github.com/onsi/gomega v1.30.0
1212
github.com/otiai10/copy v1.7.0
13-
github.com/prometheus/client_golang v1.14.0
14-
github.com/prometheus/client_model v0.3.0
13+
github.com/prometheus/client_golang v1.18.0
14+
github.com/prometheus/client_model v0.5.0
1515
github.com/spf13/viper v1.14.0
1616
github.com/tidwall/gjson v1.14.1
1717
golang.org/x/net v0.20.0
18-
google.golang.org/grpc v1.56.3
18+
google.golang.org/grpc v1.58.3
1919
gopkg.in/yaml.v3 v3.0.1
20-
k8s.io/api v0.26.0
21-
k8s.io/apimachinery v0.26.0
22-
k8s.io/client-go v0.26.0
20+
k8s.io/api v0.29.3
21+
k8s.io/apimachinery v0.29.3
22+
k8s.io/client-go v0.29.3
2323
k8s.io/kubelet v0.24.0
24-
sigs.k8s.io/controller-runtime v0.14.1
24+
sigs.k8s.io/controller-runtime v0.17.2
2525
)
2626

2727
require (
28-
github.com/emicklei/go-restful/v3 v3.9.0 // indirect
29-
github.com/evanphx/json-patch/v5 v5.6.0 // indirect
28+
github.com/emicklei/go-restful/v3 v3.11.0 // indirect
29+
github.com/evanphx/json-patch/v5 v5.8.0 // indirect
3030
github.com/go-playground/locales v0.14.0 // indirect
3131
github.com/go-playground/universal-translator v0.18.0 // indirect
3232
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
3333
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
34-
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
34+
github.com/google/gnostic-models v0.6.8 // indirect
35+
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 // indirect
3536
github.com/hashicorp/errwrap v1.0.0 // indirect
3637
github.com/hashicorp/hcl v1.0.0 // indirect
3738
github.com/imdario/mergo v0.3.6 // indirect
3839
github.com/leodido/go-urn v1.2.1 // indirect
3940
github.com/magiconair/properties v1.8.6 // indirect
41+
github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect
4042
github.com/pelletier/go-toml v1.9.5 // indirect
4143
github.com/pelletier/go-toml/v2 v2.0.5 // indirect
4244
github.com/pmezard/go-difflib v1.0.0 // indirect
43-
github.com/rogpeppe/go-internal v1.8.0 // indirect
4445
github.com/spf13/afero v1.9.2 // indirect
4546
github.com/spf13/cast v1.5.0 // indirect
4647
github.com/spf13/jwalterweatherman v1.1.0 // indirect
4748
github.com/spf13/pflag v1.0.5 // indirect
4849
github.com/subosito/gotenv v1.4.1 // indirect
50+
golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect
4951
golang.org/x/tools v0.17.0 // indirect
50-
gomodules.xyz/jsonpatch/v2 v2.2.0 // indirect
51-
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
52+
gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect
53+
google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d // indirect
5254
gopkg.in/go-playground/assert.v1 v1.2.1 // indirect
5355
gopkg.in/ini.v1 v1.67.0 // indirect
5456
gopkg.in/yaml.v2 v2.4.0 // indirect
55-
k8s.io/apiextensions-apiserver v0.26.0 // indirect
56-
k8s.io/component-base v0.26.0 // indirect
57+
k8s.io/apiextensions-apiserver v0.29.0 // indirect
58+
k8s.io/component-base v0.29.0 // indirect
5759
)
5860

5961
require (
6062
github.com/beorn7/perks v1.0.1 // indirect
6163
github.com/cespare/xxhash/v2 v2.2.0 // indirect
6264
github.com/davecgh/go-spew v1.1.1 // indirect
6365
github.com/evanphx/json-patch v4.12.0+incompatible // indirect
64-
github.com/fsnotify/fsnotify v1.6.0 // indirect
66+
github.com/fsnotify/fsnotify v1.7.0 // indirect
6567
github.com/go-logr/logr v1.4.1 // indirect
66-
github.com/go-openapi/jsonpointer v0.19.5 // indirect
67-
github.com/go-openapi/jsonreference v0.20.0 // indirect
68-
github.com/go-openapi/swag v0.19.14 // indirect
68+
github.com/go-openapi/jsonpointer v0.19.6 // indirect
69+
github.com/go-openapi/jsonreference v0.20.2 // indirect
70+
github.com/go-openapi/swag v0.22.3 // indirect
6971
github.com/go-playground/validator v9.31.0+incompatible
7072
github.com/gogo/protobuf v1.3.2 // indirect
71-
github.com/golang/protobuf v1.5.3 // indirect
72-
github.com/google/gnostic v0.5.7-v3refs // indirect
73+
github.com/golang/protobuf v1.5.4 // indirect
7374
github.com/google/go-cmp v0.6.0 // indirect
74-
github.com/google/gofuzz v1.1.0 // indirect
75+
github.com/google/gofuzz v1.2.0 // indirect
7576
github.com/josharian/intern v1.0.0 // indirect
7677
github.com/json-iterator/go v1.1.12 // indirect
77-
github.com/mailru/easyjson v0.7.6 // indirect
78+
github.com/mailru/easyjson v0.7.7 // indirect
7879
github.com/mattn/go-runewidth v0.0.13 // indirect
79-
github.com/matttproud/golang_protobuf_extensions v1.0.2 // indirect
8080
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
8181
github.com/modern-go/reflect2 v1.0.2 // indirect
8282
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
8383
github.com/pkg/errors v0.9.1 // indirect
84-
github.com/prometheus/common v0.37.0 // indirect
85-
github.com/prometheus/procfs v0.8.0 // indirect
84+
github.com/prometheus/common v0.45.0 // indirect
85+
github.com/prometheus/procfs v0.12.0 // indirect
8686
github.com/rivo/uniseg v0.2.0 // indirect
87-
github.com/stretchr/testify v1.8.1
87+
github.com/stretchr/testify v1.8.4
8888
github.com/tidwall/match v1.1.1 // indirect
8989
github.com/tidwall/pretty v1.2.0 // indirect
90-
golang.org/x/oauth2 v0.7.0 // indirect
90+
golang.org/x/oauth2 v0.12.0 // indirect
9191
golang.org/x/sys v0.16.0 // indirect
9292
golang.org/x/term v0.16.0 // indirect
9393
golang.org/x/text v0.14.0 // indirect
9494
golang.org/x/time v0.3.0 // indirect
9595
google.golang.org/appengine v1.6.7 // indirect
96-
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect
97-
google.golang.org/protobuf v1.30.0 // indirect
96+
google.golang.org/protobuf v1.33.0 // indirect
9897
gopkg.in/inf.v0 v0.9.1 // indirect
99-
k8s.io/klog/v2 v2.80.1 // indirect
100-
k8s.io/kube-openapi v0.0.0-20221012153701-172d655c2280 // indirect
98+
k8s.io/klog/v2 v2.110.1 // indirect
99+
k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 // indirect
101100
k8s.io/utils v0.0.0-20240310230437-4693a0247e57
102-
sigs.k8s.io/json v0.0.0-20220713155537-f223a00ba0e2 // indirect
103-
sigs.k8s.io/structured-merge-diff/v4 v4.2.3 // indirect
104-
sigs.k8s.io/yaml v1.3.0 // indirect
101+
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect
102+
sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect
103+
sigs.k8s.io/yaml v1.4.0 // indirect
105104
)

‎go.sum

+82-82
Large diffs are not rendered by default.

‎internal/common/constants/constants.go

+12-3
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,21 @@ const (
66
PodGroupNameAnnotation = "pod-group-name"
77
ReservationPodGpuIdxAnnotation = "run.ai/reserve_for_gpu_index"
88
MigMappingAnnotation = "run.ai/mig-mapping"
9+
KwokNodeAnnotation = "kwok.x-k8s.io/node"
910

10-
GpuGroupLabel = "runai-gpu-group"
11-
GpuProductLabel = "nvidia.com/gpu.product"
12-
MigConfigStateLabel = "nvidia.com/mig.config.state"
11+
GpuGroupLabel = "runai-gpu-group"
12+
GpuProductLabel = "nvidia.com/gpu.product"
13+
MigConfigStateLabel = "nvidia.com/mig.config.state"
14+
FakeNodeDeploymentTemplateLabel = "run.ai/fake-node-deployment-template"
1315

1416
ReservationNs = "runai-reservation"
1517

1618
GpuResourceName = "nvidia.com/gpu"
19+
20+
// GuyTodo: Use these constants in the code
21+
EnvFakeNode = "FAKE_NODE"
22+
EnvNodeName = "NODE_NAME"
23+
EnvTopologyCmName = "TOPOLOGY_CM_NAME"
24+
EnvTopologyCmNamespace = "TOPOLOGY_CM_NAMESPACE"
25+
EnvFakeGpuOperatorNs = "FAKE_GPU_OPERATOR_NAMESPACE"
1726
)

‎internal/common/kubeclient/kubeclient.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"k8s.io/apimachinery/pkg/watch"
1111
"k8s.io/client-go/kubernetes"
1212
"k8s.io/client-go/rest"
13+
14+
ctrl "sigs.k8s.io/controller-runtime"
1315
)
1416

1517
type KubeClientInterface interface {
@@ -28,11 +30,12 @@ type KubeClient struct {
2830
func NewKubeClient(config *rest.Config, stop chan struct{}) *KubeClient {
2931
if config == nil {
3032
var err error
31-
config, err = rest.InClusterConfig()
33+
config, err = ctrl.GetConfig()
3234
if err != nil {
3335
log.Fatalf("Error getting in cluster config to init kubeclient: %e", err)
3436
}
3537
}
38+
3639
clientset := kubernetes.NewForConfigOrDie(config)
3740
return &KubeClient{
3841
ClientSet: clientset,
+13-208
Original file line numberDiff line numberDiff line change
@@ -1,229 +1,34 @@
11
package deviceplugin
22

33
import (
4-
"fmt"
5-
"log"
6-
"net"
7-
"os"
8-
"path"
9-
"strings"
10-
"time"
11-
12-
"github.com/google/uuid"
4+
"github.com/run-ai/fake-gpu-operator/internal/common/constants"
135
"github.com/run-ai/fake-gpu-operator/internal/common/topology"
14-
"golang.org/x/net/context"
15-
"google.golang.org/grpc"
16-
"google.golang.org/grpc/credentials/insecure"
17-
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
6+
"github.com/spf13/viper"
7+
"k8s.io/client-go/kubernetes"
188
)
199

2010
const (
2111
resourceName = "nvidia.com/gpu"
22-
serverSock = pluginapi.DevicePluginPath + "fake-nvidia-gpu.sock"
2312
)
2413

25-
type DevicePlugin struct {
26-
devs []*pluginapi.Device
27-
socket string
28-
29-
stop chan interface{}
30-
health chan *pluginapi.Device
31-
server *grpc.Server
14+
type Interface interface {
15+
Serve() error
3216
}
3317

34-
func NewDevicePlugin(topology *topology.NodeTopology) *DevicePlugin {
18+
func NewDevicePlugin(topology *topology.NodeTopology, kubeClient kubernetes.Interface) Interface {
3519
if topology == nil {
3620
panic("topology is nil")
3721
}
3822

39-
return &DevicePlugin{
40-
devs: createDevices(getGpuCount(topology)),
41-
socket: serverSock,
42-
}
43-
}
44-
45-
func getGpuCount(nodeTopology *topology.NodeTopology) int {
46-
return len(nodeTopology.Gpus)
47-
}
48-
49-
func (m *DevicePlugin) GetDevicePluginOptions(context.Context, *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) {
50-
return &pluginapi.DevicePluginOptions{}, nil
51-
}
52-
53-
func dial(unixSocketPath string, timeout time.Duration) (*grpc.ClientConn, error) {
54-
ctx := context.Background()
55-
ctx, cancel := context.WithTimeout(ctx, timeout)
56-
defer cancel()
57-
58-
c, err := grpc.DialContext(
59-
ctx,
60-
unixSocketPath,
61-
grpc.WithTransportCredentials(insecure.NewCredentials()),
62-
grpc.WithBlock(),
63-
grpc.WithContextDialer(func(_ context.Context, addr string) (net.Conn, error) {
64-
return net.DialTimeout("unix", addr, timeout)
65-
}),
66-
)
67-
68-
if err != nil {
69-
return nil, err
70-
}
71-
72-
return c, nil
73-
}
74-
75-
func createDevices(devCount int) []*pluginapi.Device {
76-
var devs []*pluginapi.Device
77-
for i := 0; i < devCount; i++ {
78-
u, _ := uuid.NewRandom()
79-
devs = append(devs, &pluginapi.Device{
80-
ID: u.String(),
81-
Health: pluginapi.Healthy,
82-
})
83-
}
84-
return devs
85-
}
86-
87-
func (m *DevicePlugin) Start() error {
88-
err := m.cleanup()
89-
if err != nil {
90-
return err
91-
}
92-
93-
sock, err := net.Listen("unix", m.socket)
94-
if err != nil {
95-
return err
96-
}
97-
98-
m.server = grpc.NewServer([]grpc.ServerOption{}...)
99-
pluginapi.RegisterDevicePluginServer(m.server, m)
100-
101-
go func() {
102-
err := m.server.Serve(sock)
103-
if err != nil {
104-
log.Println(err)
105-
}
106-
}()
107-
108-
// Wait for server to start by launching a blocking connexion
109-
conn, err := dial(m.socket, 5*time.Second)
110-
if err != nil {
111-
return err
112-
}
113-
conn.Close()
114-
115-
return nil
116-
}
117-
118-
func (m *DevicePlugin) Stop() error {
119-
if m.server == nil {
120-
return nil
121-
}
122-
123-
m.server.Stop()
124-
m.server = nil
125-
close(m.stop)
126-
127-
return m.cleanup()
128-
}
129-
130-
func (m *DevicePlugin) Register(kubeletEndpoint, resourceName string) error {
131-
conn, err := dial(kubeletEndpoint, 5*time.Second)
132-
if err != nil {
133-
return err
134-
}
135-
defer conn.Close()
136-
137-
client := pluginapi.NewRegistrationClient(conn)
138-
reqt := &pluginapi.RegisterRequest{
139-
Version: pluginapi.Version,
140-
Endpoint: path.Base(m.socket),
141-
ResourceName: resourceName,
142-
}
143-
144-
_, err = client.Register(context.Background(), reqt)
145-
if err != nil {
146-
return err
147-
}
148-
return nil
149-
}
150-
151-
func (m *DevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error {
152-
err := s.Send(&pluginapi.ListAndWatchResponse{Devices: m.devs})
153-
if err != nil {
154-
fmt.Printf("Failed to send devices to Kubelet: %v\n", err)
155-
}
156-
157-
for {
158-
select {
159-
case <-m.stop:
160-
return nil
161-
case d := <-m.health:
162-
// FIXME: there is no way to recover from the Unhealthy state.
163-
d.Health = pluginapi.Unhealthy
164-
err := s.Send(&pluginapi.ListAndWatchResponse{Devices: m.devs})
165-
if err != nil {
166-
log.Printf("failed to send unhealthy update: %v", err)
167-
}
23+
if viper.GetBool(constants.EnvFakeNode) {
24+
return &FakeNodeDevicePlugin{
25+
kubeClient: kubeClient,
26+
gpuCount: getGpuCount(topology),
16827
}
16928
}
170-
}
171-
172-
func (m *DevicePlugin) GetPreferredAllocation(context.Context, *pluginapi.PreferredAllocationRequest) (*pluginapi.PreferredAllocationResponse, error) {
173-
return &pluginapi.PreferredAllocationResponse{}, nil
174-
}
175-
176-
func (m *DevicePlugin) Allocate(ctx context.Context, reqs *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) {
177-
responses := pluginapi.AllocateResponse{}
178-
for _, req := range reqs.ContainerRequests {
179-
response := pluginapi.ContainerAllocateResponse{
180-
Envs: map[string]string{
181-
"MOCK_NVIDIA_VISIBLE_DEVICES": strings.Join(req.DevicesIDs, ","),
182-
},
183-
Mounts: []*pluginapi.Mount{
184-
{
185-
ContainerPath: "/bin/nvidia-smi",
186-
HostPath: "/var/lib/runai/bin/nvidia-smi",
187-
},
188-
},
189-
}
19029

191-
responses.ContainerResponses = append(responses.ContainerResponses, &response)
192-
}
193-
194-
return &responses, nil
195-
}
196-
197-
func (m *DevicePlugin) PreStartContainer(context.Context, *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) {
198-
return &pluginapi.PreStartContainerResponse{}, nil
199-
}
200-
201-
func (m *DevicePlugin) cleanup() error {
202-
if err := os.Remove(m.socket); err != nil && !os.IsNotExist(err) {
203-
return err
204-
}
205-
206-
return nil
207-
}
208-
209-
func (m *DevicePlugin) Serve() error {
210-
err := m.Start()
211-
if err != nil {
212-
log.Printf("Could not start device plugin: %s", err)
213-
return err
214-
}
215-
log.Println("Starting to serve on", m.socket)
216-
217-
err = m.Register(pluginapi.KubeletSocket, resourceName)
218-
if err != nil {
219-
log.Printf("Could not register device plugin: %s", err)
220-
stopErr := m.Stop()
221-
if stopErr != nil {
222-
log.Printf("Could not stop device plugin: %s", stopErr)
223-
}
224-
return err
30+
return &RealNodeDevicePlugin{
31+
devs: createDevices(getGpuCount(topology)),
32+
socket: serverSock,
22533
}
226-
log.Println("Registered device plugin with Kubelet")
227-
228-
return nil
22934
}

‎internal/deviceplugin/fake_node.go

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package deviceplugin
2+
3+
import (
4+
"fmt"
5+
"os"
6+
7+
"github.com/run-ai/fake-gpu-operator/internal/common/constants"
8+
"golang.org/x/net/context"
9+
"k8s.io/apimachinery/pkg/types"
10+
"k8s.io/client-go/kubernetes"
11+
12+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
13+
)
14+
15+
type FakeNodeDevicePlugin struct {
16+
kubeClient kubernetes.Interface
17+
gpuCount int
18+
}
19+
20+
func (f *FakeNodeDevicePlugin) Serve() error {
21+
patch := fmt.Sprintf(`{"status": {"capacity": {"%s": "%d"}, "allocatable": {"%s": "%d"}}}`, resourceName, f.gpuCount, resourceName, f.gpuCount)
22+
_, err := f.kubeClient.CoreV1().Nodes().Patch(context.TODO(), os.Getenv(constants.EnvNodeName), types.MergePatchType, []byte(patch), metav1.PatchOptions{}, "status")
23+
if err != nil {
24+
return fmt.Errorf("failed to update node capacity and allocatable: %v", err)
25+
}
26+
27+
return nil
28+
}

‎internal/deviceplugin/real_node.go

+217
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
package deviceplugin
2+
3+
import (
4+
"fmt"
5+
"log"
6+
"net"
7+
"os"
8+
"path"
9+
"strings"
10+
"time"
11+
12+
"github.com/google/uuid"
13+
"github.com/run-ai/fake-gpu-operator/internal/common/topology"
14+
"golang.org/x/net/context"
15+
"google.golang.org/grpc"
16+
"google.golang.org/grpc/credentials/insecure"
17+
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
18+
)
19+
20+
const (
21+
serverSock = pluginapi.DevicePluginPath + "fake-nvidia-gpu.sock"
22+
)
23+
24+
type RealNodeDevicePlugin struct {
25+
devs []*pluginapi.Device
26+
socket string
27+
28+
stop chan interface{}
29+
health chan *pluginapi.Device
30+
server *grpc.Server
31+
}
32+
33+
func getGpuCount(nodeTopology *topology.NodeTopology) int {
34+
return len(nodeTopology.Gpus)
35+
}
36+
37+
func (m *RealNodeDevicePlugin) GetDevicePluginOptions(context.Context, *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) {
38+
return &pluginapi.DevicePluginOptions{}, nil
39+
}
40+
41+
func dial(unixSocketPath string, timeout time.Duration) (*grpc.ClientConn, error) {
42+
ctx := context.Background()
43+
ctx, cancel := context.WithTimeout(ctx, timeout)
44+
defer cancel()
45+
46+
c, err := grpc.DialContext(
47+
ctx,
48+
unixSocketPath,
49+
grpc.WithTransportCredentials(insecure.NewCredentials()),
50+
grpc.WithBlock(),
51+
grpc.WithContextDialer(func(_ context.Context, addr string) (net.Conn, error) {
52+
return net.DialTimeout("unix", addr, timeout)
53+
}),
54+
)
55+
56+
if err != nil {
57+
return nil, err
58+
}
59+
60+
return c, nil
61+
}
62+
63+
func createDevices(devCount int) []*pluginapi.Device {
64+
var devs []*pluginapi.Device
65+
for i := 0; i < devCount; i++ {
66+
u, _ := uuid.NewRandom()
67+
devs = append(devs, &pluginapi.Device{
68+
ID: u.String(),
69+
Health: pluginapi.Healthy,
70+
})
71+
}
72+
return devs
73+
}
74+
75+
func (m *RealNodeDevicePlugin) Start() error {
76+
err := m.cleanup()
77+
if err != nil {
78+
return err
79+
}
80+
81+
sock, err := net.Listen("unix", m.socket)
82+
if err != nil {
83+
return err
84+
}
85+
86+
m.server = grpc.NewServer([]grpc.ServerOption{}...)
87+
pluginapi.RegisterDevicePluginServer(m.server, m)
88+
89+
go func() {
90+
err := m.server.Serve(sock)
91+
if err != nil {
92+
log.Println(err)
93+
}
94+
}()
95+
96+
// Wait for server to start by launching a blocking connexion
97+
conn, err := dial(m.socket, 5*time.Second)
98+
if err != nil {
99+
return err
100+
}
101+
conn.Close()
102+
103+
return nil
104+
}
105+
106+
func (m *RealNodeDevicePlugin) Stop() error {
107+
if m.server == nil {
108+
return nil
109+
}
110+
111+
m.server.Stop()
112+
m.server = nil
113+
close(m.stop)
114+
115+
return m.cleanup()
116+
}
117+
118+
func (m *RealNodeDevicePlugin) Register(kubeletEndpoint, resourceName string) error {
119+
conn, err := dial(kubeletEndpoint, 5*time.Second)
120+
if err != nil {
121+
return err
122+
}
123+
defer conn.Close()
124+
125+
client := pluginapi.NewRegistrationClient(conn)
126+
reqt := &pluginapi.RegisterRequest{
127+
Version: pluginapi.Version,
128+
Endpoint: path.Base(m.socket),
129+
ResourceName: resourceName,
130+
}
131+
132+
_, err = client.Register(context.Background(), reqt)
133+
if err != nil {
134+
return err
135+
}
136+
return nil
137+
}
138+
139+
func (m *RealNodeDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error {
140+
err := s.Send(&pluginapi.ListAndWatchResponse{Devices: m.devs})
141+
if err != nil {
142+
fmt.Printf("Failed to send devices to Kubelet: %v\n", err)
143+
}
144+
145+
for {
146+
select {
147+
case <-m.stop:
148+
return nil
149+
case d := <-m.health:
150+
// FIXME: there is no way to recover from the Unhealthy state.
151+
d.Health = pluginapi.Unhealthy
152+
err := s.Send(&pluginapi.ListAndWatchResponse{Devices: m.devs})
153+
if err != nil {
154+
log.Printf("failed to send unhealthy update: %v", err)
155+
}
156+
}
157+
}
158+
}
159+
160+
func (m *RealNodeDevicePlugin) GetPreferredAllocation(context.Context, *pluginapi.PreferredAllocationRequest) (*pluginapi.PreferredAllocationResponse, error) {
161+
return &pluginapi.PreferredAllocationResponse{}, nil
162+
}
163+
164+
func (m *RealNodeDevicePlugin) Allocate(ctx context.Context, reqs *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) {
165+
responses := pluginapi.AllocateResponse{}
166+
for _, req := range reqs.ContainerRequests {
167+
response := pluginapi.ContainerAllocateResponse{
168+
Envs: map[string]string{
169+
"MOCK_NVIDIA_VISIBLE_DEVICES": strings.Join(req.DevicesIDs, ","),
170+
},
171+
Mounts: []*pluginapi.Mount{
172+
{
173+
ContainerPath: "/bin/nvidia-smi",
174+
HostPath: "/var/lib/runai/bin/nvidia-smi",
175+
},
176+
},
177+
}
178+
179+
responses.ContainerResponses = append(responses.ContainerResponses, &response)
180+
}
181+
182+
return &responses, nil
183+
}
184+
185+
func (m *RealNodeDevicePlugin) PreStartContainer(context.Context, *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) {
186+
return &pluginapi.PreStartContainerResponse{}, nil
187+
}
188+
189+
func (m *RealNodeDevicePlugin) cleanup() error {
190+
if err := os.Remove(m.socket); err != nil && !os.IsNotExist(err) {
191+
return err
192+
}
193+
194+
return nil
195+
}
196+
197+
func (m *RealNodeDevicePlugin) Serve() error {
198+
err := m.Start()
199+
if err != nil {
200+
log.Printf("Could not start device plugin: %s", err)
201+
return err
202+
}
203+
log.Println("Starting to serve on", m.socket)
204+
205+
err = m.Register(pluginapi.KubeletSocket, resourceName)
206+
if err != nil {
207+
log.Printf("Could not register device plugin: %s", err)
208+
stopErr := m.Stop()
209+
if stopErr != nil {
210+
log.Printf("Could not stop device plugin: %s", stopErr)
211+
}
212+
return err
213+
}
214+
log.Println("Registered device plugin with Kubelet")
215+
216+
return nil
217+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package node
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"os"
7+
8+
"github.com/run-ai/fake-gpu-operator/internal/common/constants"
9+
appsv1 "k8s.io/api/apps/v1"
10+
v1 "k8s.io/api/core/v1"
11+
"k8s.io/apimachinery/pkg/api/errors"
12+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
13+
"k8s.io/utils/ptr"
14+
)
15+
16+
func (p *NodeHandler) applyFakeNodeDeployments(node *v1.Node) error {
17+
if !isFakeNode(node) {
18+
return nil
19+
}
20+
21+
deployments, err := p.generateFakeNodeDeployments(node)
22+
if err != nil {
23+
return fmt.Errorf("failed to get fake node deployments: %w", err)
24+
}
25+
26+
for _, deployment := range deployments {
27+
err := p.applyDeployment(deployment)
28+
if err != nil {
29+
return fmt.Errorf("failed to apply deployment: %w", err)
30+
}
31+
}
32+
33+
return nil
34+
}
35+
36+
func (p *NodeHandler) deleteFakeNodeDeployments(node *v1.Node) error {
37+
if !isFakeNode(node) {
38+
return nil
39+
}
40+
41+
deployments, err := p.generateFakeNodeDeployments(node)
42+
if err != nil {
43+
return fmt.Errorf("failed to get fake node deployments: %w", err)
44+
}
45+
46+
for _, deployment := range deployments {
47+
err := p.kubeClient.AppsV1().Deployments(deployment.Namespace).Delete(context.TODO(), deployment.Name, metav1.DeleteOptions{})
48+
if err != nil && !errors.IsNotFound(err) {
49+
return fmt.Errorf("failed to delete deployment %s: %w", deployment.Name, err)
50+
}
51+
}
52+
53+
return nil
54+
}
55+
56+
func (p *NodeHandler) generateFakeNodeDeployments(node *v1.Node) ([]appsv1.Deployment, error) {
57+
deploymentTemplates, err := p.kubeClient.AppsV1().Deployments(os.Getenv(constants.EnvFakeGpuOperatorNs)).List(context.TODO(), metav1.ListOptions{
58+
LabelSelector: fmt.Sprintf("%s=true", constants.FakeNodeDeploymentTemplateLabel),
59+
})
60+
if err != nil {
61+
return nil, fmt.Errorf("failed to list deployments: %w", err)
62+
}
63+
64+
deployments := []appsv1.Deployment{}
65+
for i := range deploymentTemplates.Items {
66+
deployments = append(deployments, *generateFakeNodeDeploymentFromTemplate(&deploymentTemplates.Items[i], node))
67+
}
68+
69+
return deployments, nil
70+
}
71+
72+
func (p *NodeHandler) applyDeployment(deployment appsv1.Deployment) error {
73+
existingDeployment, err := p.kubeClient.AppsV1().Deployments(deployment.Namespace).Get(context.TODO(), deployment.Name, metav1.GetOptions{})
74+
if err != nil && !errors.IsNotFound(err) {
75+
return fmt.Errorf("failed to get deployment %s: %w", deployment.Name, err)
76+
}
77+
78+
if errors.IsNotFound(err) {
79+
deployment.ResourceVersion = ""
80+
_, err := p.kubeClient.AppsV1().Deployments(deployment.Namespace).Create(context.TODO(), &deployment, metav1.CreateOptions{})
81+
if err != nil {
82+
return fmt.Errorf("failed to create deployment %s: %w", deployment.Name, err)
83+
}
84+
} else {
85+
deployment.UID = existingDeployment.UID
86+
deployment.ResourceVersion = existingDeployment.ResourceVersion
87+
_, err := p.kubeClient.AppsV1().Deployments(deployment.Namespace).Update(context.TODO(), &deployment, metav1.UpdateOptions{})
88+
if err != nil {
89+
return fmt.Errorf("failed to update deployment %s: %w", deployment.Name, err)
90+
}
91+
}
92+
93+
return nil
94+
}
95+
96+
func generateFakeNodeDeploymentFromTemplate(template *appsv1.Deployment, node *v1.Node) *appsv1.Deployment {
97+
deployment := template.DeepCopy()
98+
99+
delete(deployment.Labels, constants.FakeNodeDeploymentTemplateLabel)
100+
deployment.Name = fmt.Sprintf("%s-%s", deployment.Name, node.Name)
101+
deployment.Spec.Replicas = ptr.To(int32(1))
102+
deployment.Spec.Template.Spec.Containers[0].Env = append(deployment.Spec.Template.Spec.Containers[0].Env, v1.EnvVar{
103+
Name: constants.EnvNodeName,
104+
Value: node.Name,
105+
}, v1.EnvVar{
106+
Name: constants.EnvFakeNode,
107+
Value: "true",
108+
})
109+
110+
return deployment
111+
}
112+
113+
func isFakeNode(node *v1.Node) bool {
114+
return node != nil && node.Annotations[constants.KwokNodeAnnotation] == "fake"
115+
}

‎internal/status-updater/handlers/node/handler.go

+8-29
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"fmt"
55
"log"
66

7-
"github.com/google/uuid"
87
"github.com/run-ai/fake-gpu-operator/internal/common/topology"
98
v1 "k8s.io/api/core/v1"
109
"k8s.io/apimachinery/pkg/api/errors"
@@ -31,28 +30,14 @@ func NewNodeHandler(kubeClient kubernetes.Interface) *NodeHandler {
3130
func (p *NodeHandler) HandleAdd(node *v1.Node) error {
3231
log.Printf("Handling node addition: %s\n", node.Name)
3332

34-
nodeTopology, _ := topology.GetNodeTopologyFromCM(p.kubeClient, node.Name)
35-
if nodeTopology != nil {
36-
return nil
37-
}
38-
39-
baseTopology, err := topology.GetBaseTopologyFromCM(p.kubeClient)
33+
err := p.createNodeTopologyCM(node)
4034
if err != nil {
41-
return fmt.Errorf("failed to get base topology: %w", err)
42-
}
43-
44-
nodeAutofillSettings := baseTopology.Config.NodeAutofill
45-
46-
nodeTopology = &topology.NodeTopology{
47-
GpuMemory: nodeAutofillSettings.GpuMemory,
48-
GpuProduct: nodeAutofillSettings.GpuProduct,
49-
Gpus: generateGpuDetails(nodeAutofillSettings.GpuCount, node.Name),
50-
MigStrategy: nodeAutofillSettings.MigStrategy,
35+
return fmt.Errorf("failed to create node topology ConfigMap: %w", err)
5136
}
5237

53-
err = topology.CreateNodeTopologyCM(p.kubeClient, nodeTopology, node.Name)
38+
err = p.applyFakeNodeDeployments(node)
5439
if err != nil {
55-
return fmt.Errorf("failed to create node topology: %w", err)
40+
return fmt.Errorf("failed to apply fake node deployments: %w", err)
5641
}
5742

5843
return nil
@@ -66,16 +51,10 @@ func (p *NodeHandler) HandleDelete(node *v1.Node) error {
6651
return fmt.Errorf("failed to delete node topology: %w", err)
6752
}
6853

69-
return nil
70-
}
71-
72-
func generateGpuDetails(gpuCount int, nodeName string) []topology.GpuDetails {
73-
gpus := make([]topology.GpuDetails, gpuCount)
74-
for idx := range gpus {
75-
gpus[idx] = topology.GpuDetails{
76-
ID: fmt.Sprintf("GPU-%s", uuid.NewSHA1(uuid.Nil, []byte(fmt.Sprintf("%s-%d", nodeName, idx)))),
77-
}
54+
err = p.deleteFakeNodeDeployments(node)
55+
if err != nil {
56+
return fmt.Errorf("failed to delete fake node deployments: %w", err)
7857
}
7958

80-
return gpus
59+
return nil
8160
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package node
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/google/uuid"
7+
"github.com/run-ai/fake-gpu-operator/internal/common/topology"
8+
v1 "k8s.io/api/core/v1"
9+
)
10+
11+
func (p *NodeHandler) createNodeTopologyCM(node *v1.Node) error {
12+
nodeTopology, _ := topology.GetNodeTopologyFromCM(p.kubeClient, node.Name)
13+
if nodeTopology != nil {
14+
return nil
15+
}
16+
17+
baseTopology, err := topology.GetBaseTopologyFromCM(p.kubeClient)
18+
if err != nil {
19+
return fmt.Errorf("failed to get base topology: %w", err)
20+
}
21+
22+
nodeAutofillSettings := baseTopology.Config.NodeAutofill
23+
24+
nodeTopology = &topology.NodeTopology{
25+
GpuMemory: nodeAutofillSettings.GpuMemory,
26+
GpuProduct: nodeAutofillSettings.GpuProduct,
27+
Gpus: generateGpuDetails(nodeAutofillSettings.GpuCount, node.Name),
28+
MigStrategy: nodeAutofillSettings.MigStrategy,
29+
}
30+
31+
err = topology.CreateNodeTopologyCM(p.kubeClient, nodeTopology, node.Name)
32+
if err != nil {
33+
return fmt.Errorf("failed to create node topology: %w", err)
34+
}
35+
36+
return nil
37+
}
38+
39+
func generateGpuDetails(gpuCount int, nodeName string) []topology.GpuDetails {
40+
gpus := make([]topology.GpuDetails, gpuCount)
41+
for idx := range gpus {
42+
gpus[idx] = topology.GpuDetails{
43+
ID: fmt.Sprintf("GPU-%s", uuid.NewSHA1(uuid.Nil, []byte(fmt.Sprintf("%s-%d", nodeName, idx)))),
44+
}
45+
}
46+
47+
return gpus
48+
}

0 commit comments

Comments
 (0)
Please sign in to comment.