Skip to content
This repository was archived by the owner on Feb 18, 2026. It is now read-only.

Commit af3d69e

Browse files
danecorkboydbinaryaaron
authored andcommitted
[FIX] Unsloth fix for rope scaling >= 1.0
Co-authored-by: Kendrick Boyd <kendrick@gretel.ai> Co-authored-by: Aaron Gonzales <aagonzales@nvidia.com> GitOrigin-RevId: a80a21bf679857791e024be9222c9b25d09b7bb2
1 parent 866f3e1 commit af3d69e

File tree

2 files changed

+71
-49
lines changed

2 files changed

+71
-49
lines changed

src/gretel_client/workflows/configs/registry.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,63 +11,63 @@ def __setattr__(cls, name: str, value: Any) -> None:
1111

1212

1313
class Registry(metaclass=RegistryMeta):
14-
ConcatDatasets = tasks.ConcatDatasets
15-
ExtractDataSeedsFromSampleRecords = tasks.ExtractDataSeedsFromSampleRecords
16-
IdGenerator = tasks.IdGenerator
17-
LoadDataSeeds = tasks.LoadDataSeeds
18-
GenerateColumnFromTemplateV2 = tasks.GenerateColumnFromTemplateV2
19-
DropColumns = tasks.DropColumns
20-
NameGenerator = tasks.NameGenerator
2114
GenerateDatasetFromSampleRecords = tasks.GenerateDatasetFromSampleRecords
2215
SampleDataSeeds = tasks.SampleDataSeeds
16+
EvaluateDataDesignerDataset = tasks.EvaluateDataDesignerDataset
17+
SampleFromDataset = tasks.SampleFromDataset
18+
Holdout = tasks.Holdout
19+
GenerateColumnFromExpression = tasks.GenerateColumnFromExpression
20+
GenerateSamplingColumnConfigFromInstruction = (
21+
tasks.GenerateSamplingColumnConfigFromInstruction
22+
)
23+
ExtractDataSeedsFromSampleRecords = tasks.ExtractDataSeedsFromSampleRecords
2324
RunSampleToDataset = tasks.RunSampleToDataset
2425
GetGretelDataset = tasks.GetGretelDataset
25-
GenerateColumnFromExpression = tasks.GenerateColumnFromExpression
26+
JudgeWithLlm = tasks.JudgeWithLlm
27+
IdGenerator = tasks.IdGenerator
2628
Combiner = tasks.Combiner
29+
GenerateColumnFromTemplateV2 = tasks.GenerateColumnFromTemplateV2
30+
DropColumns = tasks.DropColumns
31+
ConcatDatasets = tasks.ConcatDatasets
32+
GenerateColumnConfigFromInstruction = tasks.GenerateColumnConfigFromInstruction
33+
EvaluateSafeSyntheticsDataset = tasks.EvaluateSafeSyntheticsDataset
2734
DummyTaskWithInputs = tasks.DummyTaskWithInputs
2835
DummyTaskWithListOfInputs = tasks.DummyTaskWithListOfInputs
36+
NameGenerator = tasks.NameGenerator
2937
TestFailingTask = tasks.TestFailingTask
3038
TestOptionalArgTask = tasks.TestOptionalArgTask
3139
TestRequiredAndOptionalArgsTask = tasks.TestRequiredAndOptionalArgsTask
3240
TestTaskCallingTask = tasks.TestTaskCallingTask
3341
TestUnhandledErrorTask = tasks.TestUnhandledErrorTask
34-
GenerateSamplingColumnConfigFromInstruction = (
35-
tasks.GenerateSamplingColumnConfigFromInstruction
36-
)
37-
EvaluateDataDesignerDataset = tasks.EvaluateDataDesignerDataset
38-
Holdout = tasks.Holdout
39-
GenerateColumnConfigFromInstruction = tasks.GenerateColumnConfigFromInstruction
40-
SampleFromDataset = tasks.SampleFromDataset
41-
JudgeWithLlm = tasks.JudgeWithLlm
4242
GenerateColumnsUsingSamplers = tasks.GenerateColumnsUsingSamplers
43-
EvaluateSafeSyntheticsDataset = tasks.EvaluateSafeSyntheticsDataset
44-
ValidateCode = tasks.ValidateCode
45-
SeedFromRecords = tasks.SeedFromRecords
46-
EvaluateDataset = tasks.EvaluateDataset
47-
TabularGan = tasks.TabularGan
43+
LoadDataSeeds = tasks.LoadDataSeeds
4844
TabularFt = tasks.TabularFt
45+
PromptPretrainedModel = tasks.PromptPretrainedModel
46+
TabularGan = tasks.TabularGan
4947
Transform = tasks.Transform
5048
TextFt = tasks.TextFt
51-
PromptPretrainedModel = tasks.PromptPretrainedModel
49+
ValidateCode = tasks.ValidateCode
50+
EvaluateDataset = tasks.EvaluateDataset
51+
SeedFromRecords = tasks.SeedFromRecords
52+
S3Destination = tasks.S3Destination
53+
S3Source = tasks.S3Source
5254
DataSource = tasks.DataSource
53-
AzureDestination = tasks.AzureDestination
54-
AzureSource = tasks.AzureSource
55-
MssqlDestination = tasks.MssqlDestination
56-
MssqlSource = tasks.MssqlSource
5755
GcsDestination = tasks.GcsDestination
5856
GcsSource = tasks.GcsSource
5957
BigqueryDestination = tasks.BigqueryDestination
6058
BigquerySource = tasks.BigquerySource
6159
SnowflakeDestination = tasks.SnowflakeDestination
6260
SnowflakeSource = tasks.SnowflakeSource
63-
PostgresDestination = tasks.PostgresDestination
64-
PostgresSource = tasks.PostgresSource
6561
DatabricksDestination = tasks.DatabricksDestination
6662
DatabricksSource = tasks.DatabricksSource
67-
OracleDestination = tasks.OracleDestination
68-
OracleSource = tasks.OracleSource
69-
S3Destination = tasks.S3Destination
70-
S3Source = tasks.S3Source
63+
PostgresDestination = tasks.PostgresDestination
64+
PostgresSource = tasks.PostgresSource
7165
MysqlDestination = tasks.MysqlDestination
7266
MysqlSource = tasks.MysqlSource
67+
AzureDestination = tasks.AzureDestination
68+
AzureSource = tasks.AzureSource
69+
OracleDestination = tasks.OracleDestination
70+
OracleSource = tasks.OracleSource
71+
MssqlDestination = tasks.MssqlDestination
72+
MssqlSource = tasks.MssqlSource
7373

src/gretel_client/workflows/configs/tasks.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ class EvaluateDataset(ConfigBase):
126126

127127

128128
class EvaluateSafeSyntheticsDataset(ConfigBase):
129+
model_suite: Annotated[Optional[str], Field(title="Model Suite")] = "apache-2.0"
130+
error_rate: Annotated[
131+
Optional[float], Field(ge=0.0, le=1.0, title="Error Rate")
132+
] = 0.2
129133
skip_attribute_inference_protection: Annotated[
130134
Optional[bool], Field(title="Skip Attribute Inference Protection")
131135
] = False
@@ -938,6 +942,26 @@ class NumInputRecordsToSample1(str, Enum):
938942
AUTO = "auto"
939943

940944

945+
class UseUnsloth(str, Enum):
946+
AUTO = "auto"
947+
948+
949+
class RopeScalingFactor(RootModel[int]):
950+
root: Annotated[
951+
int,
952+
Field(
953+
description="Scale the base LLM's context length by this factor using RoPE scaling.",
954+
ge=1,
955+
le=6,
956+
title="rope_scaling_factor",
957+
),
958+
]
959+
960+
961+
class RopeScalingFactor1(str, Enum):
962+
AUTO = "auto"
963+
964+
941965
class TabularFTTrainingParams(ConfigBase):
942966
num_input_records_to_sample: Annotated[
943967
Optional[Union[NumInputRecordsToSample, NumInputRecordsToSample1]],
@@ -1020,18 +1044,16 @@ class TabularFTTrainingParams(ConfigBase):
10201044
),
10211045
] = ["q_proj", "k_proj", "v_proj", "o_proj"]
10221046
use_unsloth: Annotated[
1023-
Optional[bool],
1047+
Optional[Union[bool, UseUnsloth]],
10241048
Field(description="Whether to use unsloth.", title="use_unsloth"),
1025-
] = True
1049+
] = "auto"
10261050
rope_scaling_factor: Annotated[
1027-
Optional[int],
1051+
Optional[Union[RopeScalingFactor, RopeScalingFactor1]],
10281052
Field(
1029-
description="Scale the base LLM's context length by this factor using RoPE scaling. Only works if use_unsloth is set to True.",
1030-
ge=1,
1031-
le=6,
1053+
description="Scale the base LLM's context length by this factor using RoPE scaling.",
10321054
title="rope_scaling_factor",
10331055
),
1034-
] = 1
1056+
] = "auto"
10351057

10361058

10371059
class MaxSequencesPerExample(str, Enum):
@@ -1294,7 +1316,7 @@ class PeftParams(ConfigBase):
12941316
),
12951317
] = 1
12961318
target_modules: Annotated[
1297-
Optional[Union[str, List[str]]],
1319+
Optional[Union[List[str], str]],
12981320
Field(
12991321
description="List of module names or regex expression of the module names to replace with LoRA. For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer. If not specified, modules will be chosen according to the model architecture. If the architecture is not known, an error will be raised -- in this case, you should specify the target modules manually.",
13001322
title="Target Modules",
@@ -1546,11 +1568,11 @@ class Column(ConfigBase):
15461568
Optional[str], Field(description="Rename to value.", title="Value")
15471569
] = None
15481570
entity: Annotated[
1549-
Optional[Union[str, List[str]]],
1571+
Optional[Union[List[str], str]],
15501572
Field(description="Column entity match.", title="Entity"),
15511573
] = None
15521574
type: Annotated[
1553-
Optional[Union[str, List[str]]],
1575+
Optional[Union[List[str], str]],
15541576
Field(description="Column type match.", title="Type"),
15551577
] = None
15561578

@@ -1622,7 +1644,7 @@ class NERConfig(ConfigBase):
16221644

16231645
class Row(ConfigBase):
16241646
name: Annotated[
1625-
Optional[Union[str, List[str]]], Field(description="Row name.", title="Name")
1647+
Optional[Union[List[str], str]], Field(description="Row name.", title="Name")
16261648
] = None
16271649
condition: Annotated[
16281650
Optional[str], Field(description="Row condition match.", title="Condition")
@@ -1634,11 +1656,11 @@ class Row(ConfigBase):
16341656
Optional[str], Field(description="Row value definition.", title="Value")
16351657
] = None
16361658
entity: Annotated[
1637-
Optional[Union[str, List[str]]],
1659+
Optional[Union[List[str], str]],
16381660
Field(description="Row entity match.", title="Entity"),
16391661
] = None
16401662
type: Annotated[
1641-
Optional[Union[str, List[str]]],
1663+
Optional[Union[List[str], str]],
16421664
Field(description="Row type match.", title="Type"),
16431665
] = None
16441666
fallback_value: Annotated[
@@ -2087,21 +2109,21 @@ class JudgeWithLlm(ConfigBase):
20872109
prompt: Annotated[
20882110
str,
20892111
Field(
2090-
description="Template for generating prompts. Use Jinja2 templates to reference dataset columns.",
2112+
description="Template for generating prompts.Use Jinja2 templates to reference dataset columns.",
20912113
title="Prompt",
20922114
),
20932115
]
20942116
num_samples_to_judge: Annotated[
20952117
Optional[int],
20962118
Field(
2097-
description="Number of samples to judge. Default is 100.",
2119+
description="Number of samples to judge.If unset or None, then defaults to judging all records. Default is None.",
20982120
title="Num Samples To Judge",
20992121
),
2100-
] = 100
2122+
] = None
21012123
rubrics: Annotated[
21022124
List[Rubric],
21032125
Field(
2104-
description="List of rubric configurations to use for evaluation. At least one must be provided.",
2126+
description="List of rubric configurations to use for evaluation.At least one must be provided.",
21052127
min_length=1,
21062128
title="Rubrics",
21072129
),

0 commit comments

Comments
 (0)