30
30
from sagemaker_core .shapes import (
31
31
StoppingCondition ,
32
32
RetryStrategy ,
33
- OutputDataConfig ,
34
33
Channel ,
35
34
ShuffleConfig ,
36
35
DataSource ,
43
42
RemoteDebugConfig ,
44
43
SessionChainingConfig ,
45
44
InstanceGroup ,
46
- TensorBoardOutputConfig ,
47
- CheckpointConfig ,
48
45
)
49
46
50
47
from sagemaker .modules .utils import convert_unassigned_to_none
@@ -131,6 +128,8 @@ class Compute(shapes.ResourceConfig):
131
128
subsequent training jobs.
132
129
instance_groups (Optional[List[InstanceGroup]]):
133
130
A list of instance groups for heterogeneous clusters to be used in the training job.
131
+ training_plan_arn (Optional[str]):
132
+ The Amazon Resource Name (ARN) of the training plan to use for this resource configuration.
134
133
enable_managed_spot_training (Optional[bool]):
135
134
To train models using managed spot training, choose True. Managed spot training
136
135
provides a fully managed and scalable infrastructure for training machine learning
@@ -151,8 +150,12 @@ def _to_resource_config(self) -> shapes.ResourceConfig:
151
150
compute_config_dict = self .model_dump ()
152
151
resource_config_fields = set (shapes .ResourceConfig .__annotations__ .keys ())
153
152
filtered_dict = {
154
- k : v for k , v in compute_config_dict .items () if k in resource_config_fields
153
+ k : v
154
+ for k , v in compute_config_dict .items ()
155
+ if k in resource_config_fields and v is not None
155
156
}
157
+ if not filtered_dict :
158
+ return None
156
159
return shapes .ResourceConfig (** filtered_dict )
157
160
158
161
@@ -194,10 +197,12 @@ def _model_validator(self) -> "Networking":
194
197
def _to_vpc_config (self ) -> shapes .VpcConfig :
195
198
"""Convert to a sagemaker_core.shapes.VpcConfig object."""
196
199
compute_config_dict = self .model_dump ()
197
- resource_config_fields = set (shapes .VpcConfig .__annotations__ .keys ())
200
+ vpc_config_fields = set (shapes .VpcConfig .__annotations__ .keys ())
198
201
filtered_dict = {
199
- k : v for k , v in compute_config_dict .items () if k in resource_config_fields
202
+ k : v for k , v in compute_config_dict .items () if k in vpc_config_fields and v is not None
200
203
}
204
+ if not filtered_dict :
205
+ return None
201
206
return shapes .VpcConfig (** filtered_dict )
202
207
203
208
@@ -224,3 +229,66 @@ class InputData(BaseConfig):
224
229
225
230
channel_name : str = None
226
231
data_source : Union [str , FileSystemDataSource , S3DataSource ] = None
232
+
233
+
234
+ class OutputDataConfig (shapes .OutputDataConfig ):
235
+ """OutputDataConfig.
236
+
237
+ The OutputDataConfig class is a subclass of ``sagemaker_core.shapes.OutputDataConfig``
238
+ and allows the user to specify the output data configuration for the training job.
239
+
240
+ Parameters:
241
+ s3_output_path (Optional[str]):
242
+ The S3 URI where the output data will be stored. This is the location where the
243
+ training job will save its output data, such as model artifacts and logs.
244
+ kms_key_id (Optional[str]):
245
+ The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that
246
+ SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side
247
+ encryption.
248
+ compression_type (Optional[str]):
249
+ The model output compression type. Select `NONE` to output an uncompressed model,
250
+ recommended for large model outputs. Defaults to `GZIP`.
251
+ """
252
+
253
+ s3_output_path : Optional [str ] = None
254
+ kms_key_id : Optional [str ] = None
255
+ compression_type : Optional [str ] = None
256
+
257
+
258
+ class TensorBoardOutputConfig (shapes .TensorBoardOutputConfig ):
259
+ """TensorBoardOutputConfig.
260
+
261
+ The TensorBoardOutputConfig class is a subclass of ``sagemaker_core.shapes.TensorBoardOutputConfig``
262
+ and allows the user to specify the storage locations for the Amazon SageMaker
263
+ Debugger TensorBoard.
264
+
265
+ Parameters:
266
+ s3_output_path (Optional[str]):
267
+ Path to Amazon S3 storage location for TensorBoard output. If not specified, will
268
+ default to
269
+ ``s3://<default_bucket>/<default_prefix>/<base_job_name>/<job_name>/tensorboard-output``
270
+ local_path (Optional[str]):
271
+ Path to local storage location for tensorBoard output. Defaults to /opt/ml/output/tensorboard.
272
+ """
273
+
274
+ s3_output_path : Optional [str ] = None
275
+ local_path : Optional [str ] = "/opt/ml/output/tensorboard"
276
+
277
+
278
+ class CheckpointConfig (shapes .CheckpointConfig ):
279
+ """CheckpointConfig.
280
+
281
+ The CheckpointConfig class is a subclass of ``sagemaker_core.shapes.CheckpointConfig``
282
+ and allows the user to specify the checkpoint configuration for the training job.
283
+
284
+ Parameters:
285
+ s3_uri (Optional[str]):
286
+ Path to Amazon S3 storage location for the Checkpoint data. If not specified, will
287
+ default to
288
+ ``s3://<default_bucket>/<default_prefix>/<base_job_name>/<job_name>/checkpoints``
289
+ local_path (Optional[str]):
290
+ The local directory where checkpoints are written. The default directory is /opt/ml/checkpoints.
291
+ """
292
+
293
+ s3_uri : Optional [str ] = None
294
+ local_path : Optional [str ] = "/opt/ml/checkpoints"
0 commit comments