1
1
#!/usr/bin/env python
2
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
- from dataclasses import dataclass , field
5
- from typing import List , Optional
6
4
7
- from ads . aqua . data import AquaJobSummary
8
- from ads . common . serializer import DataClassSerializable
5
+ import json
6
+ from typing import List , Literal , Optional , Union
9
7
8
+ from pydantic import Field , model_validator
10
9
11
- @dataclass (repr = False )
12
- class AquaFineTuningParams (DataClassSerializable ):
13
- epochs : int
10
+ from ads .aqua .common .errors import AquaValueError
11
+ from ads .aqua .config .utils .serializer import Serializable
12
+ from ads .aqua .data import AquaResourceIdentifier
13
+ from ads .aqua .finetuning .constants import FineTuningRestrictedParams
14
+
15
+
16
+ class AquaFineTuningParams (Serializable ):
17
+ """Class for maintaining aqua fine-tuning model parameters"""
18
+
19
+ epochs : Optional [int ] = None
14
20
learning_rate : Optional [float ] = None
15
- sample_packing : Optional [bool ] = "auto"
21
+ sample_packing : Union [bool , None , Literal [ "auto" ] ] = "auto"
16
22
batch_size : Optional [int ] = (
17
23
None # make it batch_size for user, but internally this is micro_batch_size
18
24
)
@@ -22,21 +28,59 @@ class AquaFineTuningParams(DataClassSerializable):
22
28
lora_alpha : Optional [int ] = None
23
29
lora_dropout : Optional [float ] = None
24
30
lora_target_linear : Optional [bool ] = None
25
- lora_target_modules : Optional [List ] = None
31
+ lora_target_modules : Optional [List [ str ] ] = None
26
32
early_stopping_patience : Optional [int ] = None
27
33
early_stopping_threshold : Optional [float ] = None
28
34
35
+ class Config :
36
+ extra = "allow"
37
+
38
+ def to_dict (self ) -> dict :
39
+ return json .loads (super ().to_json (exclude_none = True ))
40
+
41
+ @model_validator (mode = "before" )
42
+ @classmethod
43
+ def validate_restricted_fields (cls , data : dict ):
44
+ # we may want to skip validation if loading data from config files instead of user entered parameters
45
+ validate = data .pop ("_validate" , True )
46
+ if not (validate and isinstance (data , dict )):
47
+ return data
48
+ restricted_params = [
49
+ param for param in data if param in FineTuningRestrictedParams .values ()
50
+ ]
51
+ if restricted_params :
52
+ raise AquaValueError (
53
+ f"Found restricted parameter name: { restricted_params } "
54
+ )
55
+ return data
29
56
30
- @dataclass (repr = False )
31
- class AquaFineTuningSummary (AquaJobSummary , DataClassSerializable ):
32
- parameters : AquaFineTuningParams = field (default_factory = AquaFineTuningParams )
33
57
58
+ class AquaFineTuningSummary (Serializable ):
59
+ """Represents a summary of Aqua Finetuning job."""
34
60
35
- @dataclass (repr = False )
36
- class CreateFineTuningDetails (DataClassSerializable ):
37
- """Dataclass to create aqua model fine tuning.
61
+ id : str
62
+ name : str
63
+ console_url : str
64
+ lifecycle_state : str
65
+ lifecycle_details : str
66
+ time_created : str
67
+ tags : dict
68
+ experiment : AquaResourceIdentifier = Field (default_factory = AquaResourceIdentifier )
69
+ source : AquaResourceIdentifier = Field (default_factory = AquaResourceIdentifier )
70
+ job : AquaResourceIdentifier = Field (default_factory = AquaResourceIdentifier )
71
+ parameters : AquaFineTuningParams = Field (default_factory = AquaFineTuningParams )
38
72
39
- Fields
73
+ class Config :
74
+ extra = "ignore"
75
+
76
+ def to_dict (self ) -> dict :
77
+ return json .loads (super ().to_json (exclude_none = True ))
78
+
79
+
80
+ class CreateFineTuningDetails (Serializable ):
81
+ """Class to create aqua model fine-tuning instance.
82
+
83
+ Properties
40
84
------
41
85
ft_source_id: str
42
86
The fine tuning source id. Must be model ocid.
@@ -107,3 +151,6 @@ class CreateFineTuningDetails(DataClassSerializable):
107
151
force_overwrite : Optional [bool ] = False
108
152
freeform_tags : Optional [dict ] = None
109
153
defined_tags : Optional [dict ] = None
154
+
155
+ class Config :
156
+ extra = "ignore"
0 commit comments