3939from __future__ import annotations
4040
4141from collections .abc import Iterator
42- from dataclasses import dataclass
4342from datetime import datetime
4443import logging
4544import os
6766logger = logging .getLogger (__name__ )
6867
6968
70- @dataclass
71- class _Node :
72- name : str
73- container_id : str
74- status : str = constants .TRAINJOB_CREATED
75-
76-
77- @dataclass
78- class _Job :
79- name : str
80- created : datetime
81- runtime : types .Runtime
82- network_id : str
83- nodes : list [_Node ]
84- workdir_host : str
85-
86-
8769class ContainerBackend (ExecutionBackend ):
8870 """
8971 Unified container backend that auto-detects Docker or Podman.
@@ -94,7 +76,6 @@ class ContainerBackend(ExecutionBackend):
9476
9577 def __init__ (self , cfg : ContainerBackendConfig ):
9678 self .cfg = cfg
97- self ._jobs : dict [str , _Job ] = {}
9879 self .label_prefix = "trainer.kubeflow.org"
9980
10081 # Initialize the container client adapter
@@ -215,12 +196,17 @@ def train(
215196
216197 network_id = self ._adapter .create_network (
217198 name = f"{ job_name } -net" ,
218- labels = {f"{ self .label_prefix } /trainjob-name" : job_name },
199+ labels = {
200+ f"{ self .label_prefix } /trainjob-name" : job_name ,
201+ f"{ self .label_prefix } /runtime-name" : runtime .name ,
202+ f"{ self .label_prefix } /workdir" : workdir ,
203+ f"{ self .label_prefix } /created" : datetime .now ().isoformat (),
204+ },
219205 )
220206 logger .info (f"Created network: { network_id } " )
221207
222208 # Create N containers (one per node)
223- containers : list [_Node ] = []
209+ container_ids : list [str ] = []
224210 master_container_id = None
225211 master_ip = None
226212
@@ -279,6 +265,8 @@ def train(
279265 labels = {
280266 f"{ self .label_prefix } /trainjob-name" : job_name ,
281267 f"{ self .label_prefix } /step" : f"node-{ rank } " ,
268+ f"{ self .label_prefix } /network-id" : network_id ,
269+ f"{ self .label_prefix } /num-nodes" : str (num_nodes ),
282270 }
283271
284272 volumes = {
@@ -302,7 +290,7 @@ def train(
302290 )
303291
304292 logger .info (f"Started container { container_name } (ID: { container_id [:12 ]} )" )
305- containers .append (_Node ( name = container_name , container_id = container_id ) )
293+ container_ids .append (container_id )
306294
307295 # If this is the master node and we're using Podman, get its IP address
308296 if rank == 0 :
@@ -318,18 +306,8 @@ def train(
318306 "Worker nodes will fall back to DNS resolution."
319307 )
320308
321- # Store job in backend
322- self ._jobs [job_name ] = _Job (
323- name = job_name ,
324- created = datetime .now (),
325- runtime = runtime ,
326- network_id = network_id ,
327- nodes = containers ,
328- workdir_host = workdir ,
329- )
330-
331309 logger .info (
332- f"Training job { job_name } created successfully with { len (containers )} container(s)"
310+ f"Training job { job_name } created successfully with { len (container_ids )} container(s)"
333311 )
334312 return job_name
335313
@@ -343,11 +321,11 @@ def train(
343321
344322 try :
345323 # Stop and remove any containers that were created
346- if "containers " in locals ():
347- for node in containers :
324+ if "container_ids " in locals ():
325+ for container_id in container_ids :
348326 with suppress (Exception ):
349- self ._adapter .stop_container (node . container_id , timeout = 5 )
350- self ._adapter .remove_container (node . container_id , force = True )
327+ self ._adapter .stop_container (container_id , timeout = 5 )
328+ self ._adapter .remove_container (container_id , force = True )
351329
352330 # Remove network if it was created
353331 if "network_id" in locals ():
@@ -365,53 +343,146 @@ def train(
365343 raise
366344
367345 def list_jobs (self , runtime : types .Runtime | None = None ) -> list [types .TrainJob ]:
346+ """List all training jobs by querying container runtime."""
347+ # Get all containers with our label prefix
348+ filters = {"label" : [f"{ self .label_prefix } /trainjob-name" ]}
349+ containers = self ._adapter .list_containers (filters = filters )
350+
351+ # Group containers by job name
352+ jobs_map : dict [str , list [dict ]] = {}
353+ for container in containers :
354+ job_name = container ["labels" ].get (f"{ self .label_prefix } /trainjob-name" )
355+ if job_name :
356+ if job_name not in jobs_map :
357+ jobs_map [job_name ] = []
358+ jobs_map [job_name ].append (container )
359+
368360 result : list [types .TrainJob ] = []
369- for job in self ._jobs .values ():
370- if runtime and job .runtime .name != runtime .name :
361+ for job_name , job_containers in jobs_map .items ():
362+ # Get metadata from first container's network
363+ if not job_containers :
364+ continue
365+
366+ network_id = job_containers [0 ]["labels" ].get (f"{ self .label_prefix } /network-id" )
367+ if not network_id :
368+ continue
369+
370+ network_info = self ._adapter .get_network (network_id )
371+ if not network_info :
371372 continue
373+
374+ network_labels = network_info .get ("labels" , {})
375+ runtime_name = network_labels .get (f"{ self .label_prefix } /runtime-name" )
376+
377+ # Filter by runtime if specified
378+ if runtime and runtime_name != runtime .name :
379+ continue
380+
381+ # Get runtime object
382+ try :
383+ job_runtime = self .get_runtime (runtime_name ) if runtime_name else None
384+ except Exception :
385+ job_runtime = None
386+
387+ if not job_runtime :
388+ continue
389+
390+ # Parse creation timestamp
391+ created_str = network_labels .get (f"{ self .label_prefix } /created" , "" )
392+ try :
393+ from dateutil import parser
394+ creation_timestamp = parser .isoparse (created_str )
395+ except Exception :
396+ creation_timestamp = datetime .now ()
397+
398+ # Build steps from containers
372399 steps = []
373- for node in job .nodes :
400+ for container in sorted (job_containers , key = lambda c : c ["name" ]):
401+ step_name = container ["labels" ].get (f"{ self .label_prefix } /step" , "" )
374402 steps .append (
375403 types .Step (
376- name = node . name . split ( f" { job . name } -" )[ - 1 ] ,
377- pod_name = node . name ,
378- status = self ._container_status (node . container_id ),
404+ name = step_name ,
405+ pod_name = container [ " name" ] ,
406+ status = self ._container_status (container [ "id" ] ),
379407 )
380408 )
409+
410+ # Get num_nodes from labels
411+ num_nodes = int (job_containers [0 ]["labels" ].get (f"{ self .label_prefix } /num-nodes" , len (job_containers )))
412+
381413 result .append (
382414 types .TrainJob (
383- name = job . name ,
384- creation_timestamp = job . created ,
385- runtime = job . runtime ,
415+ name = job_name ,
416+ creation_timestamp = creation_timestamp ,
417+ runtime = job_runtime ,
386418 steps = steps ,
387- num_nodes = len ( job . nodes ) ,
388- status = self ._aggregate_status ( job ),
419+ num_nodes = num_nodes ,
420+ status = self ._aggregate_status_from_containers ( job_containers ),
389421 )
390422 )
423+
391424 return result
392425
393426 def get_job (self , name : str ) -> types .TrainJob :
394- job = self ._jobs .get (name )
395- if not job :
427+ """Get a specific training job by querying container runtime."""
428+ # Find containers for this job
429+ filters = {"label" : [f"{ self .label_prefix } /trainjob-name={ name } " ]}
430+ containers = self ._adapter .list_containers (filters = filters )
431+
432+ if not containers :
396433 raise ValueError (f"No TrainJob with name { name } " )
397- # Refresh container statuses on demand
398- steps : list [types .Step ] = []
399- for node in job .nodes :
400- status = self ._container_status (node .container_id )
434+
435+ # Get metadata from network
436+ network_id = containers [0 ]["labels" ].get (f"{ self .label_prefix } /network-id" )
437+ if not network_id :
438+ raise ValueError (f"TrainJob { name } is missing network metadata" )
439+
440+ network_info = self ._adapter .get_network (network_id )
441+ if not network_info :
442+ raise ValueError (f"TrainJob { name } network not found" )
443+
444+ network_labels = network_info .get ("labels" , {})
445+ runtime_name = network_labels .get (f"{ self .label_prefix } /runtime-name" )
446+
447+ # Get runtime object
448+ try :
449+ job_runtime = self .get_runtime (runtime_name ) if runtime_name else None
450+ except Exception :
451+ raise ValueError (f"Runtime { runtime_name } not found for job { name } " )
452+
453+ if not job_runtime :
454+ raise ValueError (f"Runtime { runtime_name } not found for job { name } " )
455+
456+ # Parse creation timestamp
457+ created_str = network_labels .get (f"{ self .label_prefix } /created" , "" )
458+ try :
459+ from dateutil import parser
460+ creation_timestamp = parser .isoparse (created_str )
461+ except Exception :
462+ creation_timestamp = datetime .now ()
463+
464+ # Build steps from containers
465+ steps = []
466+ for container in sorted (containers , key = lambda c : c ["name" ]):
467+ step_name = container ["labels" ].get (f"{ self .label_prefix } /step" , "" )
401468 steps .append (
402469 types .Step (
403- name = node . name . split ( f" { job . name } -" )[ - 1 ] ,
404- pod_name = node . name ,
405- status = status ,
470+ name = step_name ,
471+ pod_name = container [ " name" ] ,
472+ status = self . _container_status ( container [ "id" ]) ,
406473 )
407474 )
475+
476+ # Get num_nodes from labels
477+ num_nodes = int (containers [0 ]["labels" ].get (f"{ self .label_prefix } /num-nodes" , len (containers )))
478+
408479 return types .TrainJob (
409- name = job . name ,
410- creation_timestamp = job . created ,
411- runtime = job . runtime ,
480+ name = name ,
481+ creation_timestamp = creation_timestamp ,
482+ runtime = job_runtime ,
412483 steps = steps ,
413- num_nodes = len ( job . nodes ) ,
414- status = self ._aggregate_status ( job ),
484+ num_nodes = num_nodes ,
485+ status = self ._aggregate_status_from_containers ( containers ),
415486 )
416487
417488 def get_job_logs (
@@ -420,19 +491,23 @@ def get_job_logs(
420491 follow : bool = False ,
421492 step : str = constants .NODE + "-0" ,
422493 ) -> Iterator [str ]:
423- job = self ._jobs .get (name )
424- if not job :
494+ """Get logs for a training job by querying container runtime."""
495+ # Find containers for this job
496+ filters = {"label" : [f"{ self .label_prefix } /trainjob-name={ name } " ]}
497+ containers = self ._adapter .list_containers (filters = filters )
498+
499+ if not containers :
425500 raise ValueError (f"No TrainJob with name { name } " )
426501
427502 want_all = step == constants .NODE + "-0"
428- for node in job . nodes :
429- node_step = node . name . split (f"{ job . name } -" )[ - 1 ]
430- if not want_all and node_step != step :
503+ for container in sorted ( containers , key = lambda c : c [ "name" ]) :
504+ container_step = container [ "labels" ]. get (f"{ self . label_prefix } /step" , "" )
505+ if not want_all and container_step != step :
431506 continue
432507 try :
433- yield from self ._adapter .container_logs (node . container_id , follow )
508+ yield from self ._adapter .container_logs (container [ "id" ] , follow )
434509 except Exception as e :
435- logger .warning (f"Failed to get logs for { node . name } : { e } " )
510+ logger .warning (f"Failed to get logs for { container [ ' name' ] } : { e } " )
436511 yield f"Error getting logs: { e } \n "
437512
438513 def wait_for_job_status (
@@ -456,28 +531,42 @@ def wait_for_job_status(
456531 raise TimeoutError (f"Timeout waiting for TrainJob { name } to reach status: { status } " )
457532
458533 def delete_job (self , name : str ):
459- job = self ._jobs .get (name )
460- if not job :
534+ """Delete a training job by querying container runtime."""
535+ # Find containers for this job
536+ filters = {"label" : [f"{ self .label_prefix } /trainjob-name={ name } " ]}
537+ containers = self ._adapter .list_containers (filters = filters )
538+
539+ if not containers :
461540 raise ValueError (f"No TrainJob with name { name } " )
462541
542+ # Get network_id and workdir from labels
543+ network_id = containers [0 ]["labels" ].get (f"{ self .label_prefix } /network-id" )
544+
545+ # Get workdir from network labels
546+ workdir_host = None
547+ if network_id :
548+ network_info = self ._adapter .get_network (network_id )
549+ if network_info :
550+ network_labels = network_info .get ("labels" , {})
551+ workdir_host = network_labels .get (f"{ self .label_prefix } /workdir" )
552+
463553 # Stop containers and remove
464554 from contextlib import suppress
465555
466- for node in job . nodes :
556+ for container in containers :
467557 with suppress (Exception ):
468- self ._adapter .stop_container (node . container_id , timeout = 10 )
558+ self ._adapter .stop_container (container [ "id" ] , timeout = 10 )
469559 with suppress (Exception ):
470- self ._adapter .remove_container (node . container_id , force = True )
560+ self ._adapter .remove_container (container [ "id" ] , force = True )
471561
472562 # Remove network (best-effort)
473- with suppress (Exception ):
474- self ._adapter .delete_network (job .network_id )
563+ if network_id :
564+ with suppress (Exception ):
565+ self ._adapter .delete_network (network_id )
475566
476567 # Remove working directory if configured
477- if self .cfg .auto_remove and os .path .isdir (job .workdir_host ):
478- shutil .rmtree (job .workdir_host , ignore_errors = True )
479-
480- del self ._jobs [name ]
568+ if self .cfg .auto_remove and workdir_host and os .path .isdir (workdir_host ):
569+ shutil .rmtree (workdir_host , ignore_errors = True )
481570
482571 # Helper methods
483572
@@ -572,9 +661,9 @@ def _container_status(self, container_id: str) -> str:
572661 return constants .UNKNOWN
573662 return constants .UNKNOWN
574663
575- def _aggregate_status (self , job : _Job ) -> str :
576- """Aggregate status from all containers in a job ."""
577- statuses = [self ._container_status (n . container_id ) for n in job . nodes ]
664+ def _aggregate_status_from_containers (self , containers : list [ dict ] ) -> str :
665+ """Aggregate status from container info dicts ."""
666+ statuses = [self ._container_status (c [ "id" ] ) for c in containers ]
578667 if constants .TRAINJOB_FAILED in statuses :
579668 return constants .TRAINJOB_FAILED
580669 if constants .TRAINJOB_RUNNING in statuses :
0 commit comments