Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion src/xpk/commands/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from tabulate import tabulate

from ..utils.feature_flags import FeatureFlags
from ..core.capacity import H100_DEVICE_TYPE, H200_DEVICE_TYPE, B200_DEVICE_TYPE
from ..core.capacity import H100_DEVICE_TYPE, H200_DEVICE_TYPE, B200_DEVICE_TYPE, get_reservation_deployment_type
from ..core.cluster import (
get_all_clusters_programmatic,
get_cluster_credentials,
Expand Down Expand Up @@ -204,6 +204,38 @@ def cluster_adapt(args) -> None:
def _validate_cluster_create_args(args, system: SystemCharacteristics):
if FeatureFlags.SUB_SLICING_ENABLED and args.sub_slicing:
validate_sub_slicing_system(system)
_validate_sub_slicing_reservation(args)


def _validate_sub_slicing_reservation(args):
if args.reservation is None:
xpk_print(
'Error: Validation failed: Sub-slicing cluster creation requires'
' Cluster Director reservation to be specified.'
)
xpk_exit(1)

deployment_type = get_reservation_deployment_type(
reservation=args.reservation, project=args.project, zone=args.zone
)
if deployment_type != 'DENSE':
xpk_print(
'Error: Validation failed: The specified reservation'
f' "{args.reservation}" is not a Cluster Director reservation.'
)
xpk_print(
'Please provide a reservation created for Cluster Director to proceed.'
)
xpk_print('To list valid Cluster Director reservations, run:')
xpk_print(
' gcloud compute reservations list --filter="deploymentType=DENSE"'
)
xpk_print(
'Refer to the documentation for more information on creating Cluster'
' Director reservations:'
' https://cloud.google.com/cluster-director/docs/reserve-capacity'
)
xpk_exit(1)


def cluster_create(args) -> None:
Expand Down
78 changes: 67 additions & 11 deletions src/xpk/commands/cluster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
@dataclass
class _Mocks:
common_print_mock: MagicMock
common_exit_mock: MagicMock
commands_print_mock: MagicMock
commands_get_reservation_deployment_type: MagicMock


@pytest.fixture
Expand All @@ -36,12 +37,17 @@ def mock_common_print_and_exit(mocker):
'xpk.commands.common.xpk_print',
return_value=None,
)
common_exit_mock = mocker.patch(
'xpk.commands.common.xpk_exit',
return_value=None,
commands_print_mock = mocker.patch(
'xpk.commands.cluster.xpk_print', return_value=None
)
commands_get_reservation_deployment_type = mocker.patch(
'xpk.commands.cluster.get_reservation_deployment_type',
return_value='DENSE',
)
return _Mocks(
common_print_mock=common_print_mock, common_exit_mock=common_exit_mock
common_print_mock=common_print_mock,
commands_get_reservation_deployment_type=commands_get_reservation_deployment_type,
commands_print_mock=commands_print_mock,
)


Expand All @@ -61,32 +67,82 @@ def test_validate_cluster_create_args_for_correct_args_pass(
_validate_cluster_create_args(args, DEFAULT_TEST_SYSTEM)

assert mock_common_print_and_exit.common_print_mock.call_count == 0
assert mock_common_print_and_exit.common_exit_mock.call_count == 0


def test_validate_cluster_create_args_for_correct_sub_slicing_args_pass(
mock_common_print_and_exit: _Mocks,
):
FeatureFlags.SUB_SLICING_ENABLED = True
args = Namespace(sub_slicing=True)
args = Namespace(
sub_slicing=True,
reservation='test-reservation',
project='project',
zone='zone',
)

_validate_cluster_create_args(args, SUB_SLICING_SYSTEM)

assert mock_common_print_and_exit.common_print_mock.call_count == 0
assert mock_common_print_and_exit.common_exit_mock.call_count == 0


def test_validate_cluster_create_args_for_not_supported_system_throws(
mock_common_print_and_exit: _Mocks,
):
FeatureFlags.SUB_SLICING_ENABLED = True
args = Namespace(sub_slicing=True)
args = Namespace(
sub_slicing=True,
reservation='test-reservation',
project='project',
zone='zone',
)

_validate_cluster_create_args(args, DEFAULT_TEST_SYSTEM)
with pytest.raises(SystemExit):
_validate_cluster_create_args(args, DEFAULT_TEST_SYSTEM)

assert mock_common_print_and_exit.common_print_mock.call_count == 1
assert (
mock_common_print_and_exit.common_print_mock.call_args[0][0]
== 'Error: l4-1 does not support Sub-slicing.'
)
assert mock_common_print_and_exit.common_exit_mock.call_count == 1


def test_validate_cluster_create_args_for_missing_reservation(
mock_common_print_and_exit: _Mocks,
):
FeatureFlags.SUB_SLICING_ENABLED = True
args = Namespace(
sub_slicing=True, project='project', zone='zone', reservation=None
)

with pytest.raises(SystemExit):
_validate_cluster_create_args(args, SUB_SLICING_SYSTEM)

assert mock_common_print_and_exit.commands_print_mock.call_count == 1
assert (
'Validation failed: Sub-slicing cluster creation requires'
in mock_common_print_and_exit.commands_print_mock.call_args[0][0]
)


def test_validate_cluster_create_args_for_invalid_reservation(
mock_common_print_and_exit: _Mocks,
):
FeatureFlags.SUB_SLICING_ENABLED = True
args = Namespace(
sub_slicing=True,
project='project',
zone='zone',
reservation='test-reservation',
)
mock_common_print_and_exit.commands_get_reservation_deployment_type.return_value = (
'SPARSE'
)

with pytest.raises(SystemExit):
_validate_cluster_create_args(args, SUB_SLICING_SYSTEM)

assert mock_common_print_and_exit.commands_print_mock.call_count == 5
assert (
'Refer to the documentation for more information on creating Cluster'
in mock_common_print_and_exit.commands_print_mock.call_args[0][0]
)
17 changes: 17 additions & 0 deletions src/xpk/core/capacity.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,23 @@ def get_reservation_placement_policy(
return output.strip()


def get_reservation_deployment_type(
reservation: str, zone: str, project: str
) -> str:
"""Get reservation deployment type."""
command = (
f'gcloud beta compute reservations describe {reservation}'
f' --project={project} --zone={zone} --format="value(deploymentType)"'
)
return_code, output = run_command_for_value(
command, 'Get reservation deployment type', dry_run_return_val='DENSE'
)
if return_code != 0:
xpk_print(f'Get reservation deployment type ERROR {return_code}')
xpk_exit(1)
return output.strip()


def verify_reservation_exists(args) -> int:
"""Verify the reservation exists.

Expand Down
50 changes: 50 additions & 0 deletions src/xpk/core/capacity_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import pytest
from unittest.mock import MagicMock, patch
from .capacity import get_reservation_deployment_type


@patch('xpk.core.capacity.xpk_print')
def test_get_reservation_deployment_type_exits_with_command_fails(
xpk_print: MagicMock, mocker
):
mocker.patch(
target='xpk.core.capacity.run_command_for_value', return_value=(1, '')
)
with pytest.raises(SystemExit):
get_reservation_deployment_type(
reservation='reservation', zone='zone', project='project'
)

assert (
'Get reservation deployment type ERROR 1'
in xpk_print.mock_calls[0].args[0]
)


def test_get_reservation_deployment_type_returns_deployment_type_when_command_succeeds(
mocker,
):
mocker.patch(
target='xpk.core.capacity.run_command_for_value',
return_value=(0, 'DENSE'),
)
result = get_reservation_deployment_type(
reservation='reservation', zone='zone', project='project'
)
assert result == 'DENSE'
Loading