Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sharding Support for Neo Optimization Jobs #4924

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
6 changes: 6 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def __init__(
self.endpoint_name = None
self.inference_component_name = None
self._is_compiled_model = False
self._is_sharded_model = False
self._compilation_job_name = None
self._is_edge_packaged_model = False
self.inference_recommender_job_results = None
Expand Down Expand Up @@ -1599,6 +1600,11 @@ def deploy(
if self._base_name is not None:
self._base_name = "-".join((self._base_name, compiled_model_suffix))

if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
logging.warning("Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints.")
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED

# Support multiple models on same endpoint
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
if endpoint_name:
Expand Down
7 changes: 5 additions & 2 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ def _optimize_for_jumpstart(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -702,6 +703,8 @@ def _optimize_for_jumpstart(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand All @@ -727,7 +730,7 @@ def _optimize_for_jumpstart(
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)

optimization_config, override_env = _extract_optimization_config_and_env(
quantization_config, compilation_config
quantization_config, compilation_config, sharding_config
)
if not optimization_config and is_compilation:
override_env = override_env or pysdk_model_env_vars
Expand Down Expand Up @@ -792,7 +795,7 @@ def _optimize_for_jumpstart(
optimization_env_vars = _update_environment_variables(optimization_env_vars, override_env)
if optimization_env_vars:
self.pysdk_model.env.update(optimization_env_vars)
if quantization_config or is_compilation:
if quantization_config or sharding_config or is_compilation:
return create_optimization_job_args
return None

Expand Down
23 changes: 22 additions & 1 deletion src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,7 @@ def optimize(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -1142,6 +1143,8 @@ def optimize(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand Down Expand Up @@ -1170,6 +1173,7 @@ def optimize(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
Expand All @@ -1189,6 +1193,7 @@ def _model_builder_optimize_wrapper(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -1212,6 +1217,8 @@ def _model_builder_optimize_wrapper(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand All @@ -1238,6 +1245,12 @@ def _model_builder_optimize_wrapper(
if quantization_config and compilation_config:
raise ValueError("Quantization config and compilation config are mutually exclusive.")

if sharding_config and (quantization_config or compilation_config or speculative_decoding_config):
raise ValueError("Sharding config is mutually exclusive and cannot be combined with any other optimization.")

if sharding_config and ((env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" not in env_vars) or (sharding_config.get("OverrideEnvironment") and "OPTION_TENSOR_PARALLEL_DEGREE" not in sharding_config["OverrideEnvironment"])):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side note, the same validation is also performed in NeoLambda

raise ValueError("OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.")

self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
self.instance_type = instance_type or self.instance_type
self.role_arn = role_arn or self.role_arn
Expand All @@ -1254,6 +1267,7 @@ def _model_builder_optimize_wrapper(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
Expand All @@ -1272,6 +1286,7 @@ def _model_builder_optimize_wrapper(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
Expand All @@ -1287,6 +1302,9 @@ def _model_builder_optimize_wrapper(
if not speculative_decoding_config:
self.pysdk_model.remove_tag_with_key(Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER)

if sharding_config:
self.pysdk_model._is_sharded_model = True

return self.pysdk_model

def _optimize_for_hf(
Expand All @@ -1297,6 +1315,7 @@ def _optimize_for_hf(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -1312,6 +1331,8 @@ def _optimize_for_hf(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand All @@ -1327,7 +1348,7 @@ def _optimize_for_hf(
self.pysdk_model, speculative_decoding_config, False
)

if quantization_config or compilation_config:
if quantization_config or compilation_config or sharding_config:
create_optimization_job_args = {
"OptimizationJobName": job_name,
"DeploymentInstanceType": self.instance_type,
Expand Down
8 changes: 7 additions & 1 deletion src/sagemaker/serve/utils/optimize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,15 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool:


def _extract_optimization_config_and_env(
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None
) -> Optional[Tuple[Optional[Dict], Optional[Dict]]]:
"""Extracts optimization config and environment variables.

Args:
quantization_config (Optional[Dict]): The quantization config.
compilation_config (Optional[Dict]): The compilation config.
sharding_config (Optional[Dict]): The sharding config.

Returns:
Optional[Tuple[Optional[Dict], Optional[Dict]]]:
Expand All @@ -279,6 +281,10 @@ def _extract_optimization_config_and_env(
return {"ModelCompilationConfig": compilation_config}, compilation_config.get(
"OverrideEnvironment"
)
if sharding_config:
return {"ModelShardingConfig": sharding_config}, sharding_config.get(
"OverrideEnvironment"
)
return None, None


Expand Down
34 changes: 34 additions & 0 deletions tests/unit/sagemaker/serve/builder/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2667,6 +2667,40 @@ def test_optimize_exclusive_args(self, mock_get_serve_setting):
),
)

@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
def test_optimize_exclusive_sharding(self, mock_get_serve_setting):
mock_sagemaker_session = Mock()
model_builder = ModelBuilder(
model="meta-textgeneration-llama-3-70b",
sagemaker_session=mock_sagemaker_session,
)

self.assertRaisesRegex(
ValueError,
"Sharding config is mutually exclusive and cannot be combined with any other optimization.",
lambda: model_builder.optimize(
quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
),
)

@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
def test_optimize_exclusive_sharding_args(self, mock_get_serve_setting):
mock_sagemaker_session = Mock()
model_builder = ModelBuilder(
model="meta-textgeneration-llama-3-70b",
sagemaker_session=mock_sagemaker_session,
)

self.assertRaisesRegex(
ValueError,
"OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.",
lambda: model_builder.optimize(
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
),
)

@patch.object(ModelBuilder, "_prepare_for_mode")
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
def test_optimize_for_hf_with_custom_s3_path(
Expand Down
25 changes: 23 additions & 2 deletions tests/unit/sagemaker/serve/utils/test_optimize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def test_is_s3_uri(s3_uri, expected):


@pytest.mark.parametrize(
"quantization_config, compilation_config, expected_config, expected_env",
"quantization_config, compilation_config, sharding_config, expected_config, expected_env",
[
(
None,
Expand All @@ -270,6 +270,7 @@ def test_is_s3_uri(s3_uri, expected):
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
}
},
None,
{
"ModelCompilationConfig": {
"OverrideEnvironment": {
Expand All @@ -288,6 +289,7 @@ def test_is_s3_uri(s3_uri, expected):
}
},
None,
None,
{
"ModelQuantizationConfig": {
"OverrideEnvironment": {
Expand All @@ -299,7 +301,26 @@ def test_is_s3_uri(s3_uri, expected):
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
},
),
(None, None, None, None),
(
None,
None,
{
"OverrideEnvironment": {
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
}
},
{
"ModelShardingConfig": {
"OverrideEnvironment": {
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
}
},
},
{
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
},
),
(None, None, None, None, None),
],
)
def test_extract_optimization_config_and_env(
Expand Down
Loading