Skip to content
Open
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
@@ -372,6 +372,7 @@ def __init__(
self.endpoint_name = None
self.inference_component_name = None
self._is_compiled_model = False
self._is_sharded_model = False
self._compilation_job_name = None
self._is_edge_packaged_model = False
self.inference_recommender_job_results = None
@@ -1599,6 +1600,11 @@ def deploy(
if self._base_name is not None:
self._base_name = "-".join((self._base_name, compiled_model_suffix))

if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
logging.warning("Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints.")
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED

# Support multiple models on same endpoint
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
if endpoint_name:
7 changes: 5 additions & 2 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
@@ -681,6 +681,7 @@ def _optimize_for_jumpstart(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
@@ -702,6 +703,8 @@ def _optimize_for_jumpstart(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -727,7 +730,7 @@ def _optimize_for_jumpstart(
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)

optimization_config, override_env = _extract_optimization_config_and_env(
quantization_config, compilation_config
quantization_config, compilation_config, sharding_config
)
if not optimization_config and is_compilation:
override_env = override_env or pysdk_model_env_vars
@@ -792,7 +795,7 @@ def _optimize_for_jumpstart(
optimization_env_vars = _update_environment_variables(optimization_env_vars, override_env)
if optimization_env_vars:
self.pysdk_model.env.update(optimization_env_vars)
if quantization_config or is_compilation:
if quantization_config or sharding_config or is_compilation:
return create_optimization_job_args
return None

23 changes: 22 additions & 1 deletion src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
@@ -1119,6 +1119,7 @@ def optimize(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
@@ -1142,6 +1143,8 @@ def optimize(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1170,6 +1173,7 @@ def optimize(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
@@ -1189,6 +1193,7 @@ def _model_builder_optimize_wrapper(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
@@ -1212,6 +1217,8 @@ def _model_builder_optimize_wrapper(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1238,6 +1245,12 @@ def _model_builder_optimize_wrapper(
if quantization_config and compilation_config:
raise ValueError("Quantization config and compilation config are mutually exclusive.")

if sharding_config and (quantization_config or compilation_config or speculative_decoding_config):
raise ValueError("Sharding config is mutually exclusive and cannot be combined with any other optimization.")

if sharding_config and ((env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" not in env_vars) or (sharding_config.get("OverrideEnvironment") and "OPTION_TENSOR_PARALLEL_DEGREE" not in sharding_config["OverrideEnvironment"])):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side note, the same validation is also performed in NeoLambda

raise ValueError("OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.")

self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
self.instance_type = instance_type or self.instance_type
self.role_arn = role_arn or self.role_arn
@@ -1254,6 +1267,7 @@ def _model_builder_optimize_wrapper(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
@@ -1272,6 +1286,7 @@ def _model_builder_optimize_wrapper(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
@@ -1287,6 +1302,9 @@ def _model_builder_optimize_wrapper(
if not speculative_decoding_config:
self.pysdk_model.remove_tag_with_key(Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER)

if sharding_config:
self.pysdk_model._is_sharded_model = True

return self.pysdk_model

def _optimize_for_hf(
@@ -1297,6 +1315,7 @@ def _optimize_for_hf(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
@@ -1312,6 +1331,8 @@ def _optimize_for_hf(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1327,7 +1348,7 @@ def _optimize_for_hf(
self.pysdk_model, speculative_decoding_config, False
)

if quantization_config or compilation_config:
if quantization_config or compilation_config or sharding_config:
create_optimization_job_args = {
"OptimizationJobName": job_name,
"DeploymentInstanceType": self.instance_type,
8 changes: 7 additions & 1 deletion src/sagemaker/serve/utils/optimize_utils.py
Original file line number Diff line number Diff line change
@@ -259,13 +259,15 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool:


def _extract_optimization_config_and_env(
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None
) -> Optional[Tuple[Optional[Dict], Optional[Dict]]]:
"""Extracts optimization config and environment variables.

Args:
quantization_config (Optional[Dict]): The quantization config.
compilation_config (Optional[Dict]): The compilation config.
sharding_config (Optional[Dict]): The sharding config.

Returns:
Optional[Tuple[Optional[Dict], Optional[Dict]]]:
@@ -279,6 +281,10 @@ def _extract_optimization_config_and_env(
return {"ModelCompilationConfig": compilation_config}, compilation_config.get(
"OverrideEnvironment"
)
if sharding_config:
return {"ModelShardingConfig": sharding_config}, sharding_config.get(
"OverrideEnvironment"
)
return None, None


48 changes: 48 additions & 0 deletions tests/unit/sagemaker/model/test_model.py
Original file line number Diff line number Diff line change
@@ -958,6 +958,54 @@ def test_all_framework_models_inference_component_based_endpoint_deploy_path(
sagemaker_session.endpoint_in_service_or_not.reset_mock()
sagemaker_session.create_model.reset_mock()

@patch("sagemaker.utils.repack_model")
@patch("sagemaker.fw_utils.tar_and_upload_dir")
def test_sharded_model_force_inference_component_based_endpoint_deploy_path(
repack_model, tar_and_uload_dir, sagemaker_session
):
framework_model_classes_to_kwargs = {
HuggingFaceModel: {
"pytorch_version": "1.7.1",
"py_version": "py36",
"transformers_version": "4.6.1"
},
}

sagemaker_session.settings = SessionSettings(include_jumpstart_tags=False)

source_dir = "s3://blah/blah/blah"
for framework_model_class, kwargs in framework_model_classes_to_kwargs.items():
test_sharded_model = framework_model_class(
entry_point=ENTRY_POINT_INFERENCE,
role=ROLE,
sagemaker_session=sagemaker_session,
model_data=source_dir,
**kwargs,
)
test_sharded_model._is_sharded_model = True
test_sharded_model.deploy(
instance_type="ml.m2.xlarge",
initial_instance_count=INSTANCE_COUNT,
endpoint_type=EndpointType.MODEL_BASED,
resources=ResourceRequirements(
requests={
"num_accelerators": 1,
"memory": 8192,
"copies": 1,
},
limits={},
),
)

# Verified inference component based endpoint and inference component creation
# path
sagemaker_session.endpoint_in_service_or_not.assert_called_once()
sagemaker_session.create_model.assert_called_once()
sagemaker_session.create_inference_component.assert_called_once()

sagemaker_session.create_inference_component.reset_mock()
sagemaker_session.endpoint_in_service_or_not.reset_mock()
sagemaker_session.create_model.reset_mock()

@patch("sagemaker.utils.repack_model")
def test_repack_code_location_with_key_prefix(repack_model, sagemaker_session):
51 changes: 51 additions & 0 deletions tests/unit/sagemaker/serve/builder/test_js_builder.py
Original file line number Diff line number Diff line change
@@ -1198,6 +1198,57 @@ def test_optimize_quantize_for_jumpstart(

self.assertIsNotNone(out_put)

@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
def test_optimize_sharding_for_jumpstart(
self,
mock_serve_settings,
mock_telemetry,
):
mock_sagemaker_session = Mock()

mock_pysdk_model = Mock()
mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"}
mock_pysdk_model.model_data = mock_model_data
mock_pysdk_model.image_uri = mock_tgi_image_uri
mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS
mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0]

sample_input = {
"inputs": "The diamondback terrapin or simply terrapin is a species "
"of turtle native to the brackish coastal tidal marshes of the",
"parameters": {"max_new_tokens": 1024},
}
sample_output = [
{
"generated_text": "The diamondback terrapin or simply terrapin is a "
"species of turtle native to the brackish coastal "
"tidal marshes of the east coast."
}
]

model_builder = ModelBuilder(
model="meta-textgeneration-llama-3-70b",
schema_builder=SchemaBuilder(sample_input, sample_output),
sagemaker_session=mock_sagemaker_session,
)

model_builder.pysdk_model = mock_pysdk_model

out_put = model_builder._optimize_for_jumpstart(
accept_eula=True,
sharding_config={
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
},
env_vars={
"OPTION_TENSOR_PARALLEL_DEGREE": "1",
"OPTION_MAX_ROLLING_BATCH_SIZE": "2",
},
output_path="s3://bucket/code/",
)

self.assertIsNotNone(out_put)

@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
@patch(
33 changes: 33 additions & 0 deletions tests/unit/sagemaker/serve/builder/test_model_builder.py
Original file line number Diff line number Diff line change
@@ -2667,6 +2667,39 @@ def test_optimize_exclusive_args(self, mock_get_serve_setting):
),
)

@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
def test_optimize_exclusive_sharding(self, mock_get_serve_setting):
mock_sagemaker_session = Mock()
model_builder = ModelBuilder(
model="meta-textgeneration-llama-3-70b",
sagemaker_session=mock_sagemaker_session,
)

self.assertRaisesRegex(
ValueError,
"Sharding config is mutually exclusive and cannot be combined with any other optimization.",
lambda: model_builder.optimize(
compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
),
)

@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
def test_optimize_exclusive_sharding_args(self, mock_get_serve_setting):
mock_sagemaker_session = Mock()
model_builder = ModelBuilder(
model="meta-textgeneration-llama-3-70b",
sagemaker_session=mock_sagemaker_session,
)

self.assertRaisesRegex(
ValueError,
"OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.",
lambda: model_builder.optimize(
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
),
)

@patch.object(ModelBuilder, "_prepare_for_mode")
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
def test_optimize_for_hf_with_custom_s3_path(
29 changes: 25 additions & 4 deletions tests/unit/sagemaker/serve/utils/test_optimize_utils.py
Original file line number Diff line number Diff line change
@@ -261,7 +261,7 @@ def test_is_s3_uri(s3_uri, expected):


@pytest.mark.parametrize(
"quantization_config, compilation_config, expected_config, expected_env",
"quantization_config, compilation_config, sharding_config, expected_config, expected_env",
[
(
None,
@@ -270,6 +270,7 @@ def test_is_s3_uri(s3_uri, expected):
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
}
},
None,
{
"ModelCompilationConfig": {
"OverrideEnvironment": {
@@ -288,6 +289,7 @@ def test_is_s3_uri(s3_uri, expected):
}
},
None,
None,
{
"ModelQuantizationConfig": {
"OverrideEnvironment": {
@@ -299,13 +301,32 @@ def test_is_s3_uri(s3_uri, expected):
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
},
),
(None, None, None, None),
(
None,
None,
{
"OverrideEnvironment": {
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
}
},
{
"ModelShardingConfig": {
"OverrideEnvironment": {
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
}
},
},
{
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
},
),
(None, None, None, None, None),
],
)
def test_extract_optimization_config_and_env(
quantization_config, compilation_config, expected_config, expected_env
quantization_config, compilation_config, sharding_config, expected_config, expected_env
):
assert _extract_optimization_config_and_env(quantization_config, compilation_config) == (
assert _extract_optimization_config_and_env(quantization_config, compilation_config, sharding_config) == (
expected_config,
expected_env,
)