|
5 | 5 | import os |
6 | 6 | import time |
7 | 7 | import uuid |
8 | | -from multiprocessing import Manager |
9 | | -from multiprocessing.managers import SyncManager |
10 | | -from typing import Any, Dict, Optional |
| 8 | +import pickle |
| 9 | +import fcntl |
| 10 | +import tempfile |
| 11 | +from typing import Any, Dict, Optional, Set |
11 | 12 |
|
12 | 13 | from .rp_logger import RunPodLogger |
13 | 14 |
|
@@ -63,149 +64,150 @@ def __str__(self) -> str: |
63 | 64 | # ---------------------------------------------------------------------------- # |
64 | 65 | # Tracker # |
65 | 66 | # ---------------------------------------------------------------------------- # |
66 | | -class JobsProgress: |
67 | | - """Track the state of current jobs in progress using shared memory.""" |
68 | | - |
69 | | - _instance: Optional['JobsProgress'] = None |
70 | | - _manager: SyncManager |
71 | | - _shared_data: Any |
72 | | - _lock: Any |
| 67 | +class JobsProgress(Set[Job]): |
| 68 | + """Track the state of current jobs in progress with persistent state.""" |
| 69 | + |
| 70 | + _instance = None |
| 71 | + _STATE_DIR = os.getcwd() |
| 72 | + _STATE_FILE = os.path.join(_STATE_DIR, ".runpod_jobs.pkl") |
73 | 73 |
|
74 | 74 | def __new__(cls): |
75 | | - if cls._instance is None: |
76 | | - instance = object.__new__(cls) |
77 | | - # Initialize instance variables |
78 | | - instance._manager = Manager() |
79 | | - instance._shared_data = instance._manager.dict() |
80 | | - instance._shared_data['jobs'] = instance._manager.list() |
81 | | - instance._lock = instance._manager.Lock() |
82 | | - cls._instance = instance |
83 | | - return cls._instance |
| 75 | + if JobsProgress._instance is None: |
| 76 | + os.makedirs(cls._STATE_DIR, exist_ok=True) |
| 77 | + JobsProgress._instance = set.__new__(cls) |
| 78 | + # Initialize as empty set before loading state |
| 79 | + set.__init__(JobsProgress._instance) |
| 80 | + JobsProgress._instance._load_state() |
| 81 | + return JobsProgress._instance |
84 | 82 |
|
85 | 83 | def __init__(self): |
86 | | - # Everything is already initialized in __new__ |
| 84 | + # This should never clear data in a singleton |
| 85 | + # Don't call parent __init__ as it would clear the set |
87 | 86 | pass |
88 | | - |
| 87 | + |
89 | 88 | def __repr__(self) -> str: |
90 | 89 | return f"<{self.__class__.__name__}>: {self.get_job_list()}" |
91 | 90 |
|
| 91 | + def _load_state(self): |
| 92 | + """Load jobs state from pickle file with file locking.""" |
| 93 | + try: |
| 94 | + if ( |
| 95 | + os.path.exists(self._STATE_FILE) |
| 96 | + and os.path.getsize(self._STATE_FILE) > 0 |
| 97 | + ): |
| 98 | + with open(self._STATE_FILE, "rb") as f: |
| 99 | + fcntl.flock(f, fcntl.LOCK_SH) |
| 100 | + try: |
| 101 | + loaded_jobs = pickle.load(f) |
| 102 | + # Clear current state and add loaded jobs |
| 103 | + super().clear() |
| 104 | + for job in loaded_jobs: |
| 105 | + set.add( |
| 106 | + self, job |
| 107 | + ) # Use set.add to avoid triggering _save_state |
| 108 | + |
| 109 | + except (EOFError, pickle.UnpicklingError): |
| 110 | + # Handle empty or corrupted file |
| 111 | + log.debug( |
| 112 | + "JobsProgress: Failed to load state file, starting with empty state" |
| 113 | + ) |
| 114 | + pass |
| 115 | + finally: |
| 116 | + fcntl.flock(f, fcntl.LOCK_UN) |
| 117 | + |
| 118 | + except FileNotFoundError: |
| 119 | + log.debug("JobsProgress: No state file found, starting with empty state") |
| 120 | + pass |
| 121 | + |
| 122 | + def _save_state(self): |
| 123 | + """Save jobs state to pickle file with atomic write and file locking.""" |
| 124 | + try: |
| 125 | + # Use temporary file for atomic write |
| 126 | + with tempfile.NamedTemporaryFile( |
| 127 | + dir=self._STATE_DIR, delete=False, mode="wb" |
| 128 | + ) as temp_f: |
| 129 | + fcntl.flock(temp_f, fcntl.LOCK_EX) |
| 130 | + try: |
| 131 | + pickle.dump(set(self), temp_f) |
| 132 | + finally: |
| 133 | + fcntl.flock(temp_f, fcntl.LOCK_UN) |
| 134 | + |
| 135 | + # Atomically replace the state file |
| 136 | + os.replace(temp_f.name, self._STATE_FILE) |
| 137 | + except Exception as e: |
| 138 | + log.error(f"Failed to save job state: {e}") |
| 139 | + |
92 | 140 | def clear(self) -> None: |
93 | | - with self._lock: |
94 | | - self._shared_data['jobs'][:] = [] |
| 141 | + super().clear() |
| 142 | + self._save_state() |
95 | 143 |
|
96 | 144 | def add(self, element: Any): |
97 | 145 | """ |
98 | 146 | Adds a Job object to the set. |
99 | | - """ |
100 | | - if isinstance(element, str): |
101 | | - job_dict = {'id': element} |
102 | | - elif isinstance(element, dict): |
103 | | - job_dict = element |
104 | | - elif hasattr(element, 'id'): |
105 | | - job_dict = {'id': element.id} |
106 | | - else: |
107 | | - raise TypeError("Only Job objects can be added to JobsProgress.") |
108 | 147 |
|
109 | | - with self._lock: |
110 | | - # Check if job already exists |
111 | | - job_list = self._shared_data['jobs'] |
112 | | - for existing_job in job_list: |
113 | | - if existing_job['id'] == job_dict['id']: |
114 | | - return # Job already exists |
115 | | - |
116 | | - # Add new job |
117 | | - job_list.append(job_dict) |
118 | | - log.debug(f"JobsProgress | Added job: {job_dict['id']}") |
119 | | - |
120 | | - def get(self, element: Any) -> Optional[Job]: |
121 | | - """ |
122 | | - Retrieves a Job object from the set. |
| 148 | + If the added element is a string, then `Job(id=element)` is added |
123 | 149 | |
124 | | - If the element is a string, searches for Job with that id. |
| 150 | + If the added element is a dict, that `Job(**element)` is added |
125 | 151 | """ |
126 | 152 | if isinstance(element, str): |
127 | | - search_id = element |
128 | | - elif isinstance(element, Job): |
129 | | - search_id = element.id |
130 | | - else: |
131 | | - raise TypeError("Only Job objects can be retrieved from JobsProgress.") |
| 153 | + element = Job(id=element) |
132 | 154 |
|
133 | | - with self._lock: |
134 | | - for job_dict in self._shared_data['jobs']: |
135 | | - if job_dict['id'] == search_id: |
136 | | - log.debug(f"JobsProgress | Retrieved job: {job_dict['id']}") |
137 | | - return Job(**job_dict) |
138 | | - |
139 | | - return None |
| 155 | + if isinstance(element, dict): |
| 156 | + element = Job(**element) |
| 157 | + |
| 158 | + if not isinstance(element, Job): |
| 159 | + raise TypeError("Only Job objects can be added to JobsProgress.") |
| 160 | + |
| 161 | + result = super().add(element) |
| 162 | + self._save_state() |
| 163 | + return result |
140 | 164 |
|
141 | 165 | def remove(self, element: Any): |
142 | 166 | """ |
143 | 167 | Removes a Job object from the set. |
| 168 | +
|
| 169 | + If the element is a string, then `Job(id=element)` is removed |
| 170 | + |
| 171 | + If the element is a dict, then `Job(**element)` is removed |
144 | 172 | """ |
145 | 173 | if isinstance(element, str): |
146 | | - job_id = element |
147 | | - elif isinstance(element, dict): |
148 | | - job_id = element.get('id') |
149 | | - elif hasattr(element, 'id'): |
150 | | - job_id = element.id |
151 | | - else: |
| 174 | + element = Job(id=element) |
| 175 | + |
| 176 | + if isinstance(element, dict): |
| 177 | + element = Job(**element) |
| 178 | + |
| 179 | + if not isinstance(element, Job): |
152 | 180 | raise TypeError("Only Job objects can be removed from JobsProgress.") |
153 | 181 |
|
154 | | - with self._lock: |
155 | | - job_list = self._shared_data['jobs'] |
156 | | - # Find and remove the job |
157 | | - for i, job_dict in enumerate(job_list): |
158 | | - if job_dict['id'] == job_id: |
159 | | - del job_list[i] |
160 | | - log.debug(f"JobsProgress | Removed job: {job_dict['id']}") |
161 | | - break |
| 182 | + result = super().discard(element) |
| 183 | + self._save_state() |
| 184 | + return result |
| 185 | + |
| 186 | + def get(self, element: Any) -> Optional[Job]: |
| 187 | + if isinstance(element, str): |
| 188 | + element = Job(id=element) |
| 189 | + |
| 190 | + if not isinstance(element, Job): |
| 191 | + raise TypeError("Only Job objects can be retrieved from JobsProgress.") |
| 192 | + |
| 193 | + for job in self: |
| 194 | + if job == element: |
| 195 | + return job |
| 196 | + return None |
162 | 197 |
|
163 | 198 | def get_job_list(self) -> Optional[str]: |
164 | 199 | """ |
165 | 200 | Returns the list of job IDs as comma-separated string. |
166 | 201 | """ |
167 | | - with self._lock: |
168 | | - job_list = list(self._shared_data['jobs']) |
169 | | - |
170 | | - if not job_list: |
| 202 | + self._load_state() |
| 203 | + |
| 204 | + if not len(self): |
171 | 205 | return None |
172 | 206 |
|
173 | | - log.debug(f"JobsProgress | Jobs in progress: {job_list}") |
174 | | - return ",".join(str(job_dict['id']) for job_dict in job_list) |
| 207 | + return ",".join(str(job) for job in self) |
175 | 208 |
|
176 | 209 | def get_job_count(self) -> int: |
177 | 210 | """ |
178 | 211 | Returns the number of jobs. |
179 | 212 | """ |
180 | | - with self._lock: |
181 | | - return len(self._shared_data['jobs']) |
182 | | - |
183 | | - def __iter__(self): |
184 | | - """Make the class iterable - returns Job objects""" |
185 | | - with self._lock: |
186 | | - # Create a snapshot of jobs to avoid holding lock during iteration |
187 | | - job_dicts = list(self._shared_data['jobs']) |
188 | | - |
189 | | - # Return an iterator of Job objects |
190 | | - return iter(Job(**job_dict) for job_dict in job_dicts) |
191 | | - |
192 | | - def __len__(self): |
193 | | - """Support len() operation""" |
194 | | - return self.get_job_count() |
195 | | - |
196 | | - def __contains__(self, element: Any) -> bool: |
197 | | - """Support 'in' operator""" |
198 | | - if isinstance(element, str): |
199 | | - search_id = element |
200 | | - elif isinstance(element, Job): |
201 | | - search_id = element.id |
202 | | - elif isinstance(element, dict): |
203 | | - search_id = element.get('id') |
204 | | - else: |
205 | | - return False |
206 | | - |
207 | | - with self._lock: |
208 | | - for job_dict in self._shared_data['jobs']: |
209 | | - if job_dict['id'] == search_id: |
210 | | - return True |
211 | | - return False |
| 213 | + return len(self) |
0 commit comments