Skip to content

Commit 60ae55a

Browse files
authored
Merge pull request #916 from transformerlab/add/bridge
adding providers support
2 parents 6cf3536 + a337522 commit 60ae55a

27 files changed

+5225
-181
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""add team_providers_tables
2+
3+
Revision ID: 63ca6eebc24c
4+
Revises: f7661070ec23
5+
Create Date: 2025-11-24 11:35:14.455588
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
from alembic import op
11+
import sqlalchemy as sa
12+
13+
14+
# revision identifiers, used by Alembic.
15+
revision: str = '63ca6eebc24c'
16+
down_revision: Union[str, Sequence[str], None] = 'f7661070ec23'
17+
branch_labels: Union[str, Sequence[str], None] = None
18+
depends_on: Union[str, Sequence[str], None] = None
19+
20+
21+
def upgrade() -> None:
22+
"""Upgrade schema."""
23+
# ### commands auto generated by Alembic - please adjust! ###
24+
op.create_table('team_providers',
25+
sa.Column('id', sa.String(), nullable=False),
26+
sa.Column('team_id', sa.String(), nullable=False),
27+
sa.Column('name', sa.String(), nullable=False),
28+
sa.Column('type', sa.String(), nullable=False),
29+
sa.Column('config', sa.JSON(), nullable=False),
30+
sa.Column('created_by_user_id', sa.String(), nullable=False),
31+
sa.Column('created_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
32+
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
33+
sa.PrimaryKeyConstraint('id')
34+
)
35+
op.create_index('idx_team_provider_name', 'team_providers', ['team_id', 'name'], unique=False)
36+
op.create_index(op.f('ix_team_providers_team_id'), 'team_providers', ['team_id'], unique=False)
37+
op.create_index(op.f('ix_team_providers_type'), 'team_providers', ['type'], unique=False)
38+
# ### end Alembic commands ###
39+
40+
41+
def downgrade() -> None:
42+
"""Downgrade schema."""
43+
# ### commands auto generated by Alembic - please adjust! ###
44+
op.drop_index(op.f('ix_team_providers_type'), table_name='team_providers')
45+
op.drop_index(op.f('ix_team_providers_team_id'), table_name='team_providers')
46+
op.drop_index('idx_team_provider_name', table_name='team_providers')
47+
op.drop_table('team_providers')
48+
# ### end Alembic commands ###
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""rename_team_providers_to_compute_providers
2+
3+
Revision ID: be6b6cb9f784
4+
Revises: 63ca6eebc24c
5+
Create Date: 2025-11-26 14:47:16.424026
6+
7+
"""
8+
9+
from typing import Sequence, Union
10+
11+
from alembic import op
12+
13+
14+
# revision identifiers, used by Alembic.
15+
revision: str = "be6b6cb9f784"
16+
down_revision: Union[str, Sequence[str], None] = "63ca6eebc24c"
17+
branch_labels: Union[str, Sequence[str], None] = None
18+
depends_on: Union[str, Sequence[str], None] = None
19+
20+
21+
def upgrade() -> None:
22+
"""Upgrade schema."""
23+
# Rename the table from team_providers to compute_providers
24+
op.rename_table("team_providers", "compute_providers")
25+
26+
# Rename the index
27+
op.drop_index("idx_team_provider_name", table_name="compute_providers")
28+
op.create_index("idx_compute_provider_name", "compute_providers", ["team_id", "name"], unique=False)
29+
30+
# Update index names that use the table name pattern
31+
# The ix_team_providers_* indexes will be automatically handled by SQLAlchemy/Alembic
32+
# but we should verify they exist and update if needed
33+
try:
34+
op.drop_index(op.f("ix_team_providers_team_id"), table_name="compute_providers")
35+
except Exception:
36+
pass # Index might not exist or already dropped
37+
try:
38+
op.drop_index(op.f("ix_team_providers_type"), table_name="compute_providers")
39+
except Exception:
40+
pass # Index might not exist or already dropped
41+
42+
# Create new indexes with correct names (Alembic will auto-generate these on next autogenerate)
43+
op.create_index(op.f("ix_compute_providers_team_id"), "compute_providers", ["team_id"], unique=False)
44+
op.create_index(op.f("ix_compute_providers_type"), "compute_providers", ["type"], unique=False)
45+
46+
47+
def downgrade() -> None:
48+
"""Downgrade schema."""
49+
# Drop new indexes
50+
op.drop_index(op.f("ix_compute_providers_type"), table_name="compute_providers")
51+
op.drop_index(op.f("ix_compute_providers_team_id"), table_name="compute_providers")
52+
op.drop_index("idx_compute_provider_name", table_name="compute_providers")
53+
54+
# Rename the table back first
55+
op.rename_table("compute_providers", "team_providers")
56+
57+
# Recreate old indexes on the renamed table
58+
op.create_index("idx_team_provider_name", "team_providers", ["team_id", "name"], unique=False)
59+
op.create_index(op.f("ix_team_providers_team_id"), "team_providers", ["team_id"], unique=False)
60+
op.create_index(op.f("ix_team_providers_type"), "team_providers", ["type"], unique=False)

api/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
batched_prompts,
6161
recipes,
6262
teams,
63+
compute_provider,
6364
auth,
6465
)
6566
from transformerlab.routers.auth import get_user_and_team # noqa: E402
@@ -236,6 +237,7 @@ async def validation_exception_handler(request, exc):
236237
app.include_router(batched_prompts.router, dependencies=[Depends(get_user_and_team)])
237238
app.include_router(fastchat_openai_api.router, dependencies=[Depends(get_user_and_team)])
238239
app.include_router(teams.router, dependencies=[Depends(get_user_and_team)])
240+
app.include_router(compute_provider.router)
239241
app.include_router(auth.router)
240242

241243
controller_process = None
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Compute provider bridge system for abstracting GPU orchestration providers."""
2+
3+
from .base import ComputeProvider
4+
from .router import ComputeProviderRouter, get_provider
5+
from .config import load_compute_providers_config, ComputeProviderConfig
6+
7+
__all__ = [
8+
"ComputeProvider",
9+
"ComputeProviderRouter",
10+
"get_provider",
11+
"load_compute_providers_config",
12+
"ComputeProviderConfig",
13+
]
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""Abstract base class for provider implementations."""
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Dict, List, Any, Optional, Union
5+
from .models import (
6+
ClusterConfig,
7+
JobConfig,
8+
ClusterStatus,
9+
JobInfo,
10+
ResourceInfo,
11+
)
12+
13+
14+
class ComputeProvider(ABC):
15+
"""Abstract base class for all compute provider implementations."""
16+
17+
@abstractmethod
18+
def launch_cluster(self, cluster_name: str, config: ClusterConfig) -> Dict[str, Any]:
19+
"""
20+
Launch/provision a new cluster.
21+
22+
Args:
23+
cluster_name: Name of the cluster to launch
24+
config: Cluster configuration
25+
26+
Returns:
27+
Dictionary with launch result (e.g., request_id, cluster_name)
28+
"""
29+
raise NotImplementedError
30+
31+
@abstractmethod
32+
def stop_cluster(self, cluster_name: str) -> Dict[str, Any]:
33+
"""
34+
Stop a running cluster (but don't tear it down).
35+
36+
Args:
37+
cluster_name: Name of the cluster to stop
38+
39+
Returns:
40+
Dictionary with stop result
41+
"""
42+
raise NotImplementedError
43+
44+
@abstractmethod
45+
def get_cluster_status(self, cluster_name: str) -> ClusterStatus:
46+
"""
47+
Get the status of a cluster.
48+
49+
Args:
50+
cluster_name: Name of the cluster
51+
52+
Returns:
53+
ClusterStatus object with cluster information
54+
"""
55+
raise NotImplementedError
56+
57+
@abstractmethod
58+
def get_cluster_resources(self, cluster_name: str) -> ResourceInfo:
59+
"""
60+
Get resource information for a cluster (GPUs, CPUs, memory, etc.).
61+
62+
Args:
63+
cluster_name: Name of the cluster
64+
65+
Returns:
66+
ResourceInfo object with resource details
67+
"""
68+
raise NotImplementedError
69+
70+
@abstractmethod
71+
def submit_job(self, cluster_name: str, job_config: JobConfig) -> Dict[str, Any]:
72+
"""
73+
Submit a job to an existing cluster.
74+
75+
Args:
76+
cluster_name: Name of the cluster
77+
job_config: Job configuration
78+
79+
Returns:
80+
Dictionary with job submission result (e.g., job_id)
81+
"""
82+
raise NotImplementedError
83+
84+
@abstractmethod
85+
def get_job_logs(
86+
self,
87+
cluster_name: str,
88+
job_id: Union[str, int],
89+
tail_lines: Optional[int] = None,
90+
follow: bool = False,
91+
) -> Union[str, Any]:
92+
"""
93+
Get logs for a job.
94+
95+
Args:
96+
cluster_name: Name of the cluster
97+
job_id: Job identifier
98+
tail_lines: Number of lines to retrieve from the end (None for all)
99+
follow: Whether to stream/follow logs (returns stream if True)
100+
101+
Returns:
102+
Log content as string, or stream object if follow=True
103+
"""
104+
raise NotImplementedError
105+
106+
@abstractmethod
107+
def cancel_job(self, cluster_name: str, job_id: Union[str, int]) -> Dict[str, Any]:
108+
"""
109+
Cancel a running or queued job.
110+
111+
Args:
112+
cluster_name: Name of the cluster
113+
job_id: Job identifier
114+
115+
Returns:
116+
Dictionary with cancellation result
117+
"""
118+
raise NotImplementedError
119+
120+
@abstractmethod
121+
def list_jobs(self, cluster_name: str) -> List[JobInfo]:
122+
"""
123+
List all jobs for a cluster.
124+
125+
Args:
126+
cluster_name: Name of the cluster
127+
128+
Returns:
129+
List of JobInfo objects
130+
"""
131+
raise NotImplementedError

0 commit comments

Comments
 (0)