-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodular_trainers.py
More file actions
212 lines (170 loc) · 5.74 KB
/
modular_trainers.py
File metadata and controls
212 lines (170 loc) · 5.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""
Shim module - Re-exports from src.training.modular_trainers for backward compatibility.
This module provides backward compatibility for code that imports from the old
'modular_trainers' location. The actual implementation has been moved to
src.training.modular_trainers as part of the project restructure.
DEPRECATION NOTICE:
This shim module is provided for backward compatibility only. New code should
import directly from src.training.modular_trainers:
from src.training.modular_trainers import TrainerConfig, BaseTrainer
Usage:
# Old style (still works via this shim):
from modular_trainers import TrainerConfig, TransformerDirectionTrainer
# New style (recommended):
from src.training.modular_trainers import TrainerConfig, TransformerDirectionTrainer
Performance:
Imports are lazy-loaded on first access to minimize startup overhead.
Explicit imports are used instead of wildcard for better IDE support
and static analysis.
"""
import sys
import warnings
from typing import TYPE_CHECKING
# Use TYPE_CHECKING for type hints without circular imports
if TYPE_CHECKING:
from src.training.modular_trainers import (
TrainerConfig,
EMACallback,
EWCPenalty,
OverfitPreventionCallback,
EWCTrainingCallback,
RichEpochCallback,
ReplayBuffer,
DriftDetector,
TrainingLineage,
BaseTrainer,
TCNTrainer,
TransformerDirectionTrainer,
TransformerRegimeTrainer,
XGBoostTrainer,
RandomForestTrainer,
RidgeTrainer,
HistGradientBoostingDirectionTrainer,
)
from src.training.modular_trainers import (
create_ewc_loss,
migrate_xgboost_model,
migrate_all_models,
train_all_modular,
)
# Module-level cache for lazy-loaded imports
_module_cache = None
_import_error = None
def _get_source_module():
"""
Lazily import the source module with error handling.
Returns:
The src.training.modular_trainers module
Raises:
ImportError: If the source module cannot be imported
"""
global _module_cache, _import_error
# Return cached module if already loaded
if _module_cache is not None:
return _module_cache
# Return cached error if import previously failed
if _import_error is not None:
raise _import_error
try:
from src.training import modular_trainers as source_module
_module_cache = source_module
return source_module
except ImportError as e:
_import_error = ImportError(
f"Cannot import from src.training.modular_trainers: {e}\n"
f"Ensure the module exists and all dependencies are installed."
)
raise _import_error
except Exception as e:
_import_error = ImportError(
f"Unexpected error importing src.training.modular_trainers: {e}"
)
raise _import_error
# Define public API explicitly for better IDE support and static analysis
__all__ = [
# Configuration
'TrainerConfig',
# Callbacks
'EMACallback',
'EWCPenalty',
'OverfitPreventionCallback',
'EWCTrainingCallback',
'RichEpochCallback',
# Utilities
'create_ewc_loss',
'ReplayBuffer',
'DriftDetector',
'TrainingLineage',
# Base Trainer
'BaseTrainer',
# Trainer Implementations
'TCNTrainer',
'TransformerDirectionTrainer',
'TransformerRegimeTrainer',
'XGBoostTrainer',
'RandomForestTrainer',
'RidgeTrainer',
'HistGradientBoostingDirectionTrainer',
# Convenience Functions
'migrate_xgboost_model',
'migrate_all_models',
'train_all_modular',
# Version info
'__version__',
'get_source_module',
]
def __getattr__(name: str):
"""
Lazy-load attributes from the source module.
This function is called when an attribute is accessed that doesn't exist
in this module's __dict__. It provides lazy loading for better performance.
Args:
name: The attribute name being accessed
Returns:
The requested attribute from the source module
Raises:
AttributeError: If the attribute doesn't exist in the source module
"""
# Emit deprecation warning on first access
if not hasattr(sys, '_modular_trainers_warned'):
warnings.warn(
"Importing from 'modular_trainers' is deprecated. "
"Please import from 'src.training.modular_trainers' instead.",
DeprecationWarning,
stacklevel=2
)
sys._modular_trainers_warned = True
source_module = _get_source_module()
try:
return getattr(source_module, name)
except AttributeError:
available = ', '.join(sorted(dir(source_module)))
raise AttributeError(
f"module '{__name__}' has no attribute '{name}'. "
f"Available attributes: {available}"
) from None
def __dir__():
"""Provide directory listing including source module attributes."""
try:
source_module = _get_source_module()
return sorted(set(__all__) | set(dir(source_module)))
except ImportError:
return __all__
# Version information for compatibility checking
__version__ = '1.0.0'
__deprecated__ = True
# Convenience function for explicit module access
def get_source_module():
"""
Get the source module explicitly.
This is useful for code that needs to access the module directly
without triggering the lazy loading mechanism.
Returns:
The src.training.modular_trainers module
"""
return _get_source_module()
# Explicitly export version and module info
__doc__ += f"""
Version: {__version__}
Deprecated: {__deprecated__}
"""