diff --git a/src/xpk/core/nodepool.py b/src/xpk/core/nodepool.py index 58889c439..914aca1bd 100644 --- a/src/xpk/core/nodepool.py +++ b/src/xpk/core/nodepool.py @@ -267,7 +267,7 @@ def run_gke_node_pool_create_command( placement_args = '' if system.requires_workload_policy and is_topology_valid(system.topology): - placement_policy = f'{args.cluster}-placement-policy' + placement_policy = f'{system.device_type}-{system.topology}-placement-policy' ensure_resource_policy_exists(placement_policy, args, system.topology) placement_args = f' --placement-policy={placement_policy}' diff --git a/src/xpk/parser/cluster.py b/src/xpk/parser/cluster.py index a4e416769..e147b41f9 100644 --- a/src/xpk/parser/cluster.py +++ b/src/xpk/parser/cluster.py @@ -143,12 +143,7 @@ def set_cluster_create_parser(cluster_create_parser: ArgumentParser): ' enable cluster to accept Pathways workloads.' ), ) - if FeatureFlags.SUB_SLICING_ENABLED: - cluster_create_optional_arguments.add_argument( - '--sub-slicing', - action='store_true', - help='Whether to set up cluster to support sub-slicing', - ) + add_sub_slicing_arguments(cluster_create_optional_arguments) autoprovisioning_arguments = cluster_create_parser.add_argument_group( 'Autoprovisioning Arguments', @@ -221,6 +216,7 @@ def set_cluster_create_pathways_parser( add_shared_cluster_create_optional_arguments( cluster_create_pathways_optional_arguments ) + add_sub_slicing_arguments(cluster_create_pathways_optional_arguments) autoprovisioning_arguments = ( cluster_create_pathways_parser.add_argument_group( @@ -350,7 +346,9 @@ def set_cluster_create_ray_parser(cluster_create_ray_parser: ArgumentParser): ) add_resource_limits(cluster_create_resource_limits) - cluster_create_ray_parser.set_defaults(func=cluster_create_ray_cluster) + cluster_create_ray_parser.set_defaults( + func=cluster_create_ray_cluster, sub_slicing=False + ) def set_cluster_delete_parser(cluster_delete_parser: ArgumentParser): @@ -567,6 +565,15 @@ def set_cluster_adapt_parser(cluster_adapt_parser: ArgumentParser): cluster_adapt_parser.set_defaults(func=cluster_adapt) +def add_sub_slicing_arguments(parser_or_group: ParserOrArgumentGroup): + if FeatureFlags.SUB_SLICING_ENABLED: + parser_or_group.add_argument( + '--sub-slicing', + action='store_true', + help='Whether to set up cluster to support sub-slicing', + ) + + def add_autoprovisioning_arguments(parser_or_group: ParserOrArgumentGroup): parser_or_group.add_argument( '--enable-autoprovisioning', diff --git a/src/xpk/parser/cluster_test.py b/src/xpk/parser/cluster_test.py index ec1d81b6b..b5c572fa4 100644 --- a/src/xpk/parser/cluster_test.py +++ b/src/xpk/parser/cluster_test.py @@ -15,7 +15,7 @@ """ import argparse -from xpk.parser.cluster import set_cluster_create_parser +from xpk.parser.cluster import set_cluster_create_parser, set_cluster_create_pathways_parser, set_cluster_create_ray_parser import pytest from ..utils.feature_flags import FeatureFlags @@ -64,3 +64,69 @@ def test_cluster_create_sub_slicing_can_be_set(): ) assert args.sub_slicing is True + + +def test_cluster_create_pathways_sub_slicing_is_hidden_with_flag_off(): + FeatureFlags.SUB_SLICING_ENABLED = False + parser = argparse.ArgumentParser() + + set_cluster_create_pathways_parser(parser) + help_str = parser.format_help() + + assert "--sub-slicing" not in help_str + + +def test_cluster_create_pathways_sub_slicing_is_shown_with_flag_on(): + parser = argparse.ArgumentParser() + + set_cluster_create_pathways_parser(parser) + help_str = parser.format_help() + + assert "--sub-slicing" in help_str + + +def test_cluster_create_pathways_sub_slicing_is_false_by_default(): + parser = argparse.ArgumentParser() + + set_cluster_create_pathways_parser(parser) + args = parser.parse_args( + ["--cluster", "test-cluster", "--tpu-type", "test-tpu"] + ) + + assert args.sub_slicing is False + + +def test_cluster_create_pathways_sub_slicing_can_be_set(): + parser = argparse.ArgumentParser() + + set_cluster_create_pathways_parser(parser) + args = parser.parse_args( + ["--cluster", "test-cluster", "--tpu-type", "test-tpu", "--sub-slicing"] + ) + + assert args.sub_slicing is True + + +def test_cluster_create_ray_sub_slicing_is_hidden(): + parser = argparse.ArgumentParser() + + set_cluster_create_ray_parser(parser) + help_str = parser.format_help() + + assert "--sub-slicing" not in help_str + + +def test_cluster_create_ray_sub_slicing_is_false(): + parser = argparse.ArgumentParser() + + set_cluster_create_ray_parser(parser) + args = parser.parse_args([ + "--cluster", + "test-cluster", + "--tpu-type", + "test-tpu", + "--ray-version", + "1.0.0", + ]) + + assert args.sub_slicing is False diff --git a/src/xpk/parser/workload_test.py b/src/xpk/parser/workload_test.py index 7e0d53d19..4988aa6ab 100644 --- a/src/xpk/parser/workload_test.py +++ b/src/xpk/parser/workload_test.py @@ -15,7 +15,7 @@ """ import argparse -from xpk.parser.workload import set_workload_create_parser +from xpk.parser.workload import set_workload_create_parser, set_workload_create_pathways_parser from ..utils.feature_flags import FeatureFlags import pytest @@ -32,7 +32,7 @@ def test_workload_create_sub_slicing_topology_is_hidden_with_flag_off(): set_workload_create_parser(parser) help_str = parser.format_help() - assert "--sub-slicing" not in help_str + assert "--sub-slicing-topology" not in help_str def test_workload_create_sub_slicing_topology_is_shown_with_flag_on(): @@ -41,7 +41,7 @@ def test_workload_create_sub_slicing_topology_is_shown_with_flag_on(): set_workload_create_parser(parser) help_str = parser.format_help() - assert "--sub-slicing" in help_str + assert "--sub-slicing-topology" in help_str def test_workload_create_sub_slicing_topology_is_none_by_default(): @@ -80,3 +80,60 @@ def test_workload_create_sub_slicing_topology_can_be_set(): ]) assert args.sub_slicing_topology is "2x2" + + +def test_workload_create_pathways_sub_slicing_topology_is_hidden_with_flag_off(): + FeatureFlags.SUB_SLICING_ENABLED = False + parser = argparse.ArgumentParser() + + set_workload_create_pathways_parser(parser) + help_str = parser.format_help() + + assert "--sub-slicing-topology" not in help_str + + +def test_workload_create_pathways_sub_slicing_topology_is_shown_with_flag_on(): + parser = argparse.ArgumentParser() + + set_workload_create_pathways_parser(parser) + help_str = parser.format_help() + + assert "--sub-slicing-topology" in help_str + + +def test_workload_create_pathways_sub_slicing_topology_is_none_by_default(): + parser = argparse.ArgumentParser() + + set_workload_create_pathways_parser(parser) + args = parser.parse_args([ + "--cluster", + "test-cluster", + "--command", + "python3", + "--workload", + "test", + "--tpu-type", + "test-tpu", + ]) + + assert args.sub_slicing_topology is None + + +def test_workload_create_pathways_sub_slicing_topology_can_be_set(): + parser = argparse.ArgumentParser() + + set_workload_create_pathways_parser(parser) + args = parser.parse_args([ + "--cluster", + "test-cluster", + "--command", + "python3", + "--workload", + "test", + "--tpu-type", + "test-tpu", + "--sub-slicing-topology", + "2x2", + ]) + + assert args.sub_slicing_topology is "2x2"