Skip to content

Commit 9ae0532

Browse files
committed
Type cleanup
Signed-off-by: Samuel Monson <[email protected]>
1 parent afb6fa6 commit 9ae0532

File tree

7 files changed

+46
-80
lines changed

7 files changed

+46
-80
lines changed

src/guidellm/scheduler/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
from .objects import (
1616
BackendInterface,
1717
BackendT,
18+
DatasetIterT,
1819
HistoryT,
1920
MeasuredRequestTimings,
20-
MultiTurnRequestT,
21-
MultiTurnT,
21+
RequestDataT,
2222
RequestSchedulerTimings,
2323
RequestT,
2424
ResponseT,
@@ -28,7 +28,6 @@
2828
SchedulerState,
2929
SchedulerUpdateAction,
3030
SchedulerUpdateActionProgress,
31-
TurnT,
3231
)
3332
from .scheduler import Scheduler
3433
from .strategies import (
@@ -59,6 +58,7 @@
5958
"Constraint",
6059
"ConstraintInitializer",
6160
"ConstraintsInitializerFactory",
61+
"DatasetIterT",
6262
"Environment",
6363
"HistoryT",
6464
"LastCompletionRequestTimings",
@@ -68,12 +68,11 @@
6868
"MaxGlobalErrorRateConstraint",
6969
"MaxNumberConstraint",
7070
"MeasuredRequestTimings",
71-
"MultiTurnRequestT",
72-
"MultiTurnT",
7371
"NoDelayRequestTimings",
7472
"NonDistributedEnvironment",
7573
"PoissonRateRequestTimings",
7674
"PydanticConstraintInitializer",
75+
"RequestDataT",
7776
"RequestSchedulerTimings",
7877
"RequestT",
7978
"ResponseT",
@@ -91,7 +90,6 @@
9190
"StrategyType",
9291
"SynchronousStrategy",
9392
"ThroughputStrategy",
94-
"TurnT",
9593
"UnserializableConstraintInitializer",
9694
"WorkerProcess",
9795
"WorkerProcessGroup",

src/guidellm/scheduler/environments.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919

2020
import time
2121
from abc import ABC, abstractmethod
22-
from collections.abc import AsyncIterator, Iterable
22+
from collections.abc import AsyncIterator
2323
from typing import (
2424
Generic,
2525
)
2626

2727
from guidellm.scheduler.constraints import Constraint
2828
from guidellm.scheduler.objects import (
29-
MultiTurnRequestT,
29+
DatasetIterT,
3030
RequestT,
3131
ResponseT,
3232
ScheduledRequestInfo,
@@ -52,11 +52,11 @@ class Environment(ABC, Generic[RequestT, ResponseT], InfoMixin):
5252
@abstractmethod
5353
async def sync_run_params(
5454
self,
55-
requests: Iterable[RequestT | MultiTurnRequestT[RequestT]],
55+
requests: DatasetIterT[RequestT],
5656
strategy: SchedulingStrategy,
5757
constraints: dict[str, Constraint],
5858
) -> tuple[
59-
Iterable[RequestT | MultiTurnRequestT[RequestT]],
59+
DatasetIterT[RequestT],
6060
SchedulingStrategy,
6161
dict[str, Constraint],
6262
]:
@@ -130,7 +130,7 @@ async def sync_run_end(
130130
) -> AsyncIterator[
131131
tuple[
132132
ResponseT,
133-
RequestT | MultiTurnRequestT[RequestT],
133+
RequestT,
134134
ScheduledRequestInfo,
135135
SchedulerState,
136136
]
@@ -194,11 +194,11 @@ def __init__(self):
194194

195195
async def sync_run_params(
196196
self,
197-
requests: Iterable[RequestT | MultiTurnRequestT[RequestT]],
197+
requests: DatasetIterT[RequestT],
198198
strategy: SchedulingStrategy,
199199
constraints: dict[str, Constraint],
200200
) -> tuple[
201-
Iterable[RequestT | MultiTurnRequestT[RequestT]],
201+
DatasetIterT[RequestT],
202202
SchedulingStrategy,
203203
dict[str, Constraint],
204204
]:
@@ -250,7 +250,7 @@ async def sync_run_end(
250250
) -> AsyncIterator[
251251
tuple[
252252
ResponseT,
253-
RequestT | MultiTurnRequestT[RequestT],
253+
RequestT,
254254
ScheduledRequestInfo,
255255
SchedulerState,
256256
]

src/guidellm/scheduler/objects.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import time
1313
import uuid
14-
from collections.abc import AsyncIterator
14+
from collections.abc import AsyncIterator, Iterable
1515
from typing import (
1616
Any,
1717
ClassVar,
@@ -34,10 +34,10 @@
3434
__all__ = [
3535
"BackendInterface",
3636
"BackendT",
37+
"DatasetIterT",
3738
"HistoryT",
3839
"MeasuredRequestTimings",
39-
"MultiTurnRequestT",
40-
"MultiTurnT",
40+
"RequestDataT",
4141
"RequestSchedulerTimings",
4242
"RequestT",
4343
"ResponseT",
@@ -47,36 +47,32 @@
4747
"SchedulerState",
4848
"SchedulerUpdateAction",
4949
"SchedulerUpdateActionProgress",
50-
"TurnT",
5150
]
5251

5352
RequestT = TypeVar("RequestT")
5453
"""Generic request object type for scheduler processing."""
5554

56-
# TODO: Remove
57-
MultiTurnRequestT = RequestT
58-
5955
ResponseT = TypeVar("ResponseT")
6056
"""Generic response object type returned by backend processing."""
6157

62-
TurnT = TypeAliasType(
63-
"TurnT",
58+
RequestDataT = TypeAliasType(
59+
"RequestDataT",
6460
tuple[RequestT, "ScheduledRequestAugmentation", "ScheduledRequestInfo"],
6561
type_params=(RequestT,),
6662
)
67-
68-
MultiTurnT = TypeAliasType(
69-
"MultiTurnT",
70-
list[TurnT[RequestT]],
71-
type_params=(RequestT,),
72-
)
73-
"""Multi-turn request structure supporting conversation history with optional delays."""
63+
"""Request including external metadata and scheduling config."""
7464

7565
HistoryT = TypeAliasType(
7666
"HistoryT",
7767
list[tuple[RequestT, ResponseT]],
7868
type_params=(RequestT, ResponseT),
7969
)
70+
"""Record of requests + responses in conversation."""
71+
72+
73+
DatasetIterT = TypeAliasType(
74+
"DatasetIterT", Iterable[Iterable[tuple[RequestT, float]]], type_params=(RequestT,)
75+
)
8076

8177

8278
class SchedulerMessagingPydanticRegistry(RegistryMixin[RegistryObjT]):

src/guidellm/scheduler/scheduler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from __future__ import annotations
1212

13-
from collections.abc import AsyncIterator, Iterable
13+
from collections.abc import AsyncIterator
1414
from typing import Any, Generic
1515

1616
from guidellm.scheduler.constraints import (
@@ -20,7 +20,7 @@
2020
from guidellm.scheduler.environments import Environment, NonDistributedEnvironment
2121
from guidellm.scheduler.objects import (
2222
BackendInterface,
23-
MultiTurnRequestT,
23+
DatasetIterT,
2424
RequestT,
2525
ResponseT,
2626
ScheduledRequestInfo,
@@ -66,7 +66,7 @@ class Scheduler(
6666

6767
async def run(
6868
self,
69-
requests: Iterable[RequestT | MultiTurnRequestT[RequestT]],
69+
requests: DatasetIterT[RequestT],
7070
backend: BackendInterface[RequestT, ResponseT],
7171
strategy: SchedulingStrategy,
7272
env: Environment | None,

src/guidellm/scheduler/worker.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232
from guidellm.scheduler.objects import (
3333
BackendInterface,
3434
HistoryT,
35-
MultiTurnRequestT,
36-
MultiTurnT,
35+
RequestDataT,
3736
RequestT,
3837
ResponseT,
3938
ScheduledRequestAugmentation,
@@ -54,7 +53,7 @@
5453
"ProcessRequestT",
5554
tuple[
5655
HistoryT[RequestT, ResponseT],
57-
MultiTurnT[RequestT],
56+
list[RequestDataT[RequestT]],
5857
ScheduledRequestAugmentation,
5958
],
6059
type_params=(RequestT, ResponseT),
@@ -87,11 +86,8 @@ class WorkerProcess(Generic[RequestT, ResponseT]):
8786
def __init__(
8887
self,
8988
messaging: InterProcessMessaging[
90-
tuple[
91-
ResponseT | None,
92-
RequestT | MultiTurnRequestT[RequestT],
93-
ScheduledRequestInfo,
94-
],
89+
tuple[ResponseT | None, RequestT, ScheduledRequestInfo],
90+
list[RequestDataT[RequestT]],
9591
],
9692
backend: BackendInterface[RequestT, ResponseT],
9793
request_timings: ScheduledRequestTimings,
@@ -132,7 +128,7 @@ def __init__(
132128
self.backend_started = False
133129
self.messaging_started = False
134130
self.turns_queue: list[
135-
tuple[HistoryT[RequestT, ResponseT], MultiTurnT[RequestT]]
131+
tuple[HistoryT[RequestT, ResponseT], list[RequestDataT[RequestT]]]
136132
] = []
137133

138134
def run(self):
@@ -332,7 +328,7 @@ async def _cancel_requests_loop(self):
332328
self._send_update("cancelled", None, request, request_info)
333329

334330
async def _process_next_request(self) -> ProcessRequestT[RequestT, ResponseT]:
335-
conversation: MultiTurnT[RequestT] = []
331+
conversation: list[RequestDataT[RequestT]] = []
336332
history: HistoryT[RequestT, ResponseT] = []
337333
request: RequestT | None = None
338334
request_info: ScheduledRequestInfo | None = None
@@ -409,7 +405,7 @@ async def _process_next_request(self) -> ProcessRequestT[RequestT, ResponseT]:
409405
async def _wait_then_requeue(
410406
self,
411407
history: HistoryT[RequestT, ResponseT],
412-
conversation: MultiTurnT[RequestT],
408+
conversation: list[RequestDataT[RequestT]],
413409
aug: ScheduledRequestAugmentation,
414410
):
415411
try:
@@ -427,7 +423,7 @@ def _send_update(
427423
"pending", "in_progress", "completed", "errored", "cancelled"
428424
],
429425
response: ResponseT | None,
430-
request: RequestT | MultiTurnRequestT[RequestT],
426+
request: RequestT,
431427
request_info: ScheduledRequestInfo,
432428
):
433429
prev_status = request_info.status

src/guidellm/scheduler/worker_group.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from guidellm.scheduler.constraints import Constraint, RequestsExhaustedConstraint
2626
from guidellm.scheduler.objects import (
2727
BackendInterface,
28-
MultiTurnRequestT,
29-
MultiTurnT,
28+
DatasetIterT,
29+
RequestDataT,
3030
RequestT,
3131
ResponseT,
3232
ScheduledRequestAugmentation,
@@ -83,8 +83,8 @@ class WorkerProcessGroup(Generic[RequestT, ResponseT]):
8383

8484
def __init__(
8585
self,
86-
requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None,
87-
cycle_requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None,
86+
requests: DatasetIterT[RequestT] | None,
87+
cycle_requests: DatasetIterT[RequestT] | None,
8888
backend: BackendInterface[RequestT, ResponseT],
8989
strategy: SchedulingStrategy,
9090
constraints: dict[str, Constraint],
@@ -131,16 +131,8 @@ def __init__(
131131
# Scheduler and messaging state, created in start
132132
self.state: WorkerGroupState[ResponseT, RequestT] = None
133133
self.messaging: InterProcessMessaging[
134-
tuple[
135-
RequestT | MultiTurnRequestT[RequestT],
136-
ScheduledRequestInfo,
137-
],
138-
tuple[
139-
ResponseT | None,
140-
RequestT | MultiTurnRequestT[RequestT],
141-
ScheduledRequestInfo,
142-
SchedulerState,
143-
],
134+
list[RequestDataT[RequestT]],
135+
tuple[ResponseT | None, RequestT, ScheduledRequestInfo, SchedulerState],
144136
] = None
145137

146138
async def create_processes(self):
@@ -473,9 +465,9 @@ def __init__(
473465

474466
def requests_generator(
475467
self,
476-
requests: Iterable[Iterable[tuple[RequestT, float]]] | None,
477-
cycle_requests: Iterable[Iterable[tuple[RequestT, float]]] | None,
478-
) -> Generator[MultiTurnT[RequestT], None, None]:
468+
requests: DatasetIterT[RequestT] | None,
469+
cycle_requests: DatasetIterT[RequestT] | None,
470+
) -> Generator[list[RequestDataT[RequestT]], None, None]:
479471
"""
480472
Generate request-info pairs for worker processing with constraint evaluation.
481473
@@ -544,12 +536,12 @@ def received_callback(
544536
self,
545537
update: tuple[
546538
ResponseT | None,
547-
RequestT | MultiTurnRequestT,
539+
RequestT,
548540
ScheduledRequestInfo,
549541
],
550542
) -> tuple[
551543
ResponseT | None,
552-
RequestT | MultiTurnRequestT,
544+
RequestT,
553545
ScheduledRequestInfo,
554546
SchedulerState,
555547
]:

tests/unit/scheduler/test_objects.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,11 @@
77

88
import pytest
99
from pydantic import ValidationError
10-
from typing_extensions import TypeAliasType
1110

1211
from guidellm.scheduler import (
1312
BackendInterface,
1413
BackendT,
1514
MeasuredRequestTimings,
16-
MultiTurnRequestT,
1715
RequestSchedulerTimings,
1816
RequestT,
1917
ResponseT,
@@ -49,20 +47,6 @@ def test_backend_t():
4947
assert BackendT.__constraints__ == ()
5048

5149

52-
def test_multi_turn_request_t():
53-
"""Validate MultiTurnRequestT is a TypeAliasType for multi-turn requests."""
54-
assert isinstance(MultiTurnRequestT, TypeAliasType)
55-
assert MultiTurnRequestT.__name__ == "MultiTurnRequestT"
56-
57-
value = MultiTurnRequestT.__value__
58-
assert hasattr(value, "__origin__")
59-
assert value.__origin__ is Union
60-
61-
type_params = getattr(MultiTurnRequestT, "__type_params__", ())
62-
assert len(type_params) == 1
63-
assert type_params[0].__name__ == "RequestT"
64-
65-
6650
class TestBackendInterface:
6751
"""Test the BackendInterface abstract base class."""
6852

0 commit comments

Comments
 (0)