Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Incorrect Column Type Crash #63

Merged
merged 4 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions omymodels/models/dataclass/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Dict, List, Optional
from typing import List, Optional

from table_meta import TableMeta
from table_meta.model import Column

import omymodels.types as t
from omymodels.helpers import create_class_name, datetime_now_check
Expand All @@ -24,7 +27,7 @@ def add_custom_type(self, _type: str) -> str:
column_type = column_type[0]
return _type

def generate_attr(self, column: Dict, defaults_off: bool) -> str:
def generate_attr(self, column: Column, defaults_off: bool) -> str:
column_str = dt.dataclass_attr

if "." in column.type:
Expand Down Expand Up @@ -57,19 +60,19 @@ def generate_attr(self, column: Dict, defaults_off: bool) -> str:
return column_str

@staticmethod
def add_column_default(column_str: str, column: Dict) -> str:
def add_column_default(column_str: str, column: Column) -> str:
if column.type.upper() in datetime_types:
if datetime_now_check(column.default.lower()):
# todo: need to add other popular PostgreSQL & MySQL functions
column.default = dt.field_datetime_now
elif "'" not in column.default:
column.default = f"'{column['default']}'"
column.default = f"'{column.default}'"
column_str += dt.dataclass_default_attr.format(default=column.default)
return column_str

def generate_model(
self,
table: Dict,
table: TableMeta,
singular: bool = True,
exceptions: Optional[List] = None,
defaults_off: Optional[bool] = False,
Expand Down
35 changes: 21 additions & 14 deletions omymodels/models/pydantic/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, List, Optional
from typing import List, Optional

from table_meta.model import Column
from table_meta.model import Column, TableMeta

import omymodels.types as t
from omymodels.helpers import create_class_name, datetime_now_check
Expand All @@ -11,29 +11,28 @@

class ModelGenerator:
def __init__(self):
self.imports = set([pt.base_model])
self.imports = {pt.base_model}
self.types_for_import = ["Json"]
self.datetime_import = False
self.typing_imports = set()
self.custom_types = {}
self.uuid_import = False
self.prefix = ""

def add_custom_type(self, target_type):
def add_custom_type(self, target_type: str) -> Optional[str]:
column_type = self.custom_types.get(target_type, None)
_type = None
if isinstance(column_type, tuple):
_type = column_type[1]
return _type

def get_not_custom_type(self, column: Column):
def get_not_custom_type(self, column: Column) -> str:
_type = None
if "." in column.type:
_type = column.type.split(".")[1]
else:
_type = column.type.lower().split("[")[0]
if _type == _type:
_type = types_mapping.get(_type, _type)
_type = types_mapping.get(_type, _type)
if _type in self.types_for_import:
self.imports.add(_type)
elif "datetime" in _type:
Expand All @@ -45,40 +44,48 @@ def get_not_custom_type(self, column: Column):
self.uuid_import = True
return _type

def generate_attr(self, column: Dict, defaults_off: bool) -> str:
def generate_attr(self, column: Column, defaults_off: bool) -> str:
_type = None

if column.nullable:
self.typing_imports.add("Optional")
column_str = pt.pydantic_optional_attr
else:
column_str = pt.pydantic_attr

if self.custom_types:
_type = self.add_custom_type(column.type)
if not _type:
_type = self.get_not_custom_type(column)

column_str = column_str.format(arg_name=column.name, type=_type)

if column.default and defaults_off is False:
if column.default is not None and not defaults_off:
column_str = self.add_default_values(column_str, column)

return column_str

@staticmethod
def add_default_values(column_str: str, column: Dict) -> str:
def add_default_values(column_str: str, column: Column) -> str:
# Handle datetime default values
if column.type.upper() in datetime_types:
if datetime_now_check(column.default.lower()):
# todo: need to add other popular PostgreSQL & MySQL functions
# Handle functions like CURRENT_TIMESTAMP
column.default = "datetime.datetime.now()"
elif "'" not in column.default:
column.default = f"'{column['default']}'"
elif column.default.upper() != "NULL" and "'" not in column.default:
column.default = f"'{column.default}'"

# If the default is 'NULL', don't set a default in Pydantic (it already defaults to None)
if column.default.upper() == "NULL":
return column_str

# Append the default value if it's not None (e.g., explicit default values like '0' or CURRENT_TIMESTAMP)
column_str += pt.pydantic_default_attr.format(default=column.default)
return column_str

def generate_model(
self,
table: Dict,
table: TableMeta,
singular: bool = True,
exceptions: Optional[List] = None,
defaults_off: Optional[bool] = False,
Expand Down
2 changes: 1 addition & 1 deletion omymodels/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def process_types_after_models_parser(column_data: Column) -> Column:
return column_data


def prepare_column_data(column_data: Column) -> str:
def prepare_column_data(column_data: Column) -> Column:
if "." in column_data.type or "(":
column_data = process_types_after_models_parser(column_data)
return column_data
Expand Down
Loading