Skip to content

Commit

Permalink
Handle column population for tables that inherit from other tables
Browse files Browse the repository at this point in the history
  • Loading branch information
jzmiller1 committed Aug 15, 2024
1 parent 478c6e3 commit 6904b27
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 13 deletions.
13 changes: 12 additions & 1 deletion postnormalism/schema/database.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from dataclasses import dataclass, field

from ..core import create_items, create_extensions
from . import DatabaseItem, PostnormalismMigrations, Schema
from . import DatabaseItem, PostnormalismMigrations, Schema, Table


class SchemaProxy:
Expand Down Expand Up @@ -38,11 +39,21 @@ def __post_init__(self):
if not os.path.isdir(self.migrations_folder):
print("Invalid migrations folder.")
for entry in self.load_order:
self._set_database_reference(entry)
if isinstance(entry, list):
self.add_items(*entry, schema_loaded=schema_loaded)
else:
self.add_items(entry, schema_loaded=schema_loaded)

def _set_database_reference(self, item: DatabaseItem):
if isinstance(item, list):
for sub_item in item:
self._set_database_reference(sub_item)
else:
object.__setattr__(item, '_database', self)
if isinstance(item, Table):
item._initialize_columns()

def add_items(self, *items: DatabaseItem, schema_loaded: set) -> None:
for item in items:
item_type = type(item).__name__
Expand Down
21 changes: 20 additions & 1 deletion postnormalism/schema/database_item.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass, field
import re
import warnings


@dataclass(frozen=True)
class DatabaseItem:
"""
A base data class for schema items like tables and functions.
A base data class for schema items like tables, functions, etc.
"""
create: str
comment: str = field(default=None)
Expand All @@ -14,6 +15,7 @@ class DatabaseItem:
_name: str = field(init=False, default=None)
_schema_pattern: str = field(default=None)
_schema: str = field(init=False, default=None)
_database: object = field(default=None, init=False, repr=False) # Internal use only

def __post_init__(self):
create = self.create.upper()
Expand Down Expand Up @@ -53,3 +55,20 @@ def schema(self) -> str:
@property
def itype(self) -> str:
return self._item_type

@property
def database(self):
"""Get the database reference."""
return self._database

@database.setter
def database(self, db):
"""Set the database reference for this item."""
object.__setattr__(self, '_database', db)

def warn_if_no_database(self):
"""Warn if the database is not set and required for operations."""
if not self.database:
warnings.warn(
f"Database reference is not set for '{self.name}'. Certain operations may not work correctly."
)
57 changes: 46 additions & 11 deletions postnormalism/schema/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,36 @@ class Table(DatabaseItem):
A data class for tables.
"""
_item_type: str = 'table'
_name_pattern: str = field(default=r'CREATE\s+(?:OR\s+REPLACE\s+)?(?:TEMP\s+)?(?:TABLE)\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:\w+\.)?(\w+)')
_name_pattern: str = field(
default=r'CREATE\s+(?:OR\s+REPLACE\s+)?(?:TEMP\s+)?(?:TABLE)\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:\w+\.)?(\w+)')
_schema_pattern: str = field(default=r'CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(\w+)\.\w+')
_pattern_create: str = field(default=r"^\s*(\w+)\s+(?:[\w\(\)]+).*?(?:,|$)")
_pattern_alter: str = field(default=r"ADD COLUMN\s+(\w+)\s+[\w\(\)]+")
_pattern_inherits: str = field(default=r"INHERITS\s*\((\w+)\)")

alter: str = field(default=None)
columns: list[str] = field(init=False)
inherits: bool = field(default=False, init=False)
_columns: list[str] = field(default=None, init=False, repr=False)

def __post_init__(self):
super().__post_init__()
inherit_match = re.search(self._pattern_inherits, self.create, re.IGNORECASE)
if inherit_match:
object.__setattr__(self, 'inherits', True)
self._initialize_columns()

@property
def columns(self):
if self._columns is None:
self._initialize_columns()
return self._columns

def _initialize_columns(self):
self._extract_columns()
if self.inherits:
self._extract_inherited_columns()

def _extract_columns(self):
# match column definitions in both CREATE TABLE and ALTER TABLE statements
pattern_create = r"^\s*(\w+)\s+(?:[\w\(\)]+).*?(?:,|$)"
pattern_alter = r"ADD COLUMN\s+(\w+)\s+[\w\(\)]+"

columns = []
in_table_definition = False

Expand All @@ -35,26 +51,45 @@ def _extract_columns(self):
if in_table_definition:
parts = line.split(',')
for part in parts:
match = re.match(pattern_create, part.strip())
match = re.match(self._pattern_create, part.strip())
if match:
columns.append(match.group(1))

if self.alter:
for line in self.alter.splitlines():
match = re.search(pattern_alter, line.strip(), re.IGNORECASE)
match = re.search(self._pattern_alter, line.strip(), re.IGNORECASE)
if match:
columns.append(match.group(1))
object.__setattr__(self, '_columns', columns)

object.__setattr__(self, 'columns', columns)
def _extract_inherited_columns(self):
inherit_match = re.search(self._pattern_inherits, self.create, re.IGNORECASE)
if inherit_match and self.database:
parent_table_name = inherit_match.group(1).lower()
parent_table = self._get_parent_table(parent_table_name)
if parent_table:
parent_columns = parent_table.columns if parent_table else []
columns = parent_columns + (self._columns or [])
object.__setattr__(self, '_columns', columns)
else:
# Fallback to setting columns to an empty list if no parent found
object.__setattr__(self, '_columns', [])

def _get_parent_table(self, parent_table_name):
if not self.database:
raise ValueError("Database reference not set in Table instance.")
return getattr(self.database, parent_table_name)

def full_sql(self, exists=False) -> str:
sql_parts = super().full_sql().split("\n\n")
sql_parts = [self.create.strip()]

if exists:
sql_parts[0] = sql_parts[0].replace("CREATE TABLE", "CREATE TABLE IF NOT EXISTS")

if self.comment:
sql_parts.append(self.comment.strip())

if self.alter:
sql_parts.append(self.alter.strip())

return "\n\n".join(sql_parts)

32 changes: 32 additions & 0 deletions tests/items/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,35 @@ def test_create_table_with_alter_table(self):
table = Table(create=sql_create, alter=sql_alter)
expected_columns = ["order_id", "customer_id", "order_date", "amount", "status", "notes"]
self.assertEqual(table.columns, expected_columns)

def test_inherited_table_columns(self):
create_parent = """
CREATE TABLE process_element (
id uuid PRIMARY KEY DEFAULT uuid_generate_v4(),
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
"""

create_child = """
CREATE TABLE process_element_material (
element uuid REFERENCES material NOT NULL
) INHERITS (process_element);
"""

ProcessElement = Table(create=create_parent)
ProcessElementMaterial = Table(create=create_child)

from postnormalism.schema import Database

universe_db = Database(
load_order=[
ProcessElement,
ProcessElementMaterial,
],
)

expected_columns = [
"id", "created_at", "updated_at", "element"
]
self.assertEqual(universe_db.process_element_material.columns, expected_columns)

0 comments on commit 6904b27

Please sign in to comment.