Skip to content

Commit 0693a1c

Browse files
committed
add more restrictive typing checks
fix create_pool int parameters fix typo REQUEST_DB_PORT instead of SHARING_DB_PORT
1 parent eadcd2d commit 0693a1c

19 files changed

+423
-364
lines changed

swift_browser_ui/common/common_middleware.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import logging
55
import os
6-
import typing
76

87
import aiohttp.web
98
import asyncpg.exceptions
@@ -45,7 +44,7 @@ async def catch_uniqueness_error(
4544
@aiohttp.web.middleware
4645
async def check_db_conn(
4746
request: aiohttp.web.Request, handler: swift_browser_ui.common.types.AiohttpHandler
48-
):
47+
) -> aiohttp.web.Response:
4948
"""Check if an established database connection exists."""
5049
if request.path == "/health":
5150
return await handler(request)
@@ -80,7 +79,7 @@ async def handle_validate_authentication(
8079
reason="Query string missing validity or signature"
8180
)
8281

83-
project: typing.Union[None, str]
82+
project: None | str = None
8483
project_tokens = []
8584
try:
8685
project = request.match_info["project"]

swift_browser_ui/common/common_util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ async def read_in_keys(app: aiohttp.web.Application) -> None:
1616
app["tokens"] = [token.encode("utf-8") for token in app["tokens"]]
1717

1818

19-
async def sleep_random():
19+
async def sleep_random() -> None:
2020
"""Sleep a random time."""
2121
return await asyncio.sleep(random.randint(2, 5)) # nosec # noqa: S311

swift_browser_ui/common/signature.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
LOGGER.setLevel(os.environ.get("LOG_LEVEL", "INFO"))
1515

1616

17-
def sign_api_request(path: str, valid_for: int = 3600, key: bytes = b"") -> dict:
17+
def sign_api_request(
18+
path: str, valid_for: int = 3600, key: bytes = b""
19+
) -> typing.Dict[str, typing.Any]:
1820
"""Handle authentication with a signature."""
1921
valid_until = str(int(time.time() + valid_for))
2022
to_sign = (valid_until + path).encode("utf-8")

swift_browser_ui/common/types.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77

88
AiohttpHandler = typing.Callable[
99
[aiohttp.web.Request],
10-
typing.Coroutine[typing.Awaitable, typing.Any, aiohttp.web.Response],
10+
typing.Coroutine[typing.Awaitable[typing.Any], typing.Any, aiohttp.web.Response],
1111
]

swift_browser_ui/common/vault_client.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
from base64 import standard_b64encode
77
from dataclasses import dataclass
8-
from typing import Any, Dict, List, Optional, Union
8+
from typing import Any, Dict, List
99

1010
from aiohttp import ClientSession, ClientTimeout
1111
from aiohttp.web import HTTPError, HTTPGatewayTimeout, HTTPInternalServerError
@@ -135,10 +135,10 @@ async def _request(
135135
self,
136136
method: str = "GET",
137137
path: str = "",
138-
params: Optional[dict] = None,
139-
json_data: Union[None, Dict, List[Dict]] = None,
138+
params: None | Dict[str, Any] = None,
139+
json_data: None | Dict[str, Any] | List[Dict[str, Any]] = None,
140140
timeout: int = 10,
141-
) -> Optional[dict]:
141+
) -> None | str | Dict[Any, Any]:
142142
"""Request to Vault API with error handling logic, and retries in case of token expiry.
143143
144144
:param method: HTTP method
@@ -185,7 +185,7 @@ async def _request(
185185
if response.status == 204 or response.status == 404:
186186
return None
187187

188-
content = await response.json()
188+
content: Dict[str, Any] = await response.json()
189189
if response.status == 200:
190190
LOGGER.debug("Content: %r", content)
191191
return content
@@ -241,7 +241,7 @@ async def get_key() -> str:
241241
key_json = await self._request("GET", f"c4ghtransit/keys/{project}")
242242
if isinstance(key_json, dict) and "data" in key_json:
243243
latest_version = str(key_json["data"]["latest_version"])
244-
return key_json["data"]["keys"][latest_version]["public_key_c4gh_64"]
244+
return str(key_json["data"]["keys"][latest_version]["public_key_c4gh_64"])
245245
return ""
246246

247247
LOGGER.debug("Getting public key for project %r", project)
@@ -298,7 +298,7 @@ async def get_header(self, project: str, container: str, path: str) -> str:
298298
params={"service": self.service, "key": self._key_name},
299299
)
300300
if isinstance(header_response, dict) and "data" in header_response:
301-
return header_response["data"]["headers"]["1"]["header"]
301+
return str(header_response["data"]["headers"]["1"]["header"])
302302
return ""
303303

304304
async def put_header(

swift_browser_ui/launcher.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,22 @@
66
import swift_browser_ui.upload.server
77

88

9-
def run_ui():
9+
def run_ui() -> None:
1010
"""Run the UI."""
1111
swift_browser_ui.ui.shell.main()
1212

1313

14-
def run_sharing():
14+
def run_sharing() -> None:
1515
"""Run swift-x-account-sharing service."""
1616
swift_browser_ui.sharing.server.main()
1717

1818

19-
def run_request():
19+
def run_request() -> None:
2020
"""Run swift-sharing-request service."""
2121
swift_browser_ui.request.server.main()
2222

2323

24-
def run_upload():
24+
def run_upload() -> None:
2525
"""Run swiftui-upload-runner service."""
2626
swift_browser_ui.upload.server.main()
2727

swift_browser_ui/request/db.py

+100-79
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class DBConn:
1919
def __init__(self) -> None:
2020
"""."""
2121
self.log = MODULE_LOGGER
22-
self.pool: asyncpg.Pool = None
22+
self.pool: asyncpg.Pool | None = None
2323

2424
async def open(self) -> None:
2525
"""Gracefully open the database."""
@@ -32,12 +32,14 @@ async def open(self) -> None:
3232
port=int(os.environ.get("REQUEST_DB_PORT", 5432)),
3333
ssl=os.environ.get("REQUEST_DB_SSL", "prefer"),
3434
database=os.environ.get("REQUEST_DB_NAME", "swiftbrowserdb"),
35-
min_size=os.environ.get("REQUEST_DB_MIN_CONNECTIONS", 0),
36-
max_size=os.environ.get("REQUEST_DB_MAX_CONNECTIONS", 49),
37-
timeout=os.environ.get("REQUEST_DB_TIMEOUT", 120),
38-
command_timeout=os.environ.get("REQUEST_DB_COMMAND_TIMEOUT", 180),
39-
max_inactive_connection_lifetime=os.environ.get(
40-
"REQUEST_DB_MAX_INACTIVE_CONN_LIFETIME", 0
35+
min_size=int(os.environ.get("REQUEST_DB_MIN_CONNECTIONS", 0)),
36+
max_size=int(os.environ.get("REQUEST_DB_MAX_CONNECTIONS", 49)),
37+
timeout=int(os.environ.get("REQUEST_DB_TIMEOUT", 120)),
38+
command_timeout=int(
39+
os.environ.get("REQUEST_DB_COMMAND_TIMEOUT", 180)
40+
),
41+
max_inactive_connection_lifetime=int(
42+
os.environ.get("REQUEST_DB_MAX_INACTIVE_CONN_LIFETIME", 0)
4143
),
4244
)
4345
except (ConnectionError, OSError):
@@ -60,11 +62,14 @@ async def close(self) -> None:
6062

6163
def erase(self) -> None:
6264
"""Immediately erase the connection."""
63-
self.pool.terminate()
64-
self.pool = None
65+
if self.pool is not None:
66+
self.pool.terminate()
67+
self.pool = None
6568

6669
@staticmethod
67-
async def parse_query(query: typing.List[asyncpg.Record]) -> typing.List[dict]:
70+
async def parse_query(
71+
query: typing.List[asyncpg.Record],
72+
) -> typing.List[typing.Dict[str, typing.Any]]:
6873
"""Parse a database query list to JSON serializable form."""
6974
return [
7075
{
@@ -78,79 +83,95 @@ async def parse_query(query: typing.List[asyncpg.Record]) -> typing.List[dict]:
7883

7984
async def add_request(self, user: str, container: str, owner: str) -> bool:
8085
"""Add an access request to the database."""
81-
async with self.pool.acquire() as conn:
82-
async with conn.transaction():
83-
await conn.execute(
84-
"""
85-
INSERT INTO Requests(
86+
if self.pool is not None:
87+
async with self.pool.acquire() as conn:
88+
async with conn.transaction():
89+
await conn.execute(
90+
"""
91+
INSERT INTO Requests(
92+
container,
93+
container_owner,
94+
recipient,
95+
created
96+
) VALUES (
97+
$1, $2, $3, NOW()
98+
);
99+
""",
86100
container,
87-
container_owner,
88-
recipient,
89-
created
90-
) VALUES (
91-
$1, $2, $3, NOW()
92-
);
93-
""",
94-
container,
95-
owner,
96-
user,
97-
)
98-
return True
99-
100-
async def get_request_owned(self, user: str) -> typing.List:
101+
owner,
102+
user,
103+
)
104+
return True
105+
return False
106+
107+
async def get_request_owned(
108+
self, user: str
109+
) -> typing.List[typing.Dict[str, typing.Any]]:
101110
"""Get the requests owned by the getter."""
102-
query = await self.pool.fetch(
103-
"""
104-
SELECT *
105-
FROM Requests
106-
WHERE container_owner = $1
107-
;
108-
""",
109-
user,
110-
)
111-
return await self.parse_query(query)
112-
113-
async def get_request_made(self, user: str) -> typing.List:
111+
if self.pool is not None:
112+
query = await self.pool.fetch(
113+
"""
114+
SELECT *
115+
FROM Requests
116+
WHERE container_owner = $1
117+
;
118+
""",
119+
user,
120+
)
121+
return await self.parse_query(query)
122+
return []
123+
124+
async def get_request_made(
125+
self, user: str
126+
) -> typing.List[typing.Dict[str, typing.Any]]:
114127
"""Get the requests made by the getter."""
115-
query = await self.pool.fetch(
116-
"""
117-
SELECT *
118-
FROM Requests
119-
WHERE recipient = $1
120-
;
121-
""",
122-
user,
123-
)
124-
return await self.parse_query(query)
125-
126-
async def get_request_container(self, container: str) -> typing.List:
128+
if self.pool is not None:
129+
query = await self.pool.fetch(
130+
"""
131+
SELECT *
132+
FROM Requests
133+
WHERE recipient = $1
134+
;
135+
""",
136+
user,
137+
)
138+
return await self.parse_query(query)
139+
return []
140+
141+
async def get_request_container(
142+
self, container: str
143+
) -> typing.List[typing.Dict[str, typing.Any]]:
127144
"""Get the requests made for a container."""
128-
query = await self.pool.fetch(
129-
"""
130-
SELECT *
131-
FROM Requests
132-
WHERE container = $1
133-
;
134-
""",
135-
container,
136-
)
137-
return await self.parse_query(query)
145+
if self.pool is not None:
146+
query = await self.pool.fetch(
147+
"""
148+
SELECT *
149+
FROM Requests
150+
WHERE container = $1
151+
;
152+
""",
153+
container,
154+
)
155+
return await self.parse_query(query)
156+
return []
138157

139158
async def delete_request(self, container: str, owner: str, recipient: str) -> bool:
140159
"""Delete an access request from the database."""
141-
async with self.pool.acquire() as conn:
142-
async with conn.transaction():
143-
await conn.execute(
144-
"""
145-
DELETE FROM Requests
146-
WHERE
147-
container = $1 AND
148-
container_owner = $2 AND
149-
recipient = $3
150-
;
151-
""",
152-
container,
153-
owner,
154-
recipient,
155-
)
156-
return True
160+
if self.pool is not None:
161+
async with self.pool.acquire() as conn:
162+
async with conn.transaction():
163+
await conn.execute(
164+
"""
165+
DELETE FROM Requests
166+
WHERE
167+
container = $1 AND
168+
container_owner = $2 AND
169+
recipient = $3
170+
;
171+
""",
172+
container,
173+
owner,
174+
recipient,
175+
)
176+
return True
177+
return False

0 commit comments

Comments
 (0)