Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/gpuhunt/providers/hotaisle.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
import os
from typing import Optional
from typing import Optional, TypedDict, cast

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think TypedDict should come from typing_extensions. Other providers: nebius, gcp, runpod, oci using typing_extensions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say importing from the standard library is preferred, since it could let us drop the typing-extensions dependency in the future. In dstack, all TypedDict imports come from the standard library. Both options are equivalent, though

import requests
from requests import Response

from gpuhunt._internal.constraints import find_accelerators
from gpuhunt._internal.models import AcceleratorVendor, QueryFilter, RawCatalogItem
from gpuhunt._internal.models import AcceleratorVendor, JSONObject, QueryFilter, RawCatalogItem
from gpuhunt.providers import AbstractProvider

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -53,6 +53,10 @@ def _make_request(self, method: str, url: str) -> Response:
return response


class HotAisleCatalogItemProviderData(TypedDict):
vm_specs: JSONObject


def get_gpu_memory(gpu_name: str) -> Optional[float]:
if accelerators := find_accelerators(names=[gpu_name], vendors=[AcceleratorVendor.AMD]):
return float(accelerators[0].memory)
Expand Down Expand Up @@ -96,6 +100,14 @@ def convert_response_to_raw_catalog_items(response: Response) -> list[RawCatalog
gpu_vendor=gpu_vendor,
spot=False,
disk_size=disk_gb,
provider_data=cast(
JSONObject,
HotAisleCatalogItemProviderData(
# The specs object may duplicate some RawCatalogItem fields, but we store it in
# full because we need to pass it back to the API when creating VMs.
vm_specs=specs,
),
),
)
offers.append(offer)

Expand Down