13
13
Callable ,
14
14
Optional ,
15
15
Type ,
16
+ Union ,
16
17
cast ,
17
18
)
18
19
19
20
import yarl
20
21
21
- from .abc import AbstractAccessLogger , AbstractStreamWriter
22
+ from .abc import (
23
+ AbstractAccessLogger ,
24
+ AbstractAsyncAccessLogger ,
25
+ AbstractStreamWriter ,
26
+ )
22
27
from .base_protocol import BaseProtocol
23
28
from .helpers import CeilTimeout , current_task
24
29
from .http import (
50
55
BaseRequest ]
51
56
52
57
_RequestHandler = Callable [[BaseRequest ], Awaitable [StreamResponse ]]
58
+ _AnyAbstractAccessLogger = Union [
59
+ Type [AbstractAsyncAccessLogger ],
60
+ Type [AbstractAccessLogger ],
61
+ ]
53
62
54
63
55
64
ERROR = RawRequestMessage (
@@ -65,6 +74,22 @@ class PayloadAccessError(Exception):
65
74
"""Payload was accessed after response was sent."""
66
75
67
76
77
+ class AccessLoggerWrapper (AbstractAsyncAccessLogger ):
78
+ """
79
+ Wraps an AbstractAccessLogger so it behaves
80
+ like an AbstractAsyncAccessLogger.
81
+ """
82
+ def __init__ (self , access_logger : AbstractAccessLogger ):
83
+ self .access_logger = access_logger
84
+ super ().__init__ ()
85
+
86
+ async def log (self ,
87
+ request : BaseRequest ,
88
+ response : StreamResponse ,
89
+ request_start : float ) -> None :
90
+ self .access_logger .log (request , response , request_start )
91
+
92
+
68
93
class RequestHandler (BaseProtocol ):
69
94
"""HTTP protocol implementation.
70
95
@@ -120,7 +145,7 @@ def __init__(self, manager: 'Server', *,
120
145
keepalive_timeout : float = 75. , # NGINX default is 75 secs
121
146
tcp_keepalive : bool = True ,
122
147
logger : Logger = server_logger ,
123
- access_log_class : Type [ AbstractAccessLogger ] = AccessLogger ,
148
+ access_log_class : _AnyAbstractAccessLogger = AccessLogger ,
124
149
access_log : Logger = access_logger ,
125
150
access_log_format : str = AccessLogger .LOG_FORMAT ,
126
151
debug : bool = False ,
@@ -164,8 +189,11 @@ def __init__(self, manager: 'Server', *,
164
189
self .debug = debug
165
190
self .access_log = access_log
166
191
if access_log :
167
- self .access_logger = access_log_class (
168
- access_log , access_log_format ) # type: Optional[AbstractAccessLogger] # noqa
192
+ if issubclass (access_log_class , AbstractAsyncAccessLogger ):
193
+ self .access_logger = access_log_class () # type: Optional[AbstractAsyncAccessLogger] # noqa
194
+ else :
195
+ access_logger = access_log_class (access_log , access_log_format )
196
+ self .access_logger = AccessLoggerWrapper (access_logger )
169
197
else :
170
198
self .access_logger = None
171
199
@@ -339,13 +367,13 @@ def force_close(self) -> None:
339
367
self .transport .close ()
340
368
self .transport = None
341
369
342
- def log_access (self ,
343
- request : BaseRequest ,
344
- response : StreamResponse ,
345
- request_start : float ) -> None :
370
+ async def log_access (self ,
371
+ request : BaseRequest ,
372
+ response : StreamResponse ,
373
+ request_start : float ) -> None :
346
374
if self .access_logger is not None :
347
- self .access_logger .log (request , response ,
348
- self ._loop .time () - request_start )
375
+ await self .access_logger .log (request , response ,
376
+ self ._loop .time () - request_start )
349
377
350
378
def log_debug (self , * args : Any , ** kw : Any ) -> None :
351
379
if self .debug :
@@ -526,10 +554,10 @@ async def finish_response(self,
526
554
await prepare_meth (request )
527
555
await resp .write_eof ()
528
556
except ConnectionResetError :
529
- self .log_access (request , resp , start_time )
557
+ await self .log_access (request , resp , start_time )
530
558
return True
531
559
else :
532
- self .log_access (request , resp , start_time )
560
+ await self .log_access (request , resp , start_time )
533
561
return False
534
562
535
563
def handle_error (self ,
0 commit comments