Skip to content

Commit 00691b0

Browse files
authored
[Nebius]: Add fabrics list to provider_data (#192)
Associate each Nebius catalog item with a list of InfiniBand fabrics that it supports. The list of fabrics and their details is hardcoded until Nebius exposes it in the API. Previously, the list of fabrics was hardcoded in dstack. Moving it to gpuhunt allows to add new fabrics without a dstack release.
1 parent d5fb338 commit 00691b0

File tree

2 files changed

+74
-1
lines changed

2 files changed

+74
-1
lines changed

src/gpuhunt/providers/nebius.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import re
3-
from typing import Optional
3+
from dataclasses import dataclass
4+
from typing import Optional, cast
45

56
from nebius.aio.channel import Credentials
67
from nebius.api.nebius.billing.v1alpha1 import (
@@ -26,11 +27,13 @@
2627
TenantServiceClient,
2728
)
2829
from nebius.sdk import SDK
30+
from typing_extensions import TypedDict
2931

3032
from gpuhunt._internal.constraints import find_accelerators
3133
from gpuhunt._internal.models import (
3234
AcceleratorInfo,
3335
AcceleratorVendor,
36+
JSONObject,
3437
QueryFilter,
3538
RawCatalogItem,
3639
)
@@ -40,6 +43,26 @@
4043
TIMEOUT = 7
4144

4245

46+
@dataclass(frozen=True)
47+
class InfinibandFabric:
48+
name: str
49+
platform: str
50+
region: str
51+
52+
53+
# https://docs.nebius.com/compute/clusters/gpu#fabrics
54+
INFINIBAND_FABRICS = [
55+
InfinibandFabric("fabric-2", "gpu-h100-sxm", "eu-north1"),
56+
InfinibandFabric("fabric-3", "gpu-h100-sxm", "eu-north1"),
57+
InfinibandFabric("fabric-4", "gpu-h100-sxm", "eu-north1"),
58+
InfinibandFabric("fabric-5", "gpu-h200-sxm", "eu-west1"),
59+
InfinibandFabric("fabric-6", "gpu-h100-sxm", "eu-north1"),
60+
InfinibandFabric("fabric-7", "gpu-h200-sxm", "eu-north1"),
61+
InfinibandFabric("us-central1-a", "gpu-h200-sxm", "us-central1"),
62+
InfinibandFabric("us-central1-b", "gpu-b200-sxm", "us-central1"),
63+
]
64+
65+
4366
class NebiusProvider(AbstractProvider):
4467
NAME = "nebius"
4568

@@ -77,6 +100,10 @@ def get(
77100
return items
78101

79102

103+
class NebiusCatalogItemProviderData(TypedDict):
104+
fabrics: list[str]
105+
106+
80107
def get_sample_projects(sdk: SDK) -> dict[str, str]:
81108
"""
82109
Returns:
@@ -141,6 +168,12 @@ def make_item(
141168
spot: bool,
142169
price: float,
143170
) -> Optional[RawCatalogItem]:
171+
fabrics = []
172+
if preset.allow_gpu_clustering:
173+
fabrics = [
174+
f.name for f in INFINIBAND_FABRICS if f.platform == platform and f.region == region
175+
]
176+
144177
item = RawCatalogItem(
145178
instance_name=f"{platform} {preset.name}",
146179
location=region,
@@ -153,6 +186,7 @@ def make_item(
153186
gpu_vendor=None,
154187
spot=spot,
155188
disk_size=None,
189+
provider_data=cast(JSONObject, NebiusCatalogItemProviderData(fabrics=fabrics)),
156190
)
157191

158192
if preset.resources.gpu_count:

src/integrity_tests/test_nebius.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import csv
2+
import json
23
from operator import itemgetter
34
from pathlib import Path
45

@@ -28,3 +29,41 @@ def test_spots_presented(data_rows: list[dict]):
2829
@pytest.mark.parametrize("location", ["eu-north1", "eu-west1"])
2930
def test_location_present(location: str, data_rows: list[dict]):
3031
assert location in map(itemgetter("location"), data_rows)
32+
33+
34+
def test_fabrics_unique(data_rows: list[dict]) -> None:
35+
for row in data_rows:
36+
fabrics = json.loads(row["provider_data"])["fabrics"]
37+
assert len(fabrics) == len(set(fabrics)), f"Duplicate fabrics in row: {row}"
38+
39+
40+
def test_fabrics_on_sample_offer(data_rows: list[dict]) -> None:
41+
for row in data_rows:
42+
if (
43+
row["instance_name"] == "gpu-h100-sxm 8gpu-128vcpu-1600gb"
44+
and row["location"] == "eu-north1"
45+
):
46+
break
47+
else:
48+
raise ValueError("Offer not found")
49+
fabrics = set(json.loads(row["provider_data"])["fabrics"])
50+
expected_fabrics = {
51+
"fabric-2",
52+
"fabric-3",
53+
"fabric-4",
54+
"fabric-6",
55+
}
56+
missing_fabrics = expected_fabrics - fabrics
57+
assert not missing_fabrics
58+
59+
60+
def test_no_fabrics_on_sample_non_clustered_offer(data_rows: list[dict]) -> None:
61+
for row in data_rows:
62+
if (
63+
row["instance_name"] == "gpu-h100-sxm 1gpu-16vcpu-200gb"
64+
and row["location"] == "eu-north1"
65+
):
66+
break
67+
else:
68+
raise ValueError("Offer not found")
69+
assert json.loads(row["provider_data"])["fabrics"] == []

0 commit comments

Comments
 (0)