Skip to content

Commit

Permalink
Support continuous aggregate on view
Browse files Browse the repository at this point in the history
  • Loading branch information
diorcety authored and Yann Diorcet committed Feb 26, 2024
1 parent 9ecd4c2 commit 64767df
Showing 1 changed file with 61 additions and 24 deletions.
85 changes: 61 additions & 24 deletions sqlalchemy_timescaledb/dialect.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import textwrap
from typing import Optional, Mapping, Any

from sqlalchemy import schema, event, DDL
from sqlalchemy import schema, event, DDL, Table, Dialect, ExecutableDDLElement
from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg
from sqlalchemy.dialects.postgresql.base import PGDDLCompiler
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.engine.interfaces import SchemaTranslateMapType
from sqlalchemy.ext import compiler
from sqlalchemy_utils.view import CreateView, compile_create_materialized_view

try:
import alembic
Expand All @@ -16,8 +20,44 @@ class TimescaledbImpl(postgresql.PostgresqlImpl):
__dialect__ = 'timescaledb'


def _get_interval(value):
if isinstance(value, str):
return f"INTERVAL '{value}'"
elif isinstance(value, int):
return str(value)
else:
return "NULL"


def _create_map(mapping: dict):
return ", ".join([f'{key} => {value}' for key, value in mapping.items()])


@compiler.compiles(CreateView, 'timescaledb')
def compile_create_view(create, compiler, **kw):
return compiler.visit_create_view(create, **kw)

class TimescaledbDDLCompiler(PGDDLCompiler):
def post_create_table(self, table):

def visit_create_view(self, create, **kw):
ret = compile_create_materialized_view(create, self, **kw)
view = create.element
continuous = view.kwargs.get('timescaledb_continuous', {})
if continuous:
event.listen(
view,
'after_create',
self.ddl_add_continuous(
view.name, continuous
).execute_if(
dialect='timescaledb'
)
)
return ret

def visit_create_table(self, create, **kw):
ret = super().visit_create_table(create, **kw)
table = create.element
hypertable = table.kwargs.get('timescaledb_hypertable', {})
compress = table.kwargs.get('timescaledb_compress', {})

Expand Down Expand Up @@ -52,28 +92,15 @@ def post_create_table(self, table):
)
)


return super().post_create_table(table)
return ret

@staticmethod
def ddl_hypertable(table_name, hypertable):
time_column_name = hypertable['time_column_name']
chunk_time_interval = hypertable.get('chunk_time_interval', '7 days')

if isinstance(chunk_time_interval, str):
if chunk_time_interval.isdigit():
chunk_time_interval = int(chunk_time_interval)
else:
chunk_time_interval = f"INTERVAL '{chunk_time_interval}'"
chunk_time_interval = _get_interval(hypertable.get('chunk_time_interval', '7 days'))

return DDL(textwrap.dedent(f"""
SELECT create_hypertable(
'{table_name}',
'{time_column_name}',
chunk_time_interval => {chunk_time_interval},
if_not_exists => TRUE
)
"""))
parameters = _create_map(dict(chunk_time_interval=chunk_time_interval, if_not_exists="TRUE"))
return DDL(textwrap.dedent(f"""SELECT create_hypertable('{table_name}','{time_column_name}',{parameters})"""))

@staticmethod
def ddl_compress(table_name, compress):
Expand All @@ -85,11 +112,20 @@ def ddl_compress(table_name, compress):

@staticmethod
def ddl_compression_policy(table_name, compress):
compression_policy_interval = compress.get('compression_policy_interval', '7 days')
schedule_interval = _get_interval(compress.get('compression_policy_schedule_interval', '7 days'))

return DDL(textwrap.dedent(f"""
SELECT add_compression_policy('{table_name}', INTERVAL '{compression_policy_interval}')
"""))
parameters = _create_map(dict(schedule_interval=schedule_interval))
return DDL(textwrap.dedent(f"""SELECT add_compression_policy('{table_name}', {parameters}')"""))

@staticmethod
def ddl_add_continuous(table_name, continuous):
start_offset = _get_interval(continuous.get('continuous_aggregate_policy_start_offset', None))
end_offset = _get_interval(continuous.get('continuous_aggregate_policy_end_offset', None))
schedule_interval = _get_interval(continuous.get('continuous_aggregate_policy_schedule_interval', None))

parameters = _create_map(
dict(start_offset=start_offset, end_offset=end_offset, schedule_interval=schedule_interval))
return DDL(textwrap.dedent(f"""SELECT add_continuous_aggregate_policy('{table_name}', {parameters})"""))


class TimescaledbDialect:
Expand All @@ -99,7 +135,8 @@ class TimescaledbDialect:
(
schema.Table, {
"hypertable": {},
"compress": {}
"compress": {},
"continuous": {},
}
)
]
Expand Down

0 comments on commit 64767df

Please sign in to comment.