diff --git a/postnormalism/schema/database.py b/postnormalism/schema/database.py index 723cff4..770560d 100644 --- a/postnormalism/schema/database.py +++ b/postnormalism/schema/database.py @@ -1,7 +1,8 @@ import os from dataclasses import dataclass, field + from ..core import create_items, create_extensions -from . import DatabaseItem, PostnormalismMigrations, Schema +from . import DatabaseItem, PostnormalismMigrations, Schema, Table class SchemaProxy: @@ -38,11 +39,21 @@ def __post_init__(self): if not os.path.isdir(self.migrations_folder): print("Invalid migrations folder.") for entry in self.load_order: + self._set_database_reference(entry) if isinstance(entry, list): self.add_items(*entry, schema_loaded=schema_loaded) else: self.add_items(entry, schema_loaded=schema_loaded) + def _set_database_reference(self, item: DatabaseItem): + if isinstance(item, list): + for sub_item in item: + self._set_database_reference(sub_item) + else: + object.__setattr__(item, '_database', self) + if isinstance(item, Table): + item._initialize_columns() + def add_items(self, *items: DatabaseItem, schema_loaded: set) -> None: for item in items: item_type = type(item).__name__ diff --git a/postnormalism/schema/database_item.py b/postnormalism/schema/database_item.py index d3623b0..a502c3c 100644 --- a/postnormalism/schema/database_item.py +++ b/postnormalism/schema/database_item.py @@ -1,11 +1,12 @@ from dataclasses import dataclass, field import re +import warnings @dataclass(frozen=True) class DatabaseItem: """ - A base data class for schema items like tables and functions. + A base data class for schema items like tables, functions, etc. """ create: str comment: str = field(default=None) @@ -14,6 +15,7 @@ class DatabaseItem: _name: str = field(init=False, default=None) _schema_pattern: str = field(default=None) _schema: str = field(init=False, default=None) + _database: object = field(default=None, init=False, repr=False) # Internal use only def __post_init__(self): create = self.create.upper() @@ -53,3 +55,20 @@ def schema(self) -> str: @property def itype(self) -> str: return self._item_type + + @property + def database(self): + """Get the database reference.""" + return self._database + + @database.setter + def database(self, db): + """Set the database reference for this item.""" + object.__setattr__(self, '_database', db) + + def warn_if_no_database(self): + """Warn if the database is not set and required for operations.""" + if not self.database: + warnings.warn( + f"Database reference is not set for '{self.name}'. Certain operations may not work correctly." + ) diff --git a/postnormalism/schema/table.py b/postnormalism/schema/table.py index 8513715..7232c02 100644 --- a/postnormalism/schema/table.py +++ b/postnormalism/schema/table.py @@ -10,20 +10,36 @@ class Table(DatabaseItem): A data class for tables. """ _item_type: str = 'table' - _name_pattern: str = field(default=r'CREATE\s+(?:OR\s+REPLACE\s+)?(?:TEMP\s+)?(?:TABLE)\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:\w+\.)?(\w+)') + _name_pattern: str = field( + default=r'CREATE\s+(?:OR\s+REPLACE\s+)?(?:TEMP\s+)?(?:TABLE)\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:\w+\.)?(\w+)') _schema_pattern: str = field(default=r'CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(\w+)\.\w+') + _pattern_create: str = field(default=r"^\s*(\w+)\s+(?:[\w\(\)]+).*?(?:,|$)") + _pattern_alter: str = field(default=r"ADD COLUMN\s+(\w+)\s+[\w\(\)]+") + _pattern_inherits: str = field(default=r"INHERITS\s*\((\w+)\)") + alter: str = field(default=None) - columns: list[str] = field(init=False) + inherits: bool = field(default=False, init=False) + _columns: list[str] = field(default=None, init=False, repr=False) def __post_init__(self): super().__post_init__() + inherit_match = re.search(self._pattern_inherits, self.create, re.IGNORECASE) + if inherit_match: + object.__setattr__(self, 'inherits', True) + self._initialize_columns() + + @property + def columns(self): + if self._columns is None: + self._initialize_columns() + return self._columns + + def _initialize_columns(self): self._extract_columns() + if self.inherits: + self._extract_inherited_columns() def _extract_columns(self): - # match column definitions in both CREATE TABLE and ALTER TABLE statements - pattern_create = r"^\s*(\w+)\s+(?:[\w\(\)]+).*?(?:,|$)" - pattern_alter = r"ADD COLUMN\s+(\w+)\s+[\w\(\)]+" - columns = [] in_table_definition = False @@ -35,26 +51,45 @@ def _extract_columns(self): if in_table_definition: parts = line.split(',') for part in parts: - match = re.match(pattern_create, part.strip()) + match = re.match(self._pattern_create, part.strip()) if match: columns.append(match.group(1)) if self.alter: for line in self.alter.splitlines(): - match = re.search(pattern_alter, line.strip(), re.IGNORECASE) + match = re.search(self._pattern_alter, line.strip(), re.IGNORECASE) if match: columns.append(match.group(1)) + object.__setattr__(self, '_columns', columns) - object.__setattr__(self, 'columns', columns) + def _extract_inherited_columns(self): + inherit_match = re.search(self._pattern_inherits, self.create, re.IGNORECASE) + if inherit_match and self.database: + parent_table_name = inherit_match.group(1).lower() + parent_table = self._get_parent_table(parent_table_name) + if parent_table: + parent_columns = parent_table.columns if parent_table else [] + columns = parent_columns + (self._columns or []) + object.__setattr__(self, '_columns', columns) + else: + # Fallback to setting columns to an empty list if no parent found + object.__setattr__(self, '_columns', []) + + def _get_parent_table(self, parent_table_name): + if not self.database: + raise ValueError("Database reference not set in Table instance.") + return getattr(self.database, parent_table_name) def full_sql(self, exists=False) -> str: - sql_parts = super().full_sql().split("\n\n") + sql_parts = [self.create.strip()] if exists: sql_parts[0] = sql_parts[0].replace("CREATE TABLE", "CREATE TABLE IF NOT EXISTS") + if self.comment: + sql_parts.append(self.comment.strip()) + if self.alter: sql_parts.append(self.alter.strip()) return "\n\n".join(sql_parts) - diff --git a/tests/items/test_table.py b/tests/items/test_table.py index 2cf938e..bf6eab3 100644 --- a/tests/items/test_table.py +++ b/tests/items/test_table.py @@ -185,3 +185,35 @@ def test_create_table_with_alter_table(self): table = Table(create=sql_create, alter=sql_alter) expected_columns = ["order_id", "customer_id", "order_date", "amount", "status", "notes"] self.assertEqual(table.columns, expected_columns) + + def test_inherited_table_columns(self): + create_parent = """ + CREATE TABLE process_element ( + id uuid PRIMARY KEY DEFAULT uuid_generate_v4(), + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP + ); + """ + + create_child = """ + CREATE TABLE process_element_material ( + element uuid REFERENCES material NOT NULL + ) INHERITS (process_element); + """ + + ProcessElement = Table(create=create_parent) + ProcessElementMaterial = Table(create=create_child) + + from postnormalism.schema import Database + + universe_db = Database( + load_order=[ + ProcessElement, + ProcessElementMaterial, + ], + ) + + expected_columns = [ + "id", "created_at", "updated_at", "element" + ] + self.assertEqual(universe_db.process_element_material.columns, expected_columns)