From 4cf4ea5b35b032e5bb172b6008c75adcb4dede19 Mon Sep 17 00:00:00 2001 From: Vaidotas Simkus Date: Mon, 9 Sep 2024 18:07:24 +0300 Subject: [PATCH] fix: maintain hyperparameter order when invoking jobs --- src/sagemaker_training/mapping.py | 8 ++++---- test/unit/test_mapping.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/sagemaker_training/mapping.py b/src/sagemaker_training/mapping.py index 6df40a83..83fe4e7f 100644 --- a/src/sagemaker_training/mapping.py +++ b/src/sagemaker_training/mapping.py @@ -75,7 +75,7 @@ def to_cmd_args(mapping): # type: (dict) -> list (list): List of cmd arguments. """ - sorted_keys = sorted(mapping.keys()) + mapping_keys = mapping.keys() def arg_name(obj): string = _decode(obj) @@ -84,15 +84,15 @@ def arg_name(obj): else: return "" - arg_names = [arg_name(argument) for argument in sorted_keys] + arg_names = [arg_name(argument) for argument in mapping_keys] def arg_value(value): if hasattr(value, "items"): - map_items = ["%s=%s" % (k, v) for k, v in sorted(value.items())] + map_items = ["%s=%s" % (k, v) for k, v in value.items()] return ",".join(map_items) return _decode(value) - arg_values = [arg_value(mapping[key]) for key in sorted_keys] + arg_values = [arg_value(mapping[key]) for key in mapping_keys] items = zip(arg_names, arg_values) diff --git a/test/unit/test_mapping.py b/test/unit/test_mapping.py index d2b2f983..bb9d1b8e 100644 --- a/test/unit/test_mapping.py +++ b/test/unit/test_mapping.py @@ -102,21 +102,21 @@ def test_mapping_throws_exception_trying_to_access_non_properties(property, erro [ ( {"da-sh": "1", "un_der": "2", "un-sh": "3", "da_der": "2"}, - ["--da-sh", "1", "--da_der", "2", "--un-sh", "3", "--un_der", "2"], + ["--da-sh", "1", "--un_der", "2", "--un-sh", "3", "--da_der", "2"], ), ({}, []), ({"": ""}, ["", ""]), ( {"unicode": "¡ø", "bytes": b"2", "floats": 4.0, "int": 2}, - ["--bytes", "2", "--floats", "4.0", "--int", "2", "--unicode", "¡ø"], + ["--unicode", "¡ø", "--bytes", "2", "--floats", "4.0", "--int", "2"], ), - ({"U": "1", "b": b"2", "T": "", "": "42"}, ["", "42", "-T", "", "-U", "1", "-b", "2"]), + ({"U": "1", "b": b"2", "T": "", "": "42"}, ["-U", "1", "-b", "2", "-T", "", "", "42" ]), ({"nested": ["1", ["2", "3", [["6"]]]]}, ["--nested", "['1', ['2', '3', [['6']]]]"]), ( {"map": {"a": [1, 3, 4]}, "channel_dirs": {"train": "foo", "eval": "bar"}}, - ["--channel_dirs", "eval=bar,train=foo", "--map", "a=[1, 3, 4]"], + ["--map", "a=[1, 3, 4]", "--channel_dirs", "train=foo,eval=bar"], ), - ({"truthy": True, "falsy": False}, ["--falsy", "False", "--truthy", "True"]), + ({"truthy": True, "falsy": False}, ["--truthy", "True", "--falsy", "False"]), ], ) def test_to_cmd_args(target, expected): @@ -218,7 +218,7 @@ def test_env_vars_round_trip(): ) assert env_vars["SM_MODULE_NAME"] == "user_script" assert env_vars["SM_INPUT_CONFIG_DIR"].endswith("/opt/ml/input/config") - assert env_vars["SM_USER_ARGS"] == "--batch_size 64 --epochs 10 --loss SGD --precision 5.434322" + assert env_vars["SM_USER_ARGS"] == "--loss SGD --epochs 10 --batch_size 64 --precision 5.434322" assert env_vars["SM_OUTPUT_DIR"].endswith("/opt/ml/output") assert env_vars["SM_MODEL_DIR"].endswith("/opt/ml/model") assert env_vars["SM_HOSTS"] == '["algo-1","algo-2","algo-3"]'