Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions src/dstack/_internal/core/backends/nebius/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@
get_offers_disk_modifier,
)
from dstack._internal.core.backends.nebius import resources
from dstack._internal.core.backends.nebius.fabrics import get_suitable_infiniband_fabrics
from dstack._internal.core.backends.nebius.models import NebiusConfig, NebiusServiceAccountCreds
from dstack._internal.core.backends.nebius.models import (
NebiusConfig,
NebiusOfferBackendData,
NebiusServiceAccountCreds,
)
from dstack._internal.core.errors import (
BackendError,
NotYetTerminated,
Expand Down Expand Up @@ -281,23 +284,30 @@ def create_placement_group(
master_instance_offer: InstanceOffer,
) -> PlacementGroupProvisioningData:
assert placement_group.configuration.placement_strategy == PlacementStrategy.CLUSTER
backend_data = NebiusPlacementGroupBackendData(cluster=None)
master_instance_offer_backend_data: NebiusOfferBackendData = (
NebiusOfferBackendData.__response__.parse_obj(master_instance_offer.backend_data)
)
fabrics = list(master_instance_offer_backend_data.fabrics)
if self.config.fabrics is not None:
fabrics = [f for f in fabrics if f in self.config.fabrics]
placement_group_backend_data = NebiusPlacementGroupBackendData(cluster=None)
# Only create a Nebius cluster if the instance supports it.
# For other instances, return dummy PlacementGroupProvisioningData.
if fabrics := get_suitable_infiniband_fabrics(
master_instance_offer, allowed_fabrics=self.config.fabrics
):
if fabrics:
fabric = random.choice(fabrics)
op = resources.create_cluster(
self._sdk,
name=placement_group.name,
project_id=self._region_to_project_id[placement_group.configuration.region],
fabric=fabric,
)
backend_data.cluster = NebiusClusterBackendData(id=op.resource_id, fabric=fabric)
placement_group_backend_data.cluster = NebiusClusterBackendData(
id=op.resource_id,
fabric=fabric,
)
return PlacementGroupProvisioningData(
backend=BackendType.NEBIUS,
backend_data=backend_data.json(),
backend_data=placement_group_backend_data.json(),
)

def delete_placement_group(self, placement_group: PlacementGroup) -> None:
Expand All @@ -317,16 +327,15 @@ def is_suitable_placement_group(
if placement_group.configuration.region != instance_offer.region:
return False
assert placement_group.provisioning_data is not None
backend_data = NebiusPlacementGroupBackendData.load(
placement_group_backend_data = NebiusPlacementGroupBackendData.load(
placement_group.provisioning_data.backend_data
)
instance_offer_backend_data: NebiusOfferBackendData = (
NebiusOfferBackendData.__response__.parse_obj(instance_offer.backend_data)
)
return (
backend_data.cluster is None
or backend_data.cluster.fabric
in get_suitable_infiniband_fabrics(
instance_offer,
allowed_fabrics=None, # enforced at cluster creation time, no need to enforce here
)
placement_group_backend_data.cluster is None
or placement_group_backend_data.cluster.fabric in instance_offer_backend_data.fabrics
)


Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/core/backends/nebius/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from dstack._internal.core.backends.nebius import resources
from dstack._internal.core.backends.nebius.backend import NebiusBackend
from dstack._internal.core.backends.nebius.fabrics import get_all_infiniband_fabrics
from dstack._internal.core.backends.nebius.models import (
NebiusBackendConfig,
NebiusBackendConfigWithCreds,
Expand All @@ -19,6 +18,7 @@
NebiusServiceAccountCreds,
NebiusStoredConfig,
)
from dstack._internal.core.backends.nebius.resources import get_all_infiniband_fabrics
from dstack._internal.core.errors import BackendError, ServerClientError
from dstack._internal.core.models.backends.base import BackendType

Expand Down
49 changes: 0 additions & 49 deletions src/dstack/_internal/core/backends/nebius/fabrics.py

This file was deleted.

4 changes: 4 additions & 0 deletions src/dstack/_internal/core/backends/nebius/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,7 @@ class NebiusConfig(NebiusStoredConfig):
"""

creds: AnyNebiusCreds


class NebiusOfferBackendData(CoreModel):
fabrics: set[str] = set()
14 changes: 14 additions & 0 deletions src/dstack/_internal/core/backends/nebius/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@
from nebius.sdk import SDK

from dstack._internal.core.backends.base.configurator import raise_invalid_credentials_error
from dstack._internal.core.backends.base.offers import get_catalog_offers
from dstack._internal.core.backends.nebius.models import (
DEFAULT_PROJECT_NAME_PREFIX,
NebiusOfferBackendData,
NebiusServiceAccountCreds,
)
from dstack._internal.core.errors import BackendError, NoCapacityError
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.utils.event_loop import DaemonEventLoop
from dstack._internal.utils.logging import get_logger

Expand Down Expand Up @@ -249,6 +252,17 @@ def get_default_subnet(sdk: SDK, project_id: str) -> Subnet:
raise BackendError(f"Could not find default subnet in project {project_id}")


def get_all_infiniband_fabrics() -> set[str]:
offers = get_catalog_offers(backend=BackendType.NEBIUS)
result = set()
for offer in offers:
backend_data: NebiusOfferBackendData = NebiusOfferBackendData.__response__.parse_obj(
offer.backend_data
)
result |= backend_data.fabrics
return result


def create_disk(
sdk: SDK, name: str, project_id: str, size_mib: int, image_family: str, labels: Dict[str, str]
) -> SDKOperation[Operation]:
Expand Down
82 changes: 58 additions & 24 deletions src/tests/_internal/server/routers/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@

from dstack._internal.core.backends.oci import region as oci_region
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.instances import InstanceStatus
from dstack._internal.core.models.instances import (
Gpu,
InstanceOffer,
InstanceStatus,
InstanceType,
Resources,
)
from dstack._internal.core.models.users import GlobalRole, ProjectRole
from dstack._internal.core.models.volumes import VolumeStatus
from dstack._internal.server.models import BackendModel
Expand Down Expand Up @@ -212,6 +218,30 @@ async def test_creates_lambda_backend(
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Nebius requires Python 3.10")
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
class TestNebius:
@pytest.fixture(autouse=True)
def patch_catalog(self):
with patch(
"dstack._internal.core.backends.nebius.resources.get_catalog_offers"
) as get_catalog_offers_mock:
get_catalog_offers_mock.return_value = [
InstanceOffer(
backend=BackendType.NEBIUS,
instance=InstanceType(
name="gpu-h100-sxm 8gpu-128vcpu-1600gb",
resources=Resources(
cpus=128,
memory_mib=1600 * 1024,
gpus=[Gpu(name="H100", memory_mib=80 * 1024)] * 8,
spot=False,
),
),
region="eu-north1",
price=23.6,
backend_data={"fabrics": ["fabric-2", "fabric-3"]},
)
]
yield

async def test_not_creates_with_invalid_creds(
self, test_db, session: AsyncSession, client: AsyncClient
):
Expand All @@ -238,18 +268,16 @@ async def test_not_creates_with_invalid_creds(
assert len(res.scalars().all()) == 0

@pytest.mark.parametrize(
("config_regions", "config_projects", "mocked_projects", "error"),
("config_extra", "mocked_projects", "error"),
[
pytest.param(
None,
None,
{},
[_nebius_project()],
None,
id="default",
),
pytest.param(
["eu-north1"],
None,
{"regions": ["eu-north1"]},
[
_nebius_project(
"project-e00test", "default-project-eu-north1", "eu-north1"
Expand All @@ -260,15 +288,13 @@ async def test_not_creates_with_invalid_creds(
id="with-regions",
),
pytest.param(
["xx-xxxx1"],
None,
{"regions": ["xx-xxxx1"]},
[_nebius_project()],
"do not exist in this Nebius tenancy",
id="error-invalid-regions",
),
pytest.param(
["eu-north1"],
None,
{"regions": ["eu-north1"]},
[
_nebius_project(
"project-e00test0", "default-project-eu-north1", "eu-north1"
Expand All @@ -279,8 +305,7 @@ async def test_not_creates_with_invalid_creds(
id="finds-default-project-among-many",
),
pytest.param(
["eu-north1"],
None,
{"regions": ["eu-north1"]},
[
_nebius_project("project-e00test0", "non-default-project-0", "eu-north1"),
_nebius_project("project-e00test1", "non-default-project-1", "eu-north1"),
Expand All @@ -289,8 +314,7 @@ async def test_not_creates_with_invalid_creds(
id="error-no-default-project",
),
pytest.param(
None,
["project-e00test0"],
{"projects": ["project-e00test0"]},
[
_nebius_project("project-e00test0", "non-default-project-0", "eu-north1"),
_nebius_project("project-e00test1", "non-default-project-1", "eu-north1"),
Expand All @@ -299,15 +323,13 @@ async def test_not_creates_with_invalid_creds(
id="with-projects",
),
pytest.param(
None,
["project-e00xxxx"],
{"projects": ["project-e00xxxx"]},
[_nebius_project()],
"not found in this Nebius tenancy",
id="error-invalid-projects",
),
pytest.param(
None,
["project-e00test0", "project-e00test1"],
{"projects": ["project-e00test0", "project-e00test1"]},
[
_nebius_project("project-e00test0", "non-default-project-0", "eu-north1"),
_nebius_project("project-e00test1", "non-default-project-1", "eu-north1"),
Expand All @@ -316,8 +338,10 @@ async def test_not_creates_with_invalid_creds(
id="error-multiple-projects-in-same-region",
),
pytest.param(
["eu-north1"],
["project-e00test"],
{
"regions": ["eu-north1"],
"projects": ["project-e00test"],
},
[
_nebius_project(
"project-e00test", "default-project-eu-north1", "eu-north1"
Expand All @@ -327,15 +351,26 @@ async def test_not_creates_with_invalid_creds(
None,
id="with-regions-and-projects",
),
pytest.param(
{"fabrics": ["fabric-2", "fabric-3"]},
[_nebius_project()],
None,
id="with-valid-fabrics",
),
pytest.param(
{"fabrics": ["fabric-2", "fabric-invalid"]},
[_nebius_project()],
"InfiniBand fabrics do not exist",
id="with-invalid-fabrics",
),
],
)
async def test_create(
self,
test_db,
session: AsyncSession,
client: AsyncClient,
config_regions: Optional[list[str]],
config_projects: Optional[list[str]],
config_extra: dict[str, Any],
mocked_projects: Sequence[Any],
error: Optional[str],
):
Expand All @@ -347,8 +382,7 @@ async def test_create(
body = {
"type": "nebius",
"creds": FAKE_NEBIUS_SERVICE_ACCOUNT_CREDS,
"regions": config_regions,
"projects": config_projects,
**config_extra,
}
with patch(
"dstack._internal.core.backends.nebius.resources.list_tenant_projects"
Expand Down