From 71e5bc43e27e048030e2db9da2648d50a94cb144 Mon Sep 17 00:00:00 2001 From: Guinea Wheek Date: Thu, 29 Dec 2022 01:10:45 -0800 Subject: [PATCH 1/9] initial design commit --- dozer/{db.py => db/__init__.py} | 2 +- dozer/db/orm.py | 188 ++++++++++++++++++++++++++++++++ dozer/db/pqt.py | 8 ++ 3 files changed, 197 insertions(+), 1 deletion(-) rename dozer/{db.py => db/__init__.py} (99%) create mode 100644 dozer/db/orm.py create mode 100644 dozer/db/pqt.py diff --git a/dozer/db.py b/dozer/db/__init__.py similarity index 99% rename from dozer/db.py rename to dozer/db/__init__.py index e5c07f8c..365271fd 100755 --- a/dozer/db.py +++ b/dozer/db/__init__.py @@ -4,7 +4,7 @@ import asyncpg from loguru import logger -Pool = None +Pool: asyncpg.Pool = None async def db_init(db_url): diff --git a/dozer/db/orm.py b/dozer/db/orm.py new file mode 100644 index 00000000..81a87e0b --- /dev/null +++ b/dozer/db/orm.py @@ -0,0 +1,188 @@ +from . import DatabaseTable, Pool +from .pqt import Column +from typing import List, Tuple, Dict +import asyncpg + + + +class ORMTable(DatabaseTable): + """ORM tables are a new variant on DatabaseTables: + + * they are defined from class attributes + * initial_create, initial_migrate, get_by, delete, and update_or_add (upsert) are handled for you + * ability to instantiate objects directly from the results of SQL queries + + This class can vastly reduce the amount of boilerplate in a codebase. + + notes: + * __uniques__ MUST be a tuple! Do not set it to a string! Runtime will check for this and yell at you! + """ + __tablename__: str = '' + __versions__: Tuple[int] = tuple() + __uniques__: Tuple[str] = tuple() + + _columns: Dict[str, Column] = None + + + # Declare the migrate/create functions + @classmethod + async def initial_create(cls): + """Create the table in the database. Already implemented for you. Can still override if desired.""" + + if cls is ORMTable: + # abstract class need not apply + return + + columns = cls.get_columns() + # assemble "column = Column('integer not null') into column integer not null, " + query_params = ", ".join(map(" ".join, zip(columns.keys(), (c.sql for c in columns.values())))) + + if cls.__uniques__: + # TODO: determine if we should still use __uniques__ or have primary key data in Column + query_params += f", PRIMARY KEY({', '.join(k for k in cls.__uniques__)})" + + query_str = f"CREATE TABLE {cls.__tablename__}({query_params})" + async with Pool.acquire() as conn: + await conn.execute(query_str) + + + @classmethod + async def initial_migrate(cls): + """Create a version entry in the versions table""" + if cls is ORMTable: + # abstract class need not apply + return + + async with Pool.acquire() as conn: + await conn.execute("""INSERT INTO versions VALUES ($1, 0)""", cls.__tablename__) + + @staticmethod + def nullify(): + """Function to be referenced when a table entry value needs to be set to null""" + + def __init__(self, *args, **kwargs): + # yeah the one drawback of this approach is that you don't get hints for constructors anymore + # which doesn't stop language servers from somehow figuring out how to do this with dataclasses + # turns out dataclasses use dark magic + # (they dynamically generate and eval an __init__ which they staple onto the class) + + self.__dict__.update({k: v for k, v in kwargs.items() if k in self.get_columns()}) + super().__init__() + + async def update_or_add(self): + """Assign the attribute to this object, then call this method to either insert the object if it doesn't exist in + the DB or update it if it does exist. It will update every column not specified in __uniques__.""" + keys = [] + values = [] + for var, value in self.__dict__.items(): + if var not in self.get_columns(): + continue + # Done so that the two are guaranteed to be in the same order, which isn't true of keys() and values() + keys.append(var) + values.append(None if value is self.nullify else value) + + updates = "" + for key in keys: + if key in self.__uniques__: + # Skip updating anything that has a unique constraint on it + continue + updates += f"{key} = EXCLUDED.{key}" + if keys.index(key) == len(keys) - 1: + updates += " ;" + else: + updates += ", \n" + async with Pool.acquire() as conn: + if updates: + statement = f""" + INSERT INTO {self.__tablename__} ({", ".join(keys)}) + VALUES({','.join(f'${i + 1}' for i in range(len(values)))}) + ON CONFLICT ({self.__uniques__}) DO UPDATE + SET {updates} + """ + else: + statement = f""" + INSERT INTO {self.__tablename__} ({", ".join(keys)}) + VALUES({','.join(f'${i + 1}' for i in range(len(values)))}) + ON CONFLICT ({self.__uniques__}) DO NOTHING; + """ + await conn.execute(statement, *values) + + def __repr__(self): + # repr is supposed to be a representation of the object + # usually you want eval(repr(obj)) == obj more or less + return self.__class__.__name__ + "(" + ", ".join(f"{key}={val!r}" for key, val in \ + {key: getattr(self, key) for key in self._columns.keys()}.items()) + ")" + + # Class Methods + + @classmethod + def get_columns(cls) -> Dict[str, Column]: + """Returns all columns in the table. Also initializes the column cache.""" + if cls._columns is not None: + return cls._columns + + cls._columns = {name: val for name, val in vars(cls) if isinstance(val, Column)} + return cls._columns + + + @classmethod + def from_record(cls, record: asyncpg.Record): + """Converts an asyncpg query record into an instance of the class.""" + if record is None: + return None + + return cls(**{k: record[k] for k in cls.get_columns().keys()}) + + @classmethod + async def get_by(cls, **filters) -> list: + """Selects a list of all records matching the given column=value criteria. + Since pretty much every subclass overrides this to return lists of instantiated objects rather than queries, + we simply automate this. + """ + async with Pool.acquire() as conn: + statement = f"SELECT * FROM {cls.__tablename__}" + if filters: + # note: this code relies on subsequent iterations of the same dict having the same iteration order. + # This is an implementation detail of CPython 3.6 and a language guarantee in Python 3.7+. + conditions = " AND ".join(f"{column_name} = ${i + 1}" for (i, column_name) in enumerate(filters)) + statement = f"{statement} WHERE {conditions};" + else: + statement += ";" + records = await conn.fetch(statement, *filters.values()) + return [*map(cls.from_record, records)] + + @classmethod + async def get_one(cls, **filters): + """It's like get_by except it returns exactly one record or None.""" + return (cls.get_by(**filters) or [None])[0] + + @classmethod + async def delete(cls, **filters): + """Deletes by any number of criteria specified as column=value keyword arguments. Returns the number of entries deleted.""" + async with Pool.acquire() as conn: + if filters: + # This code relies on properties of dicts - see get_by + conditions = " AND ".join(f"{column_name} = ${i + 1}" for (i, column_name) in enumerate(filters)) + statement = f"DELETE FROM {cls.__tablename__} WHERE {conditions};" + else: + # Should this be a warning/error? It's almost certainly not intentional + statement = f"TRUNCATE {cls.__tablename__};" + return await conn.execute(statement, *filters.values()) + + @classmethod + async def set_initial_version(cls): + """Sets initial version""" + await Pool.execute("""INSERT INTO versions (table_name, version_num) VALUES ($1,$2)""", cls.__tablename__, 0) + + +# this part isn't super necessary. also need to figure out how to do this without circular imports. + +#async def startup_init(): +# """Initializes all ORM classes, and validates them.""" +# for cls in ORMTable.__subclasses__: +# cls.get_columns() +# if not isinstance(cls.__uniques__, tuple): +# raise ValueError(f"{cls}.__uniques__ MUST be a tuple!") +# +# if not cls.__tablename__: +# raise ValueError(f"{cls}.__tablename__ is blank!") \ No newline at end of file diff --git a/dozer/db/pqt.py b/dozer/db/pqt.py new file mode 100644 index 00000000..03eaf217 --- /dev/null +++ b/dozer/db/pqt.py @@ -0,0 +1,8 @@ +"""Postgres types. Add aliases here.""" + +class Column: + """Represents a sql column.""" + def __init__(self, sql: str): + self.sql: str = sql + +Col = Column \ No newline at end of file From e811e28a3a2fb5a184b3dda73f318485b09ddd94 Mon Sep 17 00:00:00 2001 From: Guinea Wheek Date: Thu, 29 Dec 2022 01:13:00 -0800 Subject: [PATCH 2/9] forgot an await --- dozer/db/orm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dozer/db/orm.py b/dozer/db/orm.py index 81a87e0b..d7f45f85 100644 --- a/dozer/db/orm.py +++ b/dozer/db/orm.py @@ -154,7 +154,7 @@ async def get_by(cls, **filters) -> list: @classmethod async def get_one(cls, **filters): """It's like get_by except it returns exactly one record or None.""" - return (cls.get_by(**filters) or [None])[0] + return ((await cls.get_by(**filters)) or [None])[0] @classmethod async def delete(cls, **filters): From 2840bb44523328584b225808fae657ffd1f438f8 Mon Sep 17 00:00:00 2001 From: Guinea Wheek Date: Thu, 29 Dec 2022 01:17:42 -0800 Subject: [PATCH 3/9] add example --- dozer/db/orm.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/dozer/db/orm.py b/dozer/db/orm.py index d7f45f85..edac7596 100644 --- a/dozer/db/orm.py +++ b/dozer/db/orm.py @@ -16,6 +16,23 @@ class ORMTable(DatabaseTable): notes: * __uniques__ MUST be a tuple! Do not set it to a string! Runtime will check for this and yell at you! + + + For example: + + ```python + class StarboardConfig(db.orm.ORMTable): + __tablename__ = 'starboard_settings' + __uniques__ = ('guild_id',) + guild_id: int = Column("bigint NOT NULL") + channel_id: int = Column("bigint NOT NULL") + star_emoji: str = Column("varchar NOT NULL") + cancel_emoji: str = Column("varchar") + threshold: int = Column("bigint NOT NULL") + ``` + + will produce a functionally equivalent class without overriding initial_create or get_by. + """ __tablename__: str = '' __versions__: Tuple[int] = tuple() From 30015bc82a53bbbbca218b47241820f291b27fe6 Mon Sep 17 00:00:00 2001 From: Guinea Wheek Date: Sat, 31 Dec 2022 21:44:26 -0800 Subject: [PATCH 4/9] start testing on one table --- dozer/cogs/_utils.py | 29 ++---- dozer/db/__init__.py | 203 +++++++++++++++++++++++++++++++++++++++++- dozer/db/orm.py | 205 ------------------------------------------- 3 files changed, 207 insertions(+), 230 deletions(-) delete mode 100644 dozer/db/orm.py diff --git a/dozer/cogs/_utils.py b/dozer/cogs/_utils.py index a3ce3e16..41505b45 100755 --- a/dozer/cogs/_utils.py +++ b/dozer/cogs/_utils.py @@ -362,32 +362,15 @@ async def refresh(self): logger.info(f"{len(prefixes)} prefixes loaded from database") -class DynamicPrefixEntry(db.DatabaseTable): +class DynamicPrefixEntry(db.ORMTable): """Holds the custom prefixes for guilds""" __tablename__ = 'dynamic_prefixes' - __uniques__ = 'guild_id' - - @classmethod - async def initial_create(cls): - """Create the table in the database""" - async with db.Pool.acquire() as conn: - await conn.execute(f""" - CREATE TABLE {cls.__tablename__} ( - guild_id bigint NOT NULL, - prefix text NOT NULL, - PRIMARY KEY (guild_id) - )""") + __uniques__ = ('guild_id',) + + guild_id = db.Column("bigint not null") + prefix: str = db.Column("text not null") def __init__(self, guild_id: int, prefix: str): super().__init__() self.guild_id = guild_id - self.prefix = prefix - - @classmethod - async def get_by(cls, **kwargs): - results = await super().get_by(**kwargs) - result_list = [] - for result in results: - obj = DynamicPrefixEntry(guild_id=result.get("guild_id"), prefix=result.get("prefix")) - result_list.append(obj) - return result_list + self.prefix = prefix \ No newline at end of file diff --git a/dozer/db/__init__.py b/dozer/db/__init__.py index 365271fd..4eba3217 100755 --- a/dozer/db/__init__.py +++ b/dozer/db/__init__.py @@ -1,9 +1,12 @@ """Provides database storage for the Dozer Discord bot""" -from typing import List, Dict +from typing import List, Dict, Tuple +from .pqt import Column, Col import asyncpg from loguru import logger +__all__ = ["Pool", "Column", "Col", "db_init", "db_migrate", "DatabaseTable", "ORMTable", "ConfigCache"] + Pool: asyncpg.Pool = None @@ -21,7 +24,14 @@ async def db_migrate(): version_num int NOT NULL )""") logger.info("Checking for db migrations") - for cls in DatabaseTable.__subclasses__(): + for cls in DatabaseTable.__subclasses__() + ORMTable.__subclasses__(): + + if not cls.__tablename__: + raise ValueError(f"{cls.__name__}.__tablename__ cannot be blank!") + if isinstance(cls, ORMTable) and cls.__uniques__ and not isinstance(cls.__uniques__, tuple): + raise ValueError(f"{cls.__name__}.__uniques__ must be a tuple!") + + exists = await Pool.fetchrow("""SELECT EXISTS( SELECT 1 FROM information_schema.tables @@ -157,6 +167,195 @@ async def set_initial_version(cls): await Pool.execute("""INSERT INTO versions (table_name, version_num) VALUES ($1,$2)""", cls.__tablename__, 0) +class ORMTable(DatabaseTable): + """ORM tables are a new variant on DatabaseTables: + + * they are defined from class attributes + * initial_create, initial_migrate, get_by, delete, and update_or_add (upsert) are handled for you + * ability to instantiate objects directly from the results of SQL queries using from_record + + This class can vastly reduce the amount of boilerplate in a codebase. + + notes: + * __uniques__ MUST be a tuple of primary key column names! + Do not set it to a string! Runtime will check for this and yell at you! + + For example: + + ```python + class StarboardConfig(db.orm.ORMTable): + __tablename__ = 'starboard_settings' + __uniques__ = ('guild_id',) + guild_id: int = Column("bigint NOT NULL") + channel_id: int = Column("bigint NOT NULL") + star_emoji: str = Column("varchar NOT NULL") + cancel_emoji: str = Column("varchar") + threshold: int = Column("bigint NOT NULL") + ``` + + will produce a functionally equivalent class without overriding initial_create or get_by. + + """ + __tablename__: str = '' + __versions__: Tuple[int] = tuple() + __uniques__: Tuple[str] = tuple() + + _columns: Dict[str, Column] = None + + + # Declare the migrate/create functions + @classmethod + async def initial_create(cls): + """Create the table in the database. Already implemented for you. Can still override if desired.""" + + logger.debug(cls.__name__ + " process") + if cls is ORMTable: + # abstract class need not apply + logger.debug(cls.__name__ + " skipped") + return + + columns = cls.get_columns() + # assemble "column = Column('integer not null') into column integer not null, " + query_params = ", ".join(map(" ".join, zip(columns.keys(), (c.sql for c in columns.values())))) + + if cls.__uniques__: + # TODO: determine if we should still use __uniques__ or have primary key data in Column + query_params += f", PRIMARY KEY({', '.join(k for k in cls.__uniques__)})" + + query_str = f"CREATE TABLE {cls.__tablename__}({query_params})" + + logger.debug("exec " + query_str) + async with Pool.acquire() as conn: + await conn.execute(query_str) + + + @classmethod + async def initial_migrate(cls): + """Create a version entry in the versions table""" + if cls is ORMTable: + # abstract class need not apply + return + + async with Pool.acquire() as conn: + await conn.execute("""INSERT INTO versions VALUES ($1, 0)""", cls.__tablename__) + + @staticmethod + def nullify(): + """Function to be referenced when a table entry value needs to be set to null""" + + def __init__(self, *args, **kwargs): + """ + yeah the one drawback of this approach is that you don't get hints for constructors anymore + which doesn't stop language servers from somehow figuring out how to do this with dataclasses + turns out dataclasses use dark magic + (they dynamically generate and eval an __init__ which they staple onto the class) + + you can just avoid this by overriding the init functions anyway, but if you do, + it's suggested that you include all the columns as arguments as from_record will call + the constructor with them using the ** operator + + if that is a problem, just override from_record to not do that + + """ + self.__dict__.update({k: v for k, v in kwargs.items() if k in self.get_columns()}) + super().__init__() + + async def update_or_add(self): + """Assign the attribute to this object, then call this method to either insert the object if it doesn't exist in + the DB or update it if it does exist. It will update every column not specified in __uniques__.""" + keys = [] + values = [] + for var, value in self.__dict__.items(): + if var not in self.get_columns(): + continue + # Done so that the two are guaranteed to be in the same order, which isn't true of keys() and values() + keys.append(var) + values.append(None if value is self.nullify else value) + + updates = "" + for key in keys: + if key in self.__uniques__: + # Skip updating anything that has a unique constraint on it + continue + updates += f"{key} = EXCLUDED.{key}" + if keys.index(key) == len(keys) - 1: + updates += " ;" + else: + updates += ", \n" + + primary_key = ", ".join(self.__uniques__) + + async with Pool.acquire() as conn: + if updates: + statement = f""" + INSERT INTO {self.__tablename__} ({", ".join(keys)}) + VALUES({','.join(f'${i + 1}' for i in range(len(values)))}) + ON CONFLICT ({primary_key}) DO UPDATE + SET {updates} + """ + else: + statement = f""" + INSERT INTO {self.__tablename__} ({", ".join(keys)}) + VALUES({','.join(f'${i + 1}' for i in range(len(values)))}) + ON CONFLICT ({primary_key}) DO NOTHING; + """ + + logger.debug("exec " + statement) + await conn.execute(statement, *values) + + def __repr__(self): + # repr is supposed to be a representation of the object + # usually you want eval(repr(obj)) == obj more or less + return self.__class__.__name__ + "(" + ", ".join(f"{key}={val!r}" for key, val in \ + {key: getattr(self, key) for key in self._columns.keys()}.items()) + ")" + + # Class Methods + + @classmethod + def get_columns(cls) -> Dict[str, Column]: + """Returns all columns in the table. Also initializes the column cache.""" + if cls._columns is not None: + return cls._columns + + cls._columns = {name: val for name, val in vars(cls).items() if isinstance(val, Column)} + return cls._columns + + + @classmethod + def from_record(cls, record: asyncpg.Record): + """Converts an asyncpg query record into an instance of the class. Nonexistent entries will get filled with None.""" + if record is None: + return None + + return cls(**{k: record.get(k) for k in cls.get_columns().keys()}) + + @classmethod + async def get_by(cls, **filters) -> list: + """Selects a list of all records matching the given column=value criteria. + Since pretty much every subclass overrides this to return lists of instantiated objects rather than queries, + we simply automate this. + """ + async with Pool.acquire() as conn: + statement = f"SELECT * FROM {cls.__tablename__}" + if filters: + # note: this code relies on subsequent iterations of the same dict having the same iteration order. + # This is an implementation detail of CPython 3.6 and a language guarantee in Python 3.7+. + conditions = " AND ".join(f"{column_name} = ${i + 1}" for (i, column_name) in enumerate(filters)) + statement = f"{statement} WHERE {conditions};" + else: + statement += ";" + logger.debug("exec " + statement) + records = await conn.fetch(statement, *filters.values()) + return [*map(cls.from_record, records)] + + @classmethod + async def get_one(cls, **filters): + """It's like get_by except it returns exactly one record or None.""" + return ((await cls.get_by(**filters)) or [None])[0] + + + + class ConfigCache: """Class that will reduce calls to sqlalchemy as much as possible. Has no growth limit (yet)""" diff --git a/dozer/db/orm.py b/dozer/db/orm.py deleted file mode 100644 index edac7596..00000000 --- a/dozer/db/orm.py +++ /dev/null @@ -1,205 +0,0 @@ -from . import DatabaseTable, Pool -from .pqt import Column -from typing import List, Tuple, Dict -import asyncpg - - - -class ORMTable(DatabaseTable): - """ORM tables are a new variant on DatabaseTables: - - * they are defined from class attributes - * initial_create, initial_migrate, get_by, delete, and update_or_add (upsert) are handled for you - * ability to instantiate objects directly from the results of SQL queries - - This class can vastly reduce the amount of boilerplate in a codebase. - - notes: - * __uniques__ MUST be a tuple! Do not set it to a string! Runtime will check for this and yell at you! - - - For example: - - ```python - class StarboardConfig(db.orm.ORMTable): - __tablename__ = 'starboard_settings' - __uniques__ = ('guild_id',) - guild_id: int = Column("bigint NOT NULL") - channel_id: int = Column("bigint NOT NULL") - star_emoji: str = Column("varchar NOT NULL") - cancel_emoji: str = Column("varchar") - threshold: int = Column("bigint NOT NULL") - ``` - - will produce a functionally equivalent class without overriding initial_create or get_by. - - """ - __tablename__: str = '' - __versions__: Tuple[int] = tuple() - __uniques__: Tuple[str] = tuple() - - _columns: Dict[str, Column] = None - - - # Declare the migrate/create functions - @classmethod - async def initial_create(cls): - """Create the table in the database. Already implemented for you. Can still override if desired.""" - - if cls is ORMTable: - # abstract class need not apply - return - - columns = cls.get_columns() - # assemble "column = Column('integer not null') into column integer not null, " - query_params = ", ".join(map(" ".join, zip(columns.keys(), (c.sql for c in columns.values())))) - - if cls.__uniques__: - # TODO: determine if we should still use __uniques__ or have primary key data in Column - query_params += f", PRIMARY KEY({', '.join(k for k in cls.__uniques__)})" - - query_str = f"CREATE TABLE {cls.__tablename__}({query_params})" - async with Pool.acquire() as conn: - await conn.execute(query_str) - - - @classmethod - async def initial_migrate(cls): - """Create a version entry in the versions table""" - if cls is ORMTable: - # abstract class need not apply - return - - async with Pool.acquire() as conn: - await conn.execute("""INSERT INTO versions VALUES ($1, 0)""", cls.__tablename__) - - @staticmethod - def nullify(): - """Function to be referenced when a table entry value needs to be set to null""" - - def __init__(self, *args, **kwargs): - # yeah the one drawback of this approach is that you don't get hints for constructors anymore - # which doesn't stop language servers from somehow figuring out how to do this with dataclasses - # turns out dataclasses use dark magic - # (they dynamically generate and eval an __init__ which they staple onto the class) - - self.__dict__.update({k: v for k, v in kwargs.items() if k in self.get_columns()}) - super().__init__() - - async def update_or_add(self): - """Assign the attribute to this object, then call this method to either insert the object if it doesn't exist in - the DB or update it if it does exist. It will update every column not specified in __uniques__.""" - keys = [] - values = [] - for var, value in self.__dict__.items(): - if var not in self.get_columns(): - continue - # Done so that the two are guaranteed to be in the same order, which isn't true of keys() and values() - keys.append(var) - values.append(None if value is self.nullify else value) - - updates = "" - for key in keys: - if key in self.__uniques__: - # Skip updating anything that has a unique constraint on it - continue - updates += f"{key} = EXCLUDED.{key}" - if keys.index(key) == len(keys) - 1: - updates += " ;" - else: - updates += ", \n" - async with Pool.acquire() as conn: - if updates: - statement = f""" - INSERT INTO {self.__tablename__} ({", ".join(keys)}) - VALUES({','.join(f'${i + 1}' for i in range(len(values)))}) - ON CONFLICT ({self.__uniques__}) DO UPDATE - SET {updates} - """ - else: - statement = f""" - INSERT INTO {self.__tablename__} ({", ".join(keys)}) - VALUES({','.join(f'${i + 1}' for i in range(len(values)))}) - ON CONFLICT ({self.__uniques__}) DO NOTHING; - """ - await conn.execute(statement, *values) - - def __repr__(self): - # repr is supposed to be a representation of the object - # usually you want eval(repr(obj)) == obj more or less - return self.__class__.__name__ + "(" + ", ".join(f"{key}={val!r}" for key, val in \ - {key: getattr(self, key) for key in self._columns.keys()}.items()) + ")" - - # Class Methods - - @classmethod - def get_columns(cls) -> Dict[str, Column]: - """Returns all columns in the table. Also initializes the column cache.""" - if cls._columns is not None: - return cls._columns - - cls._columns = {name: val for name, val in vars(cls) if isinstance(val, Column)} - return cls._columns - - - @classmethod - def from_record(cls, record: asyncpg.Record): - """Converts an asyncpg query record into an instance of the class.""" - if record is None: - return None - - return cls(**{k: record[k] for k in cls.get_columns().keys()}) - - @classmethod - async def get_by(cls, **filters) -> list: - """Selects a list of all records matching the given column=value criteria. - Since pretty much every subclass overrides this to return lists of instantiated objects rather than queries, - we simply automate this. - """ - async with Pool.acquire() as conn: - statement = f"SELECT * FROM {cls.__tablename__}" - if filters: - # note: this code relies on subsequent iterations of the same dict having the same iteration order. - # This is an implementation detail of CPython 3.6 and a language guarantee in Python 3.7+. - conditions = " AND ".join(f"{column_name} = ${i + 1}" for (i, column_name) in enumerate(filters)) - statement = f"{statement} WHERE {conditions};" - else: - statement += ";" - records = await conn.fetch(statement, *filters.values()) - return [*map(cls.from_record, records)] - - @classmethod - async def get_one(cls, **filters): - """It's like get_by except it returns exactly one record or None.""" - return ((await cls.get_by(**filters)) or [None])[0] - - @classmethod - async def delete(cls, **filters): - """Deletes by any number of criteria specified as column=value keyword arguments. Returns the number of entries deleted.""" - async with Pool.acquire() as conn: - if filters: - # This code relies on properties of dicts - see get_by - conditions = " AND ".join(f"{column_name} = ${i + 1}" for (i, column_name) in enumerate(filters)) - statement = f"DELETE FROM {cls.__tablename__} WHERE {conditions};" - else: - # Should this be a warning/error? It's almost certainly not intentional - statement = f"TRUNCATE {cls.__tablename__};" - return await conn.execute(statement, *filters.values()) - - @classmethod - async def set_initial_version(cls): - """Sets initial version""" - await Pool.execute("""INSERT INTO versions (table_name, version_num) VALUES ($1,$2)""", cls.__tablename__, 0) - - -# this part isn't super necessary. also need to figure out how to do this without circular imports. - -#async def startup_init(): -# """Initializes all ORM classes, and validates them.""" -# for cls in ORMTable.__subclasses__: -# cls.get_columns() -# if not isinstance(cls.__uniques__, tuple): -# raise ValueError(f"{cls}.__uniques__ MUST be a tuple!") -# -# if not cls.__tablename__: -# raise ValueError(f"{cls}.__tablename__ is blank!") \ No newline at end of file From 0454a5f83186d24bc76ede633927358af3bf14e9 Mon Sep 17 00:00:00 2001 From: Guinea Wheek Date: Sat, 31 Dec 2022 23:38:07 -0800 Subject: [PATCH 5/9] add basic versioning support --- dozer/cogs/_utils.py | 2 +- dozer/db/__init__.py | 76 ++++++++++++++++++++++++++++++-------------- dozer/db/pqt.py | 9 ++++-- 3 files changed, 60 insertions(+), 27 deletions(-) diff --git a/dozer/cogs/_utils.py b/dozer/cogs/_utils.py index 41505b45..ad4f4fd8 100755 --- a/dozer/cogs/_utils.py +++ b/dozer/cogs/_utils.py @@ -373,4 +373,4 @@ class DynamicPrefixEntry(db.ORMTable): def __init__(self, guild_id: int, prefix: str): super().__init__() self.guild_id = guild_id - self.prefix = prefix \ No newline at end of file + self.prefix = prefix diff --git a/dozer/db/__init__.py b/dozer/db/__init__.py index 4eba3217..82639bcb 100755 --- a/dozer/db/__init__.py +++ b/dozer/db/__init__.py @@ -1,9 +1,11 @@ """Provides database storage for the Dozer Discord bot""" -from typing import List, Dict, Tuple -from .pqt import Column, Col +# pylint generates false positives on this warning +# pylint: disable=unsupported-membership-test +from typing import List, Dict, Tuple import asyncpg from loguru import logger +from .pqt import Column, Col __all__ = ["Pool", "Column", "Col", "db_init", "db_migrate", "DatabaseTable", "ORMTable", "ConfigCache"] @@ -25,6 +27,9 @@ async def db_migrate(): )""") logger.info("Checking for db migrations") for cls in DatabaseTable.__subclasses__() + ORMTable.__subclasses__(): + + if cls is ORMTable: # abstract class, do not check + continue if not cls.__tablename__: raise ValueError(f"{cls.__name__}.__tablename__ cannot be blank!") @@ -44,7 +49,12 @@ async def db_migrate(): # Migration/creation required, go to the function in the subclass for it await cls.initial_migrate() version = {"version_num": 0} - if int(version["version_num"]) < len(cls.__versions__): + + if cls.__versions__ is None: + # this uses the ORMTable autoversioner + await cls.migrate_to_version(int(version["version_num"])) + + elif int(version["version_num"]) < len(cls.__versions__): # the version in the DB is less than the version in the bot, run all the migrate scripts necessary logger.info(f"Table {cls.__tablename__} is out of date attempting to migrate") for i in range(int(version["version_num"]), len(cls.__versions__)): @@ -173,12 +183,16 @@ class ORMTable(DatabaseTable): * they are defined from class attributes * initial_create, initial_migrate, get_by, delete, and update_or_add (upsert) are handled for you * ability to instantiate objects directly from the results of SQL queries using from_record + * inferred versioning from column parameters This class can vastly reduce the amount of boilerplate in a codebase. notes: * __uniques__ MUST be a tuple of primary key column names! Do not set it to a string! Runtime will check for this and yell at you! + * By default, __version__ is None. This means that versions/migrations will be computed from arguments given to Column. + You can set __version__ to a List for the default functionality of calling migration functions if desired. + For example: @@ -186,18 +200,23 @@ class ORMTable(DatabaseTable): class StarboardConfig(db.orm.ORMTable): __tablename__ = 'starboard_settings' __uniques__ = ('guild_id',) - guild_id: int = Column("bigint NOT NULL") - channel_id: int = Column("bigint NOT NULL") - star_emoji: str = Column("varchar NOT NULL") - cancel_emoji: str = Column("varchar") - threshold: int = Column("bigint NOT NULL") + + # the column definitions + guild_id: int = db.Column("bigint NOT NULL") + channel_id: int = db.Column("bigint NOT NULL") + star_emoji: str = db.Column("varchar NOT NULL") + cancel_emoji: str = db.Column("varchar") + threshold: int = db.Column("bigint NOT NULL") + + # this parameter could be added later down the line, and it will be added using ALTER TABLE. + some_new_col: int = db.Column("bigint NOT NULL DEFAULT 10", version=1) ``` will produce a functionally equivalent class without overriding initial_create or get_by. """ __tablename__: str = '' - __versions__: Tuple[int] = tuple() + __versions__: Tuple[int] | None = None __uniques__: Tuple[str] = tuple() _columns: Dict[str, Column] = None @@ -209,10 +228,6 @@ async def initial_create(cls): """Create the table in the database. Already implemented for you. Can still override if desired.""" logger.debug(cls.__name__ + " process") - if cls is ORMTable: - # abstract class need not apply - logger.debug(cls.__name__ + " skipped") - return columns = cls.get_columns() # assemble "column = Column('integer not null') into column integer not null, " @@ -229,16 +244,6 @@ async def initial_create(cls): await conn.execute(query_str) - @classmethod - async def initial_migrate(cls): - """Create a version entry in the versions table""" - if cls is ORMTable: - # abstract class need not apply - return - - async with Pool.acquire() as conn: - await conn.execute("""INSERT INTO versions VALUES ($1, 0)""", cls.__tablename__) - @staticmethod def nullify(): """Function to be referenced when a table entry value needs to be set to null""" @@ -352,7 +357,32 @@ async def get_by(cls, **filters) -> list: async def get_one(cls, **filters): """It's like get_by except it returns exactly one record or None.""" return ((await cls.get_by(**filters)) or [None])[0] + + @classmethod + async def migrate_to_version(cls, prev_version:int, next_version:int=None): + """Migrates current table from prev_version up to next_version (setting next_version=None assumes latest) + This only supports ALTER TABLE ADD COLUMN type edits at the moment, but if you + really need more complex functionality, feel free to override this function. + + If __versions__ is set to None, this will get called in db_migrate. + """ + versions = {} + for col_name, col in cls.get_columns().items(): + v = col.version + if v <= prev_version or (next_version is not None and v > next_version): + continue + if v not in versions: + versions[v] = [(col_name, col)] + else: + versions[v].append((col_name, col)) + + async with Pool.acquire() as conn: + for vnum in sorted(versions.keys()): + for col_name, col in versions[vnum]: + await conn.execute(f"ALTER TABLE {cls.__tablename__} ADD COLUMN {col_name} {col.sql};") + logger.info(f"updated {cls.__tablename__} from version {prev_version} to {vnum}") + prev_version = vnum diff --git a/dozer/db/pqt.py b/dozer/db/pqt.py index 03eaf217..df06cf9f 100644 --- a/dozer/db/pqt.py +++ b/dozer/db/pqt.py @@ -1,8 +1,11 @@ """Postgres types. Add aliases here.""" class Column: - """Represents a sql column.""" - def __init__(self, sql: str): + """Represents a sql column. + Includes an optional version parameter for columsn added later. + """ + def __init__(self, sql: str, version=0): self.sql: str = sql + self.version: int = version -Col = Column \ No newline at end of file +Col = Column From 8ddf837acfd22f4241b5d68d8a7ed3f8f1fc8d93 Mon Sep 17 00:00:00 2001 From: Guinea Wheek Date: Sun, 1 Jan 2023 01:35:30 -0800 Subject: [PATCH 6/9] more versioning work --- dozer/db/__init__.py | 87 +++++++++++++++++++++++++++++++++----------- dozer/db/pqt.py | 29 ++++++++++++++- 2 files changed, 93 insertions(+), 23 deletions(-) diff --git a/dozer/db/__init__.py b/dozer/db/__init__.py index b278113e..032a21d7 100755 --- a/dozer/db/__init__.py +++ b/dozer/db/__init__.py @@ -2,7 +2,7 @@ # pylint generates false positives on this warning # pylint: disable=unsupported-membership-test -from typing import List, Dict, Tuple +from typing import List, Dict, Tuple, Callable import asyncpg from loguru import logger from .pqt import Column, Col @@ -44,23 +44,28 @@ async def db_migrate(): table_name = $1)""", cls.__tablename__) if not exists['exists']: await cls.initial_create() - version = None - else: - version = await Pool.fetchrow("""SELECT version_num FROM versions WHERE table_name = $1""", - cls.__tablename__) + await cls.initial_migrate() + version = await Pool.fetchrow("""SELECT version_num FROM versions WHERE table_name = $1""", + cls.__tablename__) if version is None: # Migration/creation required, go to the function in the subclass for it await cls.initial_migrate() version = {"version_num": 0} + current_version = int(version['version_num']) if cls.__versions__ is None: # this uses the ORMTable autoversioner - await cls.migrate_to_version(int(version["version_num"])) + max_cls_version = max(cls.get_columns().values(), key=lambda c: c.version) + if current_version > max_cls_version: + raise RuntimeError(f"database version for {cls.__name__} ({cls.__tablename__}) is higher than in code " + f"({current_version} > {max_cls_version})") + + await cls.migrate_to_version(current_version, None) - elif int(version["version_num"]) < len(cls.__versions__): + elif current_version < len(cls.__versions__): # the version in the DB is less than the version in the bot, run all the migrate scripts necessary logger.info(f"Table {cls.__tablename__} is out of date attempting to migrate") - for i in range(int(version["version_num"]), len(cls.__versions__)): + for i in range(current_version, len(cls.__versions__)): # Run the update script for this version! await cls.__versions__[i](cls) logger.info(f"Successfully updated table {cls.__tablename__} from version {i} to {i + 1}") @@ -73,8 +78,8 @@ async def db_migrate(): class DatabaseTable: """Defines a database table""" __tablename__: str = '' - __versions__: List[int] = [] - __uniques__: List[str] = [] + __versions__: List[Callable] = [] + __uniques__: str = '' # Declare the migrate/create functions @classmethod @@ -184,18 +189,33 @@ class ORMTable(DatabaseTable): * they are defined from class attributes * initial_create, initial_migrate, get_by, delete, and update_or_add (upsert) are handled for you * ability to instantiate objects directly from the results of SQL queries using from_record - * inferred versioning from column parameters + * inferred versioning from column parameters (mostly adds columns) This class can vastly reduce the amount of boilerplate in a codebase. - notes: + Differences from DatabaseTable: * __uniques__ MUST be a tuple of primary key column names! Do not set it to a string! Runtime will check for this and yell at you! * By default, __version__ is None. This means that versions/migrations will be computed from arguments given to Column. - You can set __version__ to a List for the default functionality of calling migration functions if desired. + * initial_create/initial_migrate will create the latest version of the table, rather than version 0. + + Versioning: + * In general, columns should be defined using the latest version of the schema. + * If your table previously had an id field that's now channel_id, name the thing - For example: + channel_id: int = db.Column("bigint NOT NULL") + + * The default versioning scheme looks at db.Column.version to see when the first version a column is introduced. + It will then either just add the column using ALTER TABLE ADD COLUMN or run a custom script via whatever script is in Column.alter_tbl. + + * You can set __version__ to a List for the default functionality of calling migration functions in order. + + You can also override cls.initial_create and cls.initial_migrate to have tables created at version 0 and upgraded all the way up, + like DatabaseTable. + + + Worked example: ```python class StarboardConfig(db.orm.ORMTable): @@ -217,7 +237,7 @@ class StarboardConfig(db.orm.ORMTable): """ __tablename__: str = '' - __versions__: Tuple[int] | None = None + __versions__: List[Callable] | None = None __uniques__: Tuple[str] = tuple() _columns: Dict[str, Column] = None @@ -226,7 +246,10 @@ class StarboardConfig(db.orm.ORMTable): # Declare the migrate/create functions @classmethod async def initial_create(cls): - """Create the table in the database. Already implemented for you. Can still override if desired.""" + """Create the table in the database. Already implemented for you. Can still override if desired. + Note that unlike DatabaseTable this will create a table directly from the latest version of the class, + rather than always version 0. + """ logger.debug(cls.__name__ + " process") @@ -358,15 +381,33 @@ async def get_by(cls, **filters) -> list: async def get_one(cls, **filters): """It's like get_by except it returns exactly one record or None.""" return ((await cls.get_by(**filters)) or [None])[0] + + + @classmethod + async def set_initial_version(cls): + """Sets initial version. + + Note that unlike DatabaseTable, it will pick the max version available, as the table on creation + will directly use the latest schema. + + """ + + if cls.__versions__ is not None: + max_version = len(cls.__versions__) + else: + max_version = max(cls.get_columns().values(), key=lambda c: c.version) + await Pool.execute("""INSERT INTO versions (table_name, version_num) VALUES ($1,$2)""", cls.__tablename__, max_version) @classmethod async def migrate_to_version(cls, prev_version:int, next_version:int=None): - """Migrates current table from prev_version up to next_version (setting next_version=None assumes latest) - This only supports ALTER TABLE ADD COLUMN type edits at the moment, but if you - really need more complex functionality, feel free to override this function. + """Migrates current table from prev_version up to next_version. (setting next_version=None assumes latest) + For each Column object, it checks if the version attr is > prev_version, and calls the corresponding alter table + action to update it. - If __versions__ is set to None, this will get called in db_migrate. + If you really need more complex functionality, feel free to override this function, or not use it. + + If __versions__ is set to None, this will get called in db_migrate. Otherwise, this function will not be used. """ versions = {} for col_name, col in cls.get_columns().items(): @@ -381,7 +422,11 @@ async def migrate_to_version(cls, prev_version:int, next_version:int=None): async with Pool.acquire() as conn: for vnum in sorted(versions.keys()): for col_name, col in versions[vnum]: - await conn.execute(f"ALTER TABLE {cls.__tablename__} ADD COLUMN {col_name} {col.sql};") + + if col.alter_tbl is None: + await conn.execute(f"ALTER TABLE {cls.__tablename__} ADD COLUMN {col_name} {col.sql};") + else: + await conn.execute(f"ALTER TABLE {cls.__tablename__} {col.alter_tbl};") logger.info(f"updated {cls.__tablename__} from version {prev_version} to {vnum}") prev_version = vnum diff --git a/dozer/db/pqt.py b/dozer/db/pqt.py index df06cf9f..4165bc49 100644 --- a/dozer/db/pqt.py +++ b/dozer/db/pqt.py @@ -2,10 +2,35 @@ class Column: """Represents a sql column. - Includes an optional version parameter for columsn added later. + Includes an optional version parameter for columns added later, and an alter_tbl field. + + + SQL injectability: don't supply user-provided input to the sql or alter_tbl fields. """ - def __init__(self, sql: str, version=0): + def __init__(self, sql: str, version=0, alter_tbl=None): + """ + sql: the type and parameters, such as "bigint NOT NULL". In general, this translates to + + CREATE TABLE tbl ({col_name} {self.sql}, ...) ...; + + during initial_create. + + + version: the first table version this column appears in. Defaults zero. + Optional if __version__ is a List. + + alter_tbl: + if None, run + ALTER TABLE ADD COLUMN {col_name} {self.sql}; + to add the table. + + If not None, run + ALTER TABLE {alter_tbl}; + instead. + + """ self.sql: str = sql self.version: int = version + self.alter_tbl = alter_tbl Col = Column From d564b3ee8bfb77eaee2854774d13a27326b8e7ce Mon Sep 17 00:00:00 2001 From: Guinea Wheek Date: Fri, 6 Jan 2023 14:38:35 -0800 Subject: [PATCH 7/9] add some type annotations --- dozer/db/__init__.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/dozer/db/__init__.py b/dozer/db/__init__.py index 032a21d7..77540833 100755 --- a/dozer/db/__init__.py +++ b/dozer/db/__init__.py @@ -2,7 +2,7 @@ # pylint generates false positives on this warning # pylint: disable=unsupported-membership-test -from typing import List, Dict, Tuple, Callable +from typing import List, Dict, Tuple, Callable, TypeVar import asyncpg from loguru import logger from .pqt import Column, Col @@ -182,6 +182,8 @@ async def set_initial_version(cls): """Sets initial version""" await Pool.execute("""INSERT INTO versions (table_name, version_num) VALUES ($1,$2)""", cls.__tablename__, 0) +# typing.Self is in 3.11 which is a shade too new for us +TORMTable = TypeVar("TORMTable", bound="ORMTable") class ORMTable(DatabaseTable): """ORM tables are a new variant on DatabaseTables: @@ -351,7 +353,7 @@ def get_columns(cls) -> Dict[str, Column]: @classmethod - def from_record(cls, record: asyncpg.Record): + def from_record(cls, record: asyncpg.Record) -> TORMTable: """Converts an asyncpg query record into an instance of the class. Nonexistent entries will get filled with None.""" if record is None: return None @@ -359,7 +361,7 @@ def from_record(cls, record: asyncpg.Record): return cls(**{k: record.get(k) for k in cls.get_columns().keys()}) @classmethod - async def get_by(cls, **filters) -> list: + async def get_by(cls, **filters) -> List[TORMTable]: """Selects a list of all records matching the given column=value criteria. Since pretty much every subclass overrides this to return lists of instantiated objects rather than queries, we simply automate this. @@ -378,7 +380,7 @@ async def get_by(cls, **filters) -> list: return [*map(cls.from_record, records)] @classmethod - async def get_one(cls, **filters): + async def get_one(cls, **filters) -> TORMTable: """It's like get_by except it returns exactly one record or None.""" return ((await cls.get_by(**filters)) or [None])[0] From 0993ad37f8960742f3f94164b5b883d745b3dc2c Mon Sep 17 00:00:00 2001 From: Guinea Wheek Date: Wed, 18 Jan 2023 15:19:50 -0800 Subject: [PATCH 8/9] revert dynamicprefixentry for now --- dozer/cogs/_utils.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/dozer/cogs/_utils.py b/dozer/cogs/_utils.py index ad4f4fd8..a3ce3e16 100755 --- a/dozer/cogs/_utils.py +++ b/dozer/cogs/_utils.py @@ -362,15 +362,32 @@ async def refresh(self): logger.info(f"{len(prefixes)} prefixes loaded from database") -class DynamicPrefixEntry(db.ORMTable): +class DynamicPrefixEntry(db.DatabaseTable): """Holds the custom prefixes for guilds""" __tablename__ = 'dynamic_prefixes' - __uniques__ = ('guild_id',) - - guild_id = db.Column("bigint not null") - prefix: str = db.Column("text not null") + __uniques__ = 'guild_id' + + @classmethod + async def initial_create(cls): + """Create the table in the database""" + async with db.Pool.acquire() as conn: + await conn.execute(f""" + CREATE TABLE {cls.__tablename__} ( + guild_id bigint NOT NULL, + prefix text NOT NULL, + PRIMARY KEY (guild_id) + )""") def __init__(self, guild_id: int, prefix: str): super().__init__() self.guild_id = guild_id self.prefix = prefix + + @classmethod + async def get_by(cls, **kwargs): + results = await super().get_by(**kwargs) + result_list = [] + for result in results: + obj = DynamicPrefixEntry(guild_id=result.get("guild_id"), prefix=result.get("prefix")) + result_list.append(obj) + return result_list From 15fcc10d27094053b84804a9a172a54d4dac8a65 Mon Sep 17 00:00:00 2001 From: Guinea Wheek Date: Wed, 18 Jan 2023 15:22:43 -0800 Subject: [PATCH 9/9] rename db.py and merge pqt in to try and avoid merge conflicts --- dozer/{db/__init__.py => db.py} | 35 +++++++++++++++++++++++++++++++- dozer/db/pqt.py | 36 --------------------------------- 2 files changed, 34 insertions(+), 37 deletions(-) rename dozer/{db/__init__.py => db.py} (95%) delete mode 100644 dozer/db/pqt.py diff --git a/dozer/db/__init__.py b/dozer/db.py similarity index 95% rename from dozer/db/__init__.py rename to dozer/db.py index 77540833..da9fd795 100755 --- a/dozer/db/__init__.py +++ b/dozer/db.py @@ -5,12 +5,45 @@ from typing import List, Dict, Tuple, Callable, TypeVar import asyncpg from loguru import logger -from .pqt import Column, Col __all__ = ["Pool", "Column", "Col", "db_init", "db_migrate", "DatabaseTable", "ORMTable", "ConfigCache"] Pool: asyncpg.Pool = None +class Column: + """Represents a sql column. + Includes an optional version parameter for columns added later, and an alter_tbl field. + + + SQL injectability: don't supply user-provided input to the sql or alter_tbl fields. + """ + def __init__(self, sql: str, version=0, alter_tbl=None): + """ + sql: the type and parameters, such as "bigint NOT NULL". In general, this translates to + + CREATE TABLE tbl ({col_name} {self.sql}, ...) ...; + + during initial_create. + + + version: the first table version this column appears in. Defaults zero. + Optional if __version__ is a List. + + alter_tbl: + if None, run + ALTER TABLE ADD COLUMN {col_name} {self.sql}; + to add the table. + + If not None, run + ALTER TABLE {alter_tbl}; + instead. + + """ + self.sql: str = sql + self.version: int = version + self.alter_tbl = alter_tbl + +Col = Column async def db_init(db_url): """Initializes the database connection""" diff --git a/dozer/db/pqt.py b/dozer/db/pqt.py deleted file mode 100644 index 4165bc49..00000000 --- a/dozer/db/pqt.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Postgres types. Add aliases here.""" - -class Column: - """Represents a sql column. - Includes an optional version parameter for columns added later, and an alter_tbl field. - - - SQL injectability: don't supply user-provided input to the sql or alter_tbl fields. - """ - def __init__(self, sql: str, version=0, alter_tbl=None): - """ - sql: the type and parameters, such as "bigint NOT NULL". In general, this translates to - - CREATE TABLE tbl ({col_name} {self.sql}, ...) ...; - - during initial_create. - - - version: the first table version this column appears in. Defaults zero. - Optional if __version__ is a List. - - alter_tbl: - if None, run - ALTER TABLE ADD COLUMN {col_name} {self.sql}; - to add the table. - - If not None, run - ALTER TABLE {alter_tbl}; - instead. - - """ - self.sql: str = sql - self.version: int = version - self.alter_tbl = alter_tbl - -Col = Column