Skip to content

Commit c849eae

Browse files
authored
change: Improve defaults handling in ModelTrainer (#5170)
* Improve default handling * format * add tests & update docs * fix docstyle * fix input_data_config * fix use input_data_config parameter in train as authoritative source * fix tests * format * update checkpoint config * docstyle * make config creation backwards compatible * format * fix condition * fix Compute and Networking config when attributes are None * format * fix * format
1 parent 8adb660 commit c849eae

File tree

4 files changed

+328
-65
lines changed

4 files changed

+328
-65
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ dependencies = [
5555
"tblib>=1.7.0,<4",
5656
"tqdm",
5757
"urllib3>=1.26.8,<3.0.0",
58-
"uvicorn"
58+
"uvicorn",
59+
"graphene>=3,<4"
5960
]
6061

6162
[project.scripts]

src/sagemaker/modules/configs.py

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from sagemaker_core.shapes import (
3131
StoppingCondition,
3232
RetryStrategy,
33-
OutputDataConfig,
3433
Channel,
3534
ShuffleConfig,
3635
DataSource,
@@ -43,8 +42,6 @@
4342
RemoteDebugConfig,
4443
SessionChainingConfig,
4544
InstanceGroup,
46-
TensorBoardOutputConfig,
47-
CheckpointConfig,
4845
)
4946

5047
from sagemaker.modules.utils import convert_unassigned_to_none
@@ -131,6 +128,8 @@ class Compute(shapes.ResourceConfig):
131128
subsequent training jobs.
132129
instance_groups (Optional[List[InstanceGroup]]):
133130
A list of instance groups for heterogeneous clusters to be used in the training job.
131+
training_plan_arn (Optional[str]):
132+
The Amazon Resource Name (ARN) of the training plan to use for this resource configuration.
134133
enable_managed_spot_training (Optional[bool]):
135134
To train models using managed spot training, choose True. Managed spot training
136135
provides a fully managed and scalable infrastructure for training machine learning
@@ -151,8 +150,12 @@ def _to_resource_config(self) -> shapes.ResourceConfig:
151150
compute_config_dict = self.model_dump()
152151
resource_config_fields = set(shapes.ResourceConfig.__annotations__.keys())
153152
filtered_dict = {
154-
k: v for k, v in compute_config_dict.items() if k in resource_config_fields
153+
k: v
154+
for k, v in compute_config_dict.items()
155+
if k in resource_config_fields and v is not None
155156
}
157+
if not filtered_dict:
158+
return None
156159
return shapes.ResourceConfig(**filtered_dict)
157160

158161

@@ -194,10 +197,12 @@ def _model_validator(self) -> "Networking":
194197
def _to_vpc_config(self) -> shapes.VpcConfig:
195198
"""Convert to a sagemaker_core.shapes.VpcConfig object."""
196199
compute_config_dict = self.model_dump()
197-
resource_config_fields = set(shapes.VpcConfig.__annotations__.keys())
200+
vpc_config_fields = set(shapes.VpcConfig.__annotations__.keys())
198201
filtered_dict = {
199-
k: v for k, v in compute_config_dict.items() if k in resource_config_fields
202+
k: v for k, v in compute_config_dict.items() if k in vpc_config_fields and v is not None
200203
}
204+
if not filtered_dict:
205+
return None
201206
return shapes.VpcConfig(**filtered_dict)
202207

203208

@@ -224,3 +229,66 @@ class InputData(BaseConfig):
224229

225230
channel_name: str = None
226231
data_source: Union[str, FileSystemDataSource, S3DataSource] = None
232+
233+
234+
class OutputDataConfig(shapes.OutputDataConfig):
235+
"""OutputDataConfig.
236+
237+
The OutputDataConfig class is a subclass of ``sagemaker_core.shapes.OutputDataConfig``
238+
and allows the user to specify the output data configuration for the training job.
239+
240+
Parameters:
241+
s3_output_path (Optional[str]):
242+
The S3 URI where the output data will be stored. This is the location where the
243+
training job will save its output data, such as model artifacts and logs.
244+
kms_key_id (Optional[str]):
245+
The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that
246+
SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side
247+
encryption.
248+
compression_type (Optional[str]):
249+
The model output compression type. Select `NONE` to output an uncompressed model,
250+
recommended for large model outputs. Defaults to `GZIP`.
251+
"""
252+
253+
s3_output_path: Optional[str] = None
254+
kms_key_id: Optional[str] = None
255+
compression_type: Optional[str] = None
256+
257+
258+
class TensorBoardOutputConfig(shapes.TensorBoardOutputConfig):
259+
"""TensorBoardOutputConfig.
260+
261+
The TensorBoardOutputConfig class is a subclass of ``sagemaker_core.shapes.TensorBoardOutputConfig``
262+
and allows the user to specify the storage locations for the Amazon SageMaker
263+
Debugger TensorBoard.
264+
265+
Parameters:
266+
s3_output_path (Optional[str]):
267+
Path to Amazon S3 storage location for TensorBoard output. If not specified, will
268+
default to
269+
``s3://<default_bucket>/<default_prefix>/<base_job_name>/<job_name>/tensorboard-output``
270+
local_path (Optional[str]):
271+
Path to local storage location for tensorBoard output. Defaults to /opt/ml/output/tensorboard.
272+
"""
273+
274+
s3_output_path: Optional[str] = None
275+
local_path: Optional[str] = "/opt/ml/output/tensorboard"
276+
277+
278+
class CheckpointConfig(shapes.CheckpointConfig):
279+
"""CheckpointConfig.
280+
281+
The CheckpointConfig class is a subclass of ``sagemaker_core.shapes.CheckpointConfig``
282+
and allows the user to specify the checkpoint configuration for the training job.
283+
284+
Parameters:
285+
s3_uri (Optional[str]):
286+
Path to Amazon S3 storage location for the Checkpoint data. If not specified, will
287+
default to
288+
``s3://<default_bucket>/<default_prefix>/<base_job_name>/<job_name>/checkpoints``
289+
local_path (Optional[str]):
290+
The local directory where checkpoints are written. The default directory is /opt/ml/checkpoints.
291+
"""
292+
293+
s3_uri: Optional[str] = None
294+
local_path: Optional[str] = "/opt/ml/checkpoints"

0 commit comments

Comments
 (0)