Skip to content

Commit

Permalink
introduce cluster list cli cmd
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandra Belousov authored and Alexandra Belousov committed Sep 8, 2024
1 parent bce8e12 commit 6e69743
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 0 deletions.
103 changes: 103 additions & 0 deletions runhouse/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import datetime
import importlib
import logging
Expand All @@ -9,6 +10,8 @@
from pathlib import Path
from typing import Any, Dict, List, Optional

import httpx

import ray

import requests
Expand All @@ -17,6 +20,7 @@
import typer
import yaml
from rich.console import Console
from rich.table import Table

import runhouse as rh

Expand All @@ -41,9 +45,18 @@
kill_actors,
)

from runhouse.utils import get_status_color

# create an explicit Typer application
app = typer.Typer(add_completion=False)

# creating a cluster app so we could create subcommands of cluster (i.e runhouse cluster list)
italic_bold_ansi = "\x1B[3m\x1B[1m"
reset_format = "\x1B[0m"
cluster_app = typer.Typer(
help=f"Cluster information commands. For more info run {italic_bold_ansi}runhouse cluster --help{reset_format}"
)

# For printing with typer
console = Console()

Expand Down Expand Up @@ -518,6 +531,9 @@ def _print_status(status_data: dict, current_cluster: Cluster) -> None:
_print_envs_info(env_servlet_processes, current_cluster)


###############################
# Cluster CLI commands
###############################
@app.command()
def status(
cluster_name: str = typer.Argument(
Expand Down Expand Up @@ -584,6 +600,93 @@ def status(
_print_status(cluster_status, current_cluster)


async def aget_clusters_from_den():
httpx_client = httpx.AsyncClient()

get_clusters_params = {"resource_type": "cluster", "folder": rns_client.username}
clusters_in_den_resp = await httpx_client.get(
f"{rns_client.api_server_url}/resource",
params=get_clusters_params,
headers=rns_client.request_headers(),
)

return clusters_in_den_resp


def get_clusters_from_den():
return asyncio.run(aget_clusters_from_den())


@cluster_app.command("list")
def cluster_list():
"""Load Runhouse clusters"""
import sky

# logged out case
if not rh.configs.token:
# TODO [SB]: adjust msg formatting (coloring etc)
sky_cli_command_formatted = f"{italic_bold_ansi}sky status -r{reset_format}" # will be printed bold and italic
console.print(
f"This feature is available only for Den users. Please run {sky_cli_command_formatted} to get on-demand cluster(s) information or sign-up Den."
)
return

on_demand_clusters_sky = sky.status(refresh=True)

clusters_in_den_resp = get_clusters_from_den()

if clusters_in_den_resp.status_code != 200:
logger.error(f"Failed to load {rns_client.username}'s clusters from Den")
clusters_in_den = []
else:
clusters_in_den = clusters_in_den_resp.json().get("data")

clusters_in_den_names = [cluster.get("name") for cluster in clusters_in_den]

if not on_demand_clusters_sky and not clusters_in_den:
console.print("No existing clusters.")

if on_demand_clusters_sky:
# getting the on-demand clusters that are not saved in den.
on_demand_clusters_sky = [
cluster
for cluster in on_demand_clusters_sky
if f'/{rns_client.username}/{cluster.get("name")}'
not in clusters_in_den_names
]

total_clusters = len(clusters_in_den)
table_title = f"[bold cyan]{rns_client.username}'s Clusters (Total: {total_clusters})[/bold cyan]"

table = Table(title=table_title)

# Add columns to the table
table.add_column("Name", justify="left", no_wrap=True)
table.add_column("Cluster Type", justify="center", no_wrap=True)
table.add_column("Status", justify="left")

for den_cluster in clusters_in_den:
# get just name, not full rns address. reset is used so the name will be printed all in white.
cluster_name = f'[reset]{den_cluster.get("name").split("/")[-1]}'
cluster_type = den_cluster.get("data").get("resource_subtype")
cluster_status = (
den_cluster.get("status") if den_cluster.get("status") else "unknown"
)
cluster_status_colored = get_status_color(cluster_status)
table.add_row(cluster_name, cluster_type, cluster_status_colored)

console.print(table)

if len(on_demand_clusters_sky) > 0:
console.print(
f"There are {len(on_demand_clusters_sky)} live clusters that are not saved in Den. To get information about them, please run [bold italic]sky status -r[/bold italic]."
)


# Register the 'cluster' command group with the main runhouse application
app.add_typer(cluster_app, name="cluster")


def load_cluster(cluster_name: str):
"""Load a cluster from RNS into the local environment, e.g. to be able to ssh."""
c = cluster(name=cluster_name)
Expand Down
12 changes: 12 additions & 0 deletions runhouse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,3 +652,15 @@ def get_gpu_usage(collected_gpus_info: dict, servlet_type: ServletType):
gpu_usage["utilization_percent"] = gpu_utilization_percent

return gpu_usage


class StatusColors(str, Enum):
RUNNING = "[green]Running[/green]"
SERVER_DOWN = "[orange1]Runhouse server down[/orange1]"
TERMINATED = "[red]Terminated[/red]"
UNKNOWN = "Unknown"
LOCAL_CLUSTER = "[bright_yellow]Local cluster[/bright_yellow]"


def get_status_color(status: str):
return getattr(StatusColors, status.upper()).value

0 comments on commit 6e69743

Please sign in to comment.