Skip to content

Commit ad50d54

Browse files
committed
fix: handler types
1 parent fc677e7 commit ad50d54

File tree

3 files changed

+128
-21
lines changed

3 files changed

+128
-21
lines changed

src/socketio-stubs/_types.pyi

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ from collections.abc import Mapping, Sequence
22
from contextlib import AbstractAsyncContextManager, AbstractContextManager
33
from threading import Event as ThreadingEvent
44
from types import ModuleType
5-
from typing import Any, Literal, TypeAlias, overload
5+
from typing import Any, Literal, Protocol, TypeAlias, overload
66

7+
import engineio
78
from _typeshed import Incomplete
89
from engineio.async_drivers.eventlet import EventletThread
910
from engineio.async_drivers.gevent import Thread as GeventThread
@@ -206,3 +207,35 @@ class JsonModule(ModuleType):
206207
def dumps(obj: Any, **kwargs: Any) -> str: ...
207208
@staticmethod
208209
def loads(s: str | bytes | bytearray, **kwargs: Any) -> Any: ...
210+
211+
## handlers
212+
213+
class ServerConnectHandler(Protocol):
214+
def __call__(self, sid: str, environ: Mapping[str, Any]) -> Any: ...
215+
216+
class ServerConnectHandlerWithData(Protocol):
217+
def __call__(self, sid: str, environ: Mapping[str, Any], data: Any) -> Any: ...
218+
219+
class ServerDisconnectHandler(Protocol):
220+
def __call__(self, sid: str, reason: engineio.Server.reason) -> Any: ...
221+
222+
class ServerDisconnectLegacyHandler(Protocol):
223+
def __call__(self, sid: str) -> Any: ...
224+
225+
class ClientConnectHandler(Protocol):
226+
def __call__(self) -> Any: ...
227+
228+
class ClientDisconnectHandler(Protocol):
229+
def __call__(self, reason: engineio.Client.reason) -> Any: ...
230+
231+
class ClientDisconnectLegacyHandler(Protocol):
232+
def __call__(self) -> Any: ...
233+
234+
class ClientConnectErrorHandler(Protocol):
235+
def __call__(self, data: Any) -> Any: ...
236+
237+
class CatchAllHandler(Protocol):
238+
def __call__(self, event: str, sid: str, data: Any) -> Any: ...
239+
240+
class EventHandler(Protocol):
241+
def __call__(self, sid: str, data: Any) -> Any: ...

src/socketio-stubs/base_client.pyi

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,31 @@ from typing import Any, ClassVar, Generic, Literal, overload
66
import engineio
77
from engineio.async_client import AsyncClient
88
from engineio.client import Client
9-
from socketio._types import JsonModule, SerializerType, TransportType
9+
from socketio._types import (
10+
CatchAllHandler,
11+
ClientConnectErrorHandler,
12+
ClientConnectHandler,
13+
ClientDisconnectHandler,
14+
ClientDisconnectLegacyHandler,
15+
EventHandler,
16+
JsonModule,
17+
SerializerType,
18+
TransportType,
19+
)
1020
from socketio.base_namespace import BaseClientNamespace
1121
from socketio.packet import Packet
1222
from typing_extensions import TypeVar
1323

1424
_T_co = TypeVar("_T_co", bound=Client | AsyncClient, covariant=True, default=Any)
1525
_IsAsyncio = TypeVar("_IsAsyncio", bound=bool, default=Literal[False])
1626
_F = TypeVar("_F", bound=Callable[..., Any])
27+
_F_event = TypeVar("_F_event", bound=EventHandler)
28+
_F_connect = TypeVar("_F_connect", bound=ClientConnectHandler)
29+
_F_connect_error = TypeVar("_F_connect_error", bound=ClientConnectErrorHandler)
30+
_F_disconnect = TypeVar(
31+
"_F_disconnect", bound=ClientDisconnectHandler | ClientDisconnectLegacyHandler
32+
)
33+
_F_catch_all = TypeVar("_F_catch_all", bound=CatchAllHandler)
1734

1835
default_logger: logging.Logger
1936
reconnecting_clients: list[BaseClient[Any]]
@@ -62,12 +79,40 @@ class BaseClient(Generic[_IsAsyncio, _T_co]):
6279
def is_asyncio_based(self) -> _IsAsyncio: ...
6380
@overload
6481
def on(
65-
self, event: Callable[..., Any], handler: None = ..., namespace: None = ...
66-
) -> None: ...
82+
self,
83+
event: Literal["connect"],
84+
handler: None = ...,
85+
namespace: str | None = ...,
86+
) -> Callable[[_F_connect], _F_connect]: ...
6787
@overload
6888
def on(
69-
self, event: str, handler: Callable[..., Any], namespace: str | None = ...
70-
) -> Callable[[_F], _F] | None: ...
89+
self,
90+
event: Literal["connect_error"],
91+
handler: None = ...,
92+
namespace: str | None = ...,
93+
) -> Callable[[_F_connect_error], _F_connect_error]: ...
94+
@overload
95+
def on(
96+
self,
97+
event: Literal["disconnect"],
98+
handler: None = ...,
99+
namespace: str | None = ...,
100+
) -> Callable[[_F_disconnect], _F_disconnect]: ...
101+
@overload
102+
def on(
103+
self, event: Literal["*"], handler: None = ..., namespace: str | None = ...
104+
) -> Callable[[_F_catch_all], _F_catch_all]: ...
105+
@overload
106+
def on(
107+
self, event: str, handler: None = ..., namespace: str | None = ...
108+
) -> Callable[[_F_event], _F_event]: ...
109+
@overload
110+
def on(
111+
self,
112+
event: Callable[..., Any],
113+
handler: None = ...,
114+
namespace: str | None = ...,
115+
) -> None: ...
71116
@overload
72117
def on(
73118
self,
@@ -76,15 +121,9 @@ class BaseClient(Generic[_IsAsyncio, _T_co]):
76121
namespace: str | None = ...,
77122
) -> Callable[[_F], _F] | None: ...
78123
@overload
79-
def event(self, handler: Callable[..., Any]) -> None: ...
124+
def event(self, handler: EventHandler) -> None: ...
80125
@overload
81-
def event(
82-
self, handler: Callable[..., Any], namespace: str | None
83-
) -> Callable[[_F], _F]: ...
84-
@overload
85-
def event(
86-
self, handler: Callable[..., Any], namespace: str | None = ...
87-
) -> Callable[[_F], _F] | None: ...
126+
def event(self, namespace: str | None) -> Callable[[_F_event], _F_event]: ...
88127
def register_namespace(self, namespace_handler: BaseClientNamespace) -> None: ...
89128
def get_sid(self, namespace: str | None = ...) -> str | None: ...
90129
def transport(self) -> TransportType: ...

src/socketio-stubs/base_server.pyi

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,40 @@ from typing import Any, ClassVar, Generic, Literal, overload
55
import engineio
66
from engineio.async_server import AsyncServer
77
from engineio.server import Server
8-
from socketio._types import JsonModule, SerializerType, SyncAsyncModeType, TransportType
8+
from socketio._types import (
9+
CatchAllHandler,
10+
EventHandler,
11+
JsonModule,
12+
SerializerType,
13+
ServerConnectHandler,
14+
ServerConnectHandlerWithData,
15+
ServerDisconnectHandler,
16+
ServerDisconnectLegacyHandler,
17+
SyncAsyncModeType,
18+
TransportType,
19+
)
920
from socketio.base_namespace import BaseClientNamespace
1021
from socketio.manager import Manager
1122
from socketio.packet import Packet
1223
from typing_extensions import TypeVar
1324

1425
_T_co = TypeVar("_T_co", bound=Server | AsyncServer, covariant=True, default=Any)
1526
_F = TypeVar("_F", bound=Callable[..., Any])
27+
_F_event = TypeVar("_F_event", bound=EventHandler)
28+
_F_connect = TypeVar(
29+
"_F_connect", bound=ServerConnectHandler | ServerConnectHandlerWithData
30+
)
31+
_F_disconnect = TypeVar(
32+
"_F_disconnect", bound=ServerDisconnectHandler | ServerDisconnectLegacyHandler
33+
)
34+
_F_catch_all = TypeVar("_F_catch_all", bound=CatchAllHandler)
1635
_IsAsyncio = TypeVar("_IsAsyncio", bound=bool, default=Literal[False])
1736

1837
default_logger: logging.Logger
1938

2039
class BaseServer(Generic[_IsAsyncio, _T_co]):
2140
reserved_events: ClassVar[list[str]]
22-
reason: ClassVar[type[engineio.Client.reason]]
41+
reason: ClassVar[type[engineio.Server.reason]]
2342
packet_class: type[Packet]
2443
eio: _T_co
2544
environ: Mapping[str, Any]
@@ -46,9 +65,27 @@ class BaseServer(Generic[_IsAsyncio, _T_co]):
4665
) -> None: ...
4766
def is_asyncio_based(self) -> _IsAsyncio: ...
4867
@overload
68+
def on(
69+
self,
70+
event: Literal["connect"],
71+
handler: None = ...,
72+
namespace: str | None = ...,
73+
) -> Callable[[_F_connect], _F_connect]: ...
74+
@overload
75+
def on(
76+
self,
77+
event: Literal["disconnect"],
78+
handler: None = ...,
79+
namespace: str | None = ...,
80+
) -> Callable[[_F_disconnect], _F_disconnect]: ...
81+
@overload
82+
def on(
83+
self, event: Literal["*"], handler: None = ..., namespace: str | None = ...
84+
) -> Callable[[_F_catch_all], _F_catch_all]: ...
85+
@overload
4986
def on(
5087
self, event: str, handler: None = ..., namespace: str | None = ...
51-
) -> Callable[[_F], _F]: ...
88+
) -> Callable[[_F_event], _F_event]: ...
5289
@overload
5390
def on(
5491
self,
@@ -64,11 +101,9 @@ class BaseServer(Generic[_IsAsyncio, _T_co]):
64101
namespace: str | None = ...,
65102
) -> Callable[[_F], _F] | None: ...
66103
@overload
67-
def event(
68-
self, handler: Callable[..., Any], namespace: str | None = ...
69-
) -> None: ...
104+
def event(self, handler: EventHandler, namespace: str | None = ...) -> None: ...
70105
@overload
71-
def event(self, namespace: str | None) -> Callable[[_F], _F]: ...
106+
def event(self, namespace: str | None) -> Callable[[_F_event], _F_event]: ...
72107
def register_namespace(
73108
self, namespace_handler: BaseClientNamespace[_IsAsyncio]
74109
) -> None: ...

0 commit comments

Comments
 (0)