Skip to content

Commit defd7eb

Browse files
committed
Updating to rely on container runtime rather than storing job info in memory
1 parent 0680c3e commit defd7eb

File tree

2 files changed

+263
-83
lines changed

2 files changed

+263
-83
lines changed

kubeflow/trainer/backends/container/backend.py

Lines changed: 172 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from __future__ import annotations
4040

4141
from collections.abc import Iterator
42-
from dataclasses import dataclass
4342
from datetime import datetime
4443
import logging
4544
import os
@@ -67,23 +66,6 @@
6766
logger = logging.getLogger(__name__)
6867

6968

70-
@dataclass
71-
class _Node:
72-
name: str
73-
container_id: str
74-
status: str = constants.TRAINJOB_CREATED
75-
76-
77-
@dataclass
78-
class _Job:
79-
name: str
80-
created: datetime
81-
runtime: types.Runtime
82-
network_id: str
83-
nodes: list[_Node]
84-
workdir_host: str
85-
86-
8769
class ContainerBackend(ExecutionBackend):
8870
"""
8971
Unified container backend that auto-detects Docker or Podman.
@@ -94,7 +76,6 @@ class ContainerBackend(ExecutionBackend):
9476

9577
def __init__(self, cfg: ContainerBackendConfig):
9678
self.cfg = cfg
97-
self._jobs: dict[str, _Job] = {}
9879
self.label_prefix = "trainer.kubeflow.org"
9980

10081
# Initialize the container client adapter
@@ -215,12 +196,17 @@ def train(
215196

216197
network_id = self._adapter.create_network(
217198
name=f"{job_name}-net",
218-
labels={f"{self.label_prefix}/trainjob-name": job_name},
199+
labels={
200+
f"{self.label_prefix}/trainjob-name": job_name,
201+
f"{self.label_prefix}/runtime-name": runtime.name,
202+
f"{self.label_prefix}/workdir": workdir,
203+
f"{self.label_prefix}/created": datetime.now().isoformat(),
204+
},
219205
)
220206
logger.info(f"Created network: {network_id}")
221207

222208
# Create N containers (one per node)
223-
containers: list[_Node] = []
209+
container_ids: list[str] = []
224210
master_container_id = None
225211
master_ip = None
226212

@@ -279,6 +265,8 @@ def train(
279265
labels = {
280266
f"{self.label_prefix}/trainjob-name": job_name,
281267
f"{self.label_prefix}/step": f"node-{rank}",
268+
f"{self.label_prefix}/network-id": network_id,
269+
f"{self.label_prefix}/num-nodes": str(num_nodes),
282270
}
283271

284272
volumes = {
@@ -302,7 +290,7 @@ def train(
302290
)
303291

304292
logger.info(f"Started container {container_name} (ID: {container_id[:12]})")
305-
containers.append(_Node(name=container_name, container_id=container_id))
293+
container_ids.append(container_id)
306294

307295
# If this is the master node and we're using Podman, get its IP address
308296
if rank == 0:
@@ -318,18 +306,8 @@ def train(
318306
"Worker nodes will fall back to DNS resolution."
319307
)
320308

321-
# Store job in backend
322-
self._jobs[job_name] = _Job(
323-
name=job_name,
324-
created=datetime.now(),
325-
runtime=runtime,
326-
network_id=network_id,
327-
nodes=containers,
328-
workdir_host=workdir,
329-
)
330-
331309
logger.info(
332-
f"Training job {job_name} created successfully with {len(containers)} container(s)"
310+
f"Training job {job_name} created successfully with {len(container_ids)} container(s)"
333311
)
334312
return job_name
335313

@@ -343,11 +321,11 @@ def train(
343321

344322
try:
345323
# Stop and remove any containers that were created
346-
if "containers" in locals():
347-
for node in containers:
324+
if "container_ids" in locals():
325+
for container_id in container_ids:
348326
with suppress(Exception):
349-
self._adapter.stop_container(node.container_id, timeout=5)
350-
self._adapter.remove_container(node.container_id, force=True)
327+
self._adapter.stop_container(container_id, timeout=5)
328+
self._adapter.remove_container(container_id, force=True)
351329

352330
# Remove network if it was created
353331
if "network_id" in locals():
@@ -365,53 +343,146 @@ def train(
365343
raise
366344

367345
def list_jobs(self, runtime: types.Runtime | None = None) -> list[types.TrainJob]:
346+
"""List all training jobs by querying container runtime."""
347+
# Get all containers with our label prefix
348+
filters = {"label": [f"{self.label_prefix}/trainjob-name"]}
349+
containers = self._adapter.list_containers(filters=filters)
350+
351+
# Group containers by job name
352+
jobs_map: dict[str, list[dict]] = {}
353+
for container in containers:
354+
job_name = container["labels"].get(f"{self.label_prefix}/trainjob-name")
355+
if job_name:
356+
if job_name not in jobs_map:
357+
jobs_map[job_name] = []
358+
jobs_map[job_name].append(container)
359+
368360
result: list[types.TrainJob] = []
369-
for job in self._jobs.values():
370-
if runtime and job.runtime.name != runtime.name:
361+
for job_name, job_containers in jobs_map.items():
362+
# Get metadata from first container's network
363+
if not job_containers:
364+
continue
365+
366+
network_id = job_containers[0]["labels"].get(f"{self.label_prefix}/network-id")
367+
if not network_id:
368+
continue
369+
370+
network_info = self._adapter.get_network(network_id)
371+
if not network_info:
371372
continue
373+
374+
network_labels = network_info.get("labels", {})
375+
runtime_name = network_labels.get(f"{self.label_prefix}/runtime-name")
376+
377+
# Filter by runtime if specified
378+
if runtime and runtime_name != runtime.name:
379+
continue
380+
381+
# Get runtime object
382+
try:
383+
job_runtime = self.get_runtime(runtime_name) if runtime_name else None
384+
except Exception:
385+
job_runtime = None
386+
387+
if not job_runtime:
388+
continue
389+
390+
# Parse creation timestamp
391+
created_str = network_labels.get(f"{self.label_prefix}/created", "")
392+
try:
393+
from dateutil import parser
394+
creation_timestamp = parser.isoparse(created_str)
395+
except Exception:
396+
creation_timestamp = datetime.now()
397+
398+
# Build steps from containers
372399
steps = []
373-
for node in job.nodes:
400+
for container in sorted(job_containers, key=lambda c: c["name"]):
401+
step_name = container["labels"].get(f"{self.label_prefix}/step", "")
374402
steps.append(
375403
types.Step(
376-
name=node.name.split(f"{job.name}-")[-1],
377-
pod_name=node.name,
378-
status=self._container_status(node.container_id),
404+
name=step_name,
405+
pod_name=container["name"],
406+
status=self._container_status(container["id"]),
379407
)
380408
)
409+
410+
# Get num_nodes from labels
411+
num_nodes = int(job_containers[0]["labels"].get(f"{self.label_prefix}/num-nodes", len(job_containers)))
412+
381413
result.append(
382414
types.TrainJob(
383-
name=job.name,
384-
creation_timestamp=job.created,
385-
runtime=job.runtime,
415+
name=job_name,
416+
creation_timestamp=creation_timestamp,
417+
runtime=job_runtime,
386418
steps=steps,
387-
num_nodes=len(job.nodes),
388-
status=self._aggregate_status(job),
419+
num_nodes=num_nodes,
420+
status=self._aggregate_status_from_containers(job_containers),
389421
)
390422
)
423+
391424
return result
392425

393426
def get_job(self, name: str) -> types.TrainJob:
394-
job = self._jobs.get(name)
395-
if not job:
427+
"""Get a specific training job by querying container runtime."""
428+
# Find containers for this job
429+
filters = {"label": [f"{self.label_prefix}/trainjob-name={name}"]}
430+
containers = self._adapter.list_containers(filters=filters)
431+
432+
if not containers:
396433
raise ValueError(f"No TrainJob with name {name}")
397-
# Refresh container statuses on demand
398-
steps: list[types.Step] = []
399-
for node in job.nodes:
400-
status = self._container_status(node.container_id)
434+
435+
# Get metadata from network
436+
network_id = containers[0]["labels"].get(f"{self.label_prefix}/network-id")
437+
if not network_id:
438+
raise ValueError(f"TrainJob {name} is missing network metadata")
439+
440+
network_info = self._adapter.get_network(network_id)
441+
if not network_info:
442+
raise ValueError(f"TrainJob {name} network not found")
443+
444+
network_labels = network_info.get("labels", {})
445+
runtime_name = network_labels.get(f"{self.label_prefix}/runtime-name")
446+
447+
# Get runtime object
448+
try:
449+
job_runtime = self.get_runtime(runtime_name) if runtime_name else None
450+
except Exception:
451+
raise ValueError(f"Runtime {runtime_name} not found for job {name}")
452+
453+
if not job_runtime:
454+
raise ValueError(f"Runtime {runtime_name} not found for job {name}")
455+
456+
# Parse creation timestamp
457+
created_str = network_labels.get(f"{self.label_prefix}/created", "")
458+
try:
459+
from dateutil import parser
460+
creation_timestamp = parser.isoparse(created_str)
461+
except Exception:
462+
creation_timestamp = datetime.now()
463+
464+
# Build steps from containers
465+
steps = []
466+
for container in sorted(containers, key=lambda c: c["name"]):
467+
step_name = container["labels"].get(f"{self.label_prefix}/step", "")
401468
steps.append(
402469
types.Step(
403-
name=node.name.split(f"{job.name}-")[-1],
404-
pod_name=node.name,
405-
status=status,
470+
name=step_name,
471+
pod_name=container["name"],
472+
status=self._container_status(container["id"]),
406473
)
407474
)
475+
476+
# Get num_nodes from labels
477+
num_nodes = int(containers[0]["labels"].get(f"{self.label_prefix}/num-nodes", len(containers)))
478+
408479
return types.TrainJob(
409-
name=job.name,
410-
creation_timestamp=job.created,
411-
runtime=job.runtime,
480+
name=name,
481+
creation_timestamp=creation_timestamp,
482+
runtime=job_runtime,
412483
steps=steps,
413-
num_nodes=len(job.nodes),
414-
status=self._aggregate_status(job),
484+
num_nodes=num_nodes,
485+
status=self._aggregate_status_from_containers(containers),
415486
)
416487

417488
def get_job_logs(
@@ -420,19 +491,23 @@ def get_job_logs(
420491
follow: bool = False,
421492
step: str = constants.NODE + "-0",
422493
) -> Iterator[str]:
423-
job = self._jobs.get(name)
424-
if not job:
494+
"""Get logs for a training job by querying container runtime."""
495+
# Find containers for this job
496+
filters = {"label": [f"{self.label_prefix}/trainjob-name={name}"]}
497+
containers = self._adapter.list_containers(filters=filters)
498+
499+
if not containers:
425500
raise ValueError(f"No TrainJob with name {name}")
426501

427502
want_all = step == constants.NODE + "-0"
428-
for node in job.nodes:
429-
node_step = node.name.split(f"{job.name}-")[-1]
430-
if not want_all and node_step != step:
503+
for container in sorted(containers, key=lambda c: c["name"]):
504+
container_step = container["labels"].get(f"{self.label_prefix}/step", "")
505+
if not want_all and container_step != step:
431506
continue
432507
try:
433-
yield from self._adapter.container_logs(node.container_id, follow)
508+
yield from self._adapter.container_logs(container["id"], follow)
434509
except Exception as e:
435-
logger.warning(f"Failed to get logs for {node.name}: {e}")
510+
logger.warning(f"Failed to get logs for {container['name']}: {e}")
436511
yield f"Error getting logs: {e}\n"
437512

438513
def wait_for_job_status(
@@ -456,28 +531,42 @@ def wait_for_job_status(
456531
raise TimeoutError(f"Timeout waiting for TrainJob {name} to reach status: {status}")
457532

458533
def delete_job(self, name: str):
459-
job = self._jobs.get(name)
460-
if not job:
534+
"""Delete a training job by querying container runtime."""
535+
# Find containers for this job
536+
filters = {"label": [f"{self.label_prefix}/trainjob-name={name}"]}
537+
containers = self._adapter.list_containers(filters=filters)
538+
539+
if not containers:
461540
raise ValueError(f"No TrainJob with name {name}")
462541

542+
# Get network_id and workdir from labels
543+
network_id = containers[0]["labels"].get(f"{self.label_prefix}/network-id")
544+
545+
# Get workdir from network labels
546+
workdir_host = None
547+
if network_id:
548+
network_info = self._adapter.get_network(network_id)
549+
if network_info:
550+
network_labels = network_info.get("labels", {})
551+
workdir_host = network_labels.get(f"{self.label_prefix}/workdir")
552+
463553
# Stop containers and remove
464554
from contextlib import suppress
465555

466-
for node in job.nodes:
556+
for container in containers:
467557
with suppress(Exception):
468-
self._adapter.stop_container(node.container_id, timeout=10)
558+
self._adapter.stop_container(container["id"], timeout=10)
469559
with suppress(Exception):
470-
self._adapter.remove_container(node.container_id, force=True)
560+
self._adapter.remove_container(container["id"], force=True)
471561

472562
# Remove network (best-effort)
473-
with suppress(Exception):
474-
self._adapter.delete_network(job.network_id)
563+
if network_id:
564+
with suppress(Exception):
565+
self._adapter.delete_network(network_id)
475566

476567
# Remove working directory if configured
477-
if self.cfg.auto_remove and os.path.isdir(job.workdir_host):
478-
shutil.rmtree(job.workdir_host, ignore_errors=True)
479-
480-
del self._jobs[name]
568+
if self.cfg.auto_remove and workdir_host and os.path.isdir(workdir_host):
569+
shutil.rmtree(workdir_host, ignore_errors=True)
481570

482571
# Helper methods
483572

@@ -572,9 +661,9 @@ def _container_status(self, container_id: str) -> str:
572661
return constants.UNKNOWN
573662
return constants.UNKNOWN
574663

575-
def _aggregate_status(self, job: _Job) -> str:
576-
"""Aggregate status from all containers in a job."""
577-
statuses = [self._container_status(n.container_id) for n in job.nodes]
664+
def _aggregate_status_from_containers(self, containers: list[dict]) -> str:
665+
"""Aggregate status from container info dicts."""
666+
statuses = [self._container_status(c["id"]) for c in containers]
578667
if constants.TRAINJOB_FAILED in statuses:
579668
return constants.TRAINJOB_FAILED
580669
if constants.TRAINJOB_RUNNING in statuses:

0 commit comments

Comments
 (0)