1
- """Sharing backend database implementation ."""
1
+ """Module for database interfaces using postgres ."""
2
2
3
3
4
4
import logging
5
5
import os
6
6
import typing
7
7
8
+ import aiohttp .web
8
9
import asyncpg
9
10
10
11
from swift_browser_ui .common .common_util import sleep_random
13
14
MODULE_LOGGER .setLevel (os .environ .get ("LOG_LEVEL" , "INFO" ))
14
15
15
16
16
- class DBConn :
17
- """Class for the account sharing database functionality."""
17
+ async def db_graceful_start (app : aiohttp .web .Application ) -> None :
18
+ """Gracefully start the database."""
19
+ app ["db_conn" ] = app ["db_class" ]()
20
+ await app ["db_conn" ].open ()
21
+
22
+
23
+ async def db_graceful_close (app : aiohttp .web .Application ) -> None :
24
+ """Gracefully close the database."""
25
+ if app ["db_conn" ] is not None :
26
+ await app ["db_conn" ].close ()
27
+
28
+
29
+ class BaseDBConn :
30
+ """Class for base database connection."""
18
31
19
32
def __init__ (self ) -> None :
20
33
"""Initialize connection variable."""
@@ -27,27 +40,16 @@ def erase(self) -> None:
27
40
self .pool .terminate ()
28
41
self .pool = None
29
42
30
- async def open (self ) -> None :
43
+ async def close (self ) -> None :
44
+ """Safely close the database connection."""
45
+ if self .pool is not None :
46
+ await self .pool .close ()
47
+
48
+ async def _open (self , ** kwargs ) -> None :
31
49
"""Initialize the database connection."""
32
50
while self .pool is None :
33
51
try :
34
- self .pool = await asyncpg .create_pool (
35
- password = os .environ .get ("SHARING_DB_PASSWORD" , None ),
36
- user = os .environ .get ("SHARING_DB_USER" , "sharing" ),
37
- host = os .environ .get ("SHARING_DB_HOST" , "localhost" ),
38
- port = int (os .environ .get ("SHARING_DB_PORT" , 5432 )),
39
- ssl = os .environ .get ("SHARING_DB_SSL" , "prefer" ),
40
- database = os .environ .get ("SHARING_DB_NAME" , "swiftbrowserdb" ),
41
- min_size = int (os .environ .get ("SHARING_DB_MIN_CONNECTIONS" , 0 )),
42
- max_size = int (os .environ .get ("SHARING_DB_MAX_CONNECTIONS" , 2 )),
43
- timeout = int (os .environ .get ("SHARING_DB_TIMEOUT" , 120 )),
44
- command_timeout = int (
45
- os .environ .get ("SHARING_DB_COMMAND_TIMEOUT" , 180 )
46
- ),
47
- max_inactive_connection_lifetime = int (
48
- os .environ .get ("SHARING_DB_MAX_INACTIVE_CONN_LIFETIME" , 10 )
49
- ),
50
- )
52
+ self .pool = await asyncpg .create_pool (** kwargs )
51
53
except (ConnectionError , OSError ):
52
54
self .log .error (
53
55
"Failed to establish connection. "
@@ -61,10 +63,47 @@ async def open(self) -> None:
61
63
self .log .error ("Database is not ready yet." )
62
64
await sleep_random ()
63
65
64
- async def close (self ) -> None :
65
- """Safely close the database connection."""
66
+ async def get_tokens (
67
+ self , token_owner : str
68
+ ) -> typing .List [typing .Dict [str , typing .Any ]]:
69
+ """Get tokens created for a project."""
66
70
if self .pool is not None :
67
- await self .pool .close ()
71
+ query = await self .pool .fetch (
72
+ """SELECT *
73
+ FROM Tokens
74
+ WHERE token_owner = $1
75
+ ;
76
+ """ ,
77
+ token_owner ,
78
+ )
79
+ return list (query )
80
+ return []
81
+
82
+
83
+ class SharingDBConn (BaseDBConn ):
84
+ """Class for the account sharing database functionality."""
85
+
86
+ def __init__ (self ) -> None :
87
+ """Initialize connection variable."""
88
+ super ().__init__ ()
89
+
90
+ async def open (self ) -> None :
91
+ """Initialize the database connection."""
92
+ await super ()._open (
93
+ password = os .environ .get ("SHARING_DB_PASSWORD" , None ),
94
+ user = os .environ .get ("SHARING_DB_USER" , "sharing" ),
95
+ host = os .environ .get ("SHARING_DB_HOST" , "localhost" ),
96
+ port = int (os .environ .get ("SHARING_DB_PORT" , 5432 )),
97
+ ssl = os .environ .get ("SHARING_DB_SSL" , "prefer" ),
98
+ database = os .environ .get ("SHARING_DB_NAME" , "swiftbrowserdb" ),
99
+ min_size = int (os .environ .get ("SHARING_DB_MIN_CONNECTIONS" , 0 )),
100
+ max_size = int (os .environ .get ("SHARING_DB_MAX_CONNECTIONS" , 2 )),
101
+ timeout = int (os .environ .get ("SHARING_DB_TIMEOUT" , 120 )),
102
+ command_timeout = int (os .environ .get ("SHARING_DB_COMMAND_TIMEOUT" , 180 )),
103
+ max_inactive_connection_lifetime = int (
104
+ os .environ .get ("SHARING_DB_MAX_INACTIVE_CONN_LIFETIME" , 10 )
105
+ ),
106
+ )
68
107
69
108
async def add_share (
70
109
self ,
@@ -311,22 +350,6 @@ async def get_shared_container_details(
311
350
return ret
312
351
return []
313
352
314
- async def get_tokens (
315
- self , token_owner : str
316
- ) -> typing .List [typing .Dict [str , typing .Any ]]:
317
- """Get tokens created for a project."""
318
- if self .pool is not None :
319
- query = await self .pool .fetch (
320
- """SELECT *
321
- FROM Tokens
322
- WHERE token_owner = $1
323
- ;
324
- """ ,
325
- token_owner ,
326
- )
327
- return list (query )
328
- return []
329
-
330
353
async def revoke_token (self , token_owner : str , token_identifier : str ) -> None :
331
354
"""Remove a token from the database."""
332
355
if self .pool is not None :
@@ -415,3 +438,139 @@ async def match_name_id(self, name: str) -> list:
415
438
return list (query )
416
439
417
440
return []
441
+
442
+
443
+ class RequestDBConn (BaseDBConn ):
444
+ """Class for handling sharing request database connection."""
445
+
446
+ def __init__ (self ) -> None :
447
+ """."""
448
+ super ().__init__ ()
449
+
450
+ async def open (self ) -> None :
451
+ """Gracefully open the database."""
452
+ await super ()._open (
453
+ password = os .environ .get ("REQUEST_DB_PASSWORD" , None ),
454
+ user = os .environ .get ("REQUEST_DB_USER" , "request" ),
455
+ host = os .environ .get ("REQUEST_DB_HOST" , "localhost" ),
456
+ port = int (os .environ .get ("REQUEST_DB_PORT" , 5432 )),
457
+ ssl = os .environ .get ("REQUEST_DB_SSL" , "prefer" ),
458
+ database = os .environ .get ("REQUEST_DB_NAME" , "swiftbrowserdb" ),
459
+ min_size = int (os .environ .get ("REQUEST_DB_MIN_CONNECTIONS" , 0 )),
460
+ max_size = int (os .environ .get ("REQUEST_DB_MAX_CONNECTIONS" , 49 )),
461
+ timeout = int (os .environ .get ("REQUEST_DB_TIMEOUT" , 120 )),
462
+ command_timeout = int (os .environ .get ("REQUEST_DB_COMMAND_TIMEOUT" , 180 )),
463
+ max_inactive_connection_lifetime = int (
464
+ os .environ .get ("REQUEST_DB_MAX_INACTIVE_CONN_LIFETIME" , 0 )
465
+ ),
466
+ )
467
+
468
+ @staticmethod
469
+ async def parse_query (
470
+ query : typing .List [asyncpg .Record ],
471
+ ) -> typing .List [typing .Dict [str , typing .Any ]]:
472
+ """Parse a database query list to JSON serializable form."""
473
+ return [
474
+ {
475
+ "container" : rec ["container" ],
476
+ "user" : rec ["recipient" ],
477
+ "owner" : rec ["container_owner" ],
478
+ "date" : rec ["created" ].isoformat (),
479
+ }
480
+ for rec in query
481
+ ]
482
+
483
+ async def add_request (self , user : str , container : str , owner : str ) -> bool :
484
+ """Add an access request to the database."""
485
+ if self .pool is not None :
486
+ async with self .pool .acquire () as conn :
487
+ async with conn .transaction ():
488
+ await conn .execute (
489
+ """
490
+ INSERT INTO Requests(
491
+ container,
492
+ container_owner,
493
+ recipient,
494
+ created
495
+ ) VALUES (
496
+ $1, $2, $3, NOW()
497
+ );
498
+ """ ,
499
+ container ,
500
+ owner ,
501
+ user ,
502
+ )
503
+ return True
504
+ return False
505
+
506
+ async def get_request_owned (
507
+ self , user : str
508
+ ) -> typing .List [typing .Dict [str , typing .Any ]]:
509
+ """Get the requests owned by the getter."""
510
+ if self .pool is not None :
511
+ query = await self .pool .fetch (
512
+ """
513
+ SELECT *
514
+ FROM Requests
515
+ WHERE container_owner = $1
516
+ ;
517
+ """ ,
518
+ user ,
519
+ )
520
+ return await self .parse_query (query )
521
+ return []
522
+
523
+ async def get_request_made (
524
+ self , user : str
525
+ ) -> typing .List [typing .Dict [str , typing .Any ]]:
526
+ """Get the requests made by the getter."""
527
+ if self .pool is not None :
528
+ query = await self .pool .fetch (
529
+ """
530
+ SELECT *
531
+ FROM Requests
532
+ WHERE recipient = $1
533
+ ;
534
+ """ ,
535
+ user ,
536
+ )
537
+ return await self .parse_query (query )
538
+ return []
539
+
540
+ async def get_request_container (
541
+ self , container : str
542
+ ) -> typing .List [typing .Dict [str , typing .Any ]]:
543
+ """Get the requests made for a container."""
544
+ if self .pool is not None :
545
+ query = await self .pool .fetch (
546
+ """
547
+ SELECT *
548
+ FROM Requests
549
+ WHERE container = $1
550
+ ;
551
+ """ ,
552
+ container ,
553
+ )
554
+ return await self .parse_query (query )
555
+ return []
556
+
557
+ async def delete_request (self , container : str , owner : str , recipient : str ) -> bool :
558
+ """Delete an access request from the database."""
559
+ if self .pool is not None :
560
+ async with self .pool .acquire () as conn :
561
+ async with conn .transaction ():
562
+ await conn .execute (
563
+ """
564
+ DELETE FROM Requests
565
+ WHERE
566
+ container = $1 AND
567
+ container_owner = $2 AND
568
+ recipient = $3
569
+ ;
570
+ """ ,
571
+ container ,
572
+ owner ,
573
+ recipient ,
574
+ )
575
+ return True
576
+ return False
0 commit comments