diff --git a/src/ldlite/__init__.py b/src/ldlite/__init__.py index 61c6081..063bea9 100644 --- a/src/ldlite/__init__.py +++ b/src/ldlite/__init__.py @@ -50,7 +50,7 @@ from ._csv import to_csv from ._database import Prefix from ._folio import FolioClient -from ._jsonx import Attr, drop_json_tables, transform_json +from ._jsonx import Attr, transform_json from ._select import select from ._sqlx import ( DBType, @@ -82,6 +82,7 @@ def __init__(self) -> None: self._quiet = False self.dbtype: DBType = DBType.UNDEFINED self.db: dbapi.DBAPIConnection | None = None + self._db: DBTypeDatabase | None = None self._folio: FolioClient | None = None self.page_size = 1000 self._okapi_timeout = 60 @@ -132,6 +133,11 @@ def _connect_db_duckdb( fn = filename if filename is not None else ":memory:" db = duckdb.connect(database=fn) self.db = cast("dbapi.DBAPIConnection", db.cursor()) + self._db = DBTypeDatabase( + DBType.DUCKDB, + lambda: cast("dbapi.DBAPIConnection", db.cursor()), + ) + return db.cursor() def connect_db_postgresql(self, dsn: str) -> psycopg2.extensions.connection: @@ -150,7 +156,10 @@ def connect_db_postgresql(self, dsn: str) -> psycopg2.extensions.connection: self.dbtype = DBType.POSTGRES db = psycopg.connect(dsn) self.db = cast("dbapi.DBAPIConnection", db) - autocommit(self.db, self.dbtype, True) + self._db = DBTypeDatabase( + DBType.POSTGRES, + lambda: cast("dbapi.DBAPIConnection", psycopg.connect(dsn)), + ) ret_db = psycopg2.connect(dsn) ret_db.rollback() @@ -180,6 +189,10 @@ def experimental_connect_db_sqlite( self.dbtype = DBType.SQLITE fn = filename if filename is not None else "file::memory:?cache=shared" self.db = sqlite3.connect(fn) + self._db = DBTypeDatabase( + DBType.SQLITE, + lambda: cast("dbapi.DBAPIConnection", sqlite3.connect(fn)), + ) db = sqlite3.connect(fn) autocommit(db, self.dbtype, True) @@ -223,22 +236,16 @@ def drop_tables(self, table: str) -> None: ld.drop_tables('g') """ - if self.db is None: + if self.db is None or self._db is None: self._check_db() return - autocommit(self.db, self.dbtype, True) schema_table = table.strip().split(".") - if len(schema_table) < 1 or len(schema_table) > 2: + if len(schema_table) != 1 and len(schema_table) != 2: raise ValueError("invalid table name: " + table) - self._check_db() - cur = self.db.cursor() - try: - cur.execute("DROP TABLE IF EXISTS " + sqlid(table)) - except (RuntimeError, psycopg2.Error): - pass - finally: - cur.close() - drop_json_tables(self.db, table) + if len(schema_table) == 2 and self.dbtype == DBType.SQLITE: + table = schema_table[0] + "_" + schema_table[1] + prefix = Prefix(table) + self._db.drop_prefix(prefix) def set_folio_max_retries(self, max_retries: int) -> None: """Sets the maximum number of retries for FOLIO requests. @@ -338,16 +345,14 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 if self._folio is None: self._check_folio() return [] - if self.db is None: + if self.db is None or self._db is None: self._check_db() return [] if len(schema_table) == 2 and self.dbtype == DBType.SQLITE: table = schema_table[0] + "_" + schema_table[1] - schema_table = [table] + prefix = Prefix(table) if not self._quiet: print("ldlite: querying: " + path, file=sys.stderr) - drop_json_tables(self.db, table) - autocommit(self.db, self.dbtype, False) try: # First get total number of records records = self._folio.iterate_records( @@ -362,8 +367,9 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 if self._verbose: print("ldlite: estimated row count: " + str(total), file=sys.stderr) - processed = count(0) - pbar = None + p_count = count(0) + processed = 0 + pbar: tqdm | PbarNoop # type:ignore[type-arg] if not self._quiet: pbar = tqdm( desc="reading", @@ -374,28 +380,43 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 colour="#A9A9A9", bar_format="{desc} {bar}{postfix}", ) + else: + + class PbarNoop: + def update(self, _: int) -> None: ... + def close(self) -> None: ... + + pbar = PbarNoop() def on_processed() -> bool: - if pbar is not None: - pbar.update(1) - p = next(processed) - return limit is None or p >= limit - - cur = self.db.cursor() - db = DBTypeDatabase(self.dbtype, self.db) - db.ingest_records(self.db, Prefix(table), on_processed, records) - self.db.commit() - if pbar is not None: - pbar.close() + pbar.update(1) + nonlocal processed + processed = next(p_count) + return True + + def on_processed_limit() -> bool: + pbar.update(1) + nonlocal processed + processed = next(p_count) + return limit is None or processed >= limit + + self._db.ingest_records( + prefix, + on_processed_limit if limit is not None else on_processed, + records, + ) + pbar.close() + self._db.drop_extracted_tables(prefix) newtables = [table] newattrs = {} if json_depth > 0: + autocommit(self.db, self.dbtype, False) jsontables, jsonattrs = transform_json( self.db, self.dbtype, table, - next(processed) - 1, + processed, self._quiet, json_depth, ) @@ -406,12 +427,7 @@ def on_processed() -> bool: newattrs[table] = {"__id": Attr("__id", "bigint")} if not keep_raw: - cur = self.db.cursor() - try: - cur.execute("DROP TABLE " + sqlid(table)) - self.db.commit() - finally: - cur.close() + self._db.drop_raw_table(prefix) finally: autocommit(self.db, self.dbtype, True) @@ -446,10 +462,8 @@ def on_processed() -> bool: pass finally: cur.close() - if pbar is not None: - pbar.update(1) - if pbar is not None: - pbar.close() + pbar.update(1) + pbar.close() # Return table names if not self._quiet: print("ldlite: created tables: " + ", ".join(newtables), file=sys.stderr) diff --git a/src/ldlite/_database.py b/src/ldlite/_database.py index f1988de..0024e4f 100644 --- a/src/ldlite/_database.py +++ b/src/ldlite/_database.py @@ -2,12 +2,12 @@ from abc import ABC, abstractmethod from contextlib import closing -from typing import TYPE_CHECKING, Callable, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, cast from psycopg import sql if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Iterator, Sequence from _typeshed import dbapi @@ -27,29 +27,125 @@ def __init__(self, table: str): def schema_name(self) -> sql.Identifier | None: return None if self._schema is None else sql.Identifier(self._schema) + def identifier(self, table: str) -> sql.Identifier: + if self._schema is None: + return sql.Identifier(table) + return sql.Identifier(self._schema, table) + @property def raw_table_name(self) -> sql.Identifier: - return ( - sql.Identifier(self._schema, self._prefix) - if self._schema is not None - else sql.Identifier(self._prefix) - ) + return self.identifier(self._prefix) + + @property + def catalog_table_name(self) -> sql.Identifier: + return self.identifier(f"{self._prefix}__tcatalog") + + @property + def legacy_jtable(self) -> sql.Identifier: + return self.identifier(f"{self._prefix}_jtable") class Database(ABC, Generic[DB]): def __init__(self, conn_factory: Callable[[], DB]): self._conn_factory = conn_factory + @abstractmethod + def _rollback(self, conn: DB) -> None: ... + + def drop_prefix( + self, + prefix: Prefix, + ) -> None: + with closing(self._conn_factory()) as conn: + self._drop_extracted_tables(conn, prefix) + self._drop_raw_table(conn, prefix) + conn.commit() + + def drop_raw_table( + self, + prefix: Prefix, + ) -> None: + with closing(self._conn_factory()) as conn: + self._drop_raw_table(conn, prefix) + conn.commit() + + def _drop_raw_table( + self, + conn: DB, + prefix: Prefix, + ) -> None: + with closing(conn.cursor()) as cur: + cur.execute( + sql.SQL("DROP TABLE IF EXISTS {table};") + .format(table=prefix.raw_table_name) + .as_string(), + ) + + def drop_extracted_tables( + self, + prefix: Prefix, + ) -> None: + with closing(self._conn_factory()) as conn: + self._drop_extracted_tables(conn, prefix) + conn.commit() + @property @abstractmethod - def _truncate_raw_table_sql(self) -> sql.SQL: ... + def _missing_table_error(self) -> tuple[type[Exception], ...]: ... + def _drop_extracted_tables( + self, + conn: DB, + prefix: Prefix, + ) -> None: + tables: list[Sequence[Sequence[Any]]] = [] + with closing(conn.cursor()) as cur: + try: + cur.execute( + sql.SQL("SELECT table_name FROM {catalog};") + .format(catalog=prefix.catalog_table_name) + .as_string(), + ) + except self._missing_table_error: + self._rollback(conn) + else: + tables.extend(cur.fetchall()) + + with closing(conn.cursor()) as cur: + try: + cur.execute( + sql.SQL("SELECT table_name FROM {catalog};") + .format(catalog=prefix.legacy_jtable) + .as_string(), + ) + except self._missing_table_error: + self._rollback(conn) + else: + tables.extend(cur.fetchall()) + + with closing(conn.cursor()) as cur: + for (et,) in tables: + cur.execute( + sql.SQL("DROP TABLE IF EXISTS {table};") + .format(table=sql.Identifier(cast("str", et))) + .as_string(), + ) + cur.execute( + sql.SQL("DROP TABLE IF EXISTS {catalog};") + .format(catalog=prefix.catalog_table_name) + .as_string(), + ) + cur.execute( + sql.SQL("DROP TABLE IF EXISTS {catalog};") + .format(catalog=prefix.legacy_jtable) + .as_string(), + ) + @property @abstractmethod - def _create_raw_table_sql(self) -> sql.SQL: ... + def _truncate_raw_table_sql(self) -> sql.SQL: ... @property @abstractmethod - def _insert_record_sql(self) -> sql.SQL: ... - + def _create_raw_table_sql(self) -> sql.SQL: ... def _prepare_raw_table( self, conn: DB, @@ -62,37 +158,33 @@ def _prepare_raw_table( .format(schema=prefix.schema_name) .as_string(), ) - + self._drop_raw_table(conn, prefix) + with closing(conn.cursor()) as cur: cur.execute( self._create_raw_table_sql.format( table=prefix.raw_table_name, ).as_string(), ) - cur.execute( - self._truncate_raw_table_sql.format( - table=prefix.raw_table_name, - ).as_string(), - ) + @property + @abstractmethod + def _insert_record_sql(self) -> sql.SQL: ... def ingest_records( self, - conn: DB, prefix: Prefix, on_processed: Callable[[], bool], records: Iterator[tuple[int, str | bytes]], ) -> None: - # the only implementation right now is a hack - # the db connection is managed outside of the factory - # for now it's taken as a parameter - # with self._conn_factory() as conn: - self._prepare_raw_table(conn, prefix) - with closing(conn.cursor()) as cur: - for pkey, d in records: - cur.execute( - self._insert_record_sql.format( - table=prefix.raw_table_name, - ).as_string(), - [pkey, d if isinstance(d, str) else d.decode("utf-8")], - ) - if not on_processed(): - return + with closing(self._conn_factory()) as conn: + self._prepare_raw_table(conn, prefix) + with closing(conn.cursor()) as cur: + for pkey, d in records: + cur.execute( + self._insert_record_sql.format( + table=prefix.raw_table_name, + ).as_string(), + [pkey, d if isinstance(d, str) else d.decode("utf-8")], + ) + if not on_processed(): + return + conn.commit() diff --git a/src/ldlite/_jsonx.py b/src/ldlite/_jsonx.py index efce0d3..4983e94 100644 --- a/src/ldlite/_jsonx.py +++ b/src/ldlite/_jsonx.py @@ -86,93 +86,10 @@ def __repr__(self) -> str: ) -def _old_jtable(table: str) -> str: - return table + "_jtable" - - def _tcatalog(table: str) -> str: return table + "__tcatalog" -# noinspection DuplicatedCode -def _old_drop_json_tables(db: dbapi.DBAPIConnection, table: str) -> None: - jtable_sql = sqlid(_old_jtable(table)) - cur = db.cursor() - try: - cur.execute("SELECT table_name FROM " + jtable_sql) - rows = list(cur.fetchall()) - for row in rows: - t = row[0] - cur2 = db.cursor() - try: - cur2.execute("DROP TABLE " + sqlid(t)) - except (psycopg.Error, duckdb.CatalogException, sqlite3.OperationalError): - continue - finally: - cur2.close() - except ( - psycopg.errors.UndefinedTable, - sqlite3.OperationalError, - duckdb.CatalogException, - ): - pass - finally: - cur.close() - cur = db.cursor() - try: - cur.execute("DROP TABLE " + jtable_sql) - except ( - psycopg.errors.UndefinedTable, - duckdb.CatalogException, - sqlite3.OperationalError, - ): - pass - finally: - cur.close() - - -# noinspection DuplicatedCode -def drop_json_tables(db: dbapi.DBAPIConnection, table: str) -> None: - tcatalog_sql = sqlid(_tcatalog(table)) - cur = db.cursor() - try: - cur.execute("SELECT table_name FROM " + tcatalog_sql) - rows = list(cur.fetchall()) - for row in rows: - t = row[0] - cur2 = db.cursor() - try: - cur2.execute("DROP TABLE " + sqlid(t)) - except ( - psycopg.errors.UndefinedTable, - duckdb.CatalogException, - sqlite3.OperationalError, - ): - continue - finally: - cur2.close() - except ( - psycopg.errors.UndefinedTable, - duckdb.CatalogException, - sqlite3.OperationalError, - ): - pass - finally: - cur.close() - cur = db.cursor() - try: - cur.execute("DROP TABLE " + tcatalog_sql) - except ( - psycopg.errors.UndefinedTable, - duckdb.CatalogException, - sqlite3.OperationalError, - ): - pass - finally: - cur.close() - _old_drop_json_tables(db, table) - - def _table_name(parents: list[tuple[int, str]]) -> str: j = len(parents) while j > 0 and parents[j - 1][0] == 0: diff --git a/src/ldlite/_sqlx.py b/src/ldlite/_sqlx.py index 6277fda..e87a38f 100644 --- a/src/ldlite/_sqlx.py +++ b/src/ldlite/_sqlx.py @@ -1,18 +1,17 @@ from __future__ import annotations import secrets +import sqlite3 from enum import Enum -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Callable, cast +import duckdb +import psycopg from psycopg import sql from ._database import Database if TYPE_CHECKING: - import sqlite3 - - import duckdb - import psycopg from _typeshed import dbapi from ._jsonx import JsonValue @@ -26,9 +25,23 @@ class DBType(Enum): class DBTypeDatabase(Database["dbapi.DBAPIConnection"]): - def __init__(self, dbtype: DBType, db: dbapi.DBAPIConnection): + def __init__(self, dbtype: DBType, factory: Callable[[], dbapi.DBAPIConnection]): self._dbtype = dbtype - super().__init__(lambda: db) + super().__init__(factory) + + @property + def _missing_table_error(self) -> tuple[type[Exception], ...]: + return ( + psycopg.errors.UndefinedTable, + sqlite3.OperationalError, + duckdb.CatalogException, + ) + + def _rollback(self, conn: dbapi.DBAPIConnection) -> None: + if sql3db := as_sqlite(conn, self._dbtype): + sql3db.rollback() + if pgdb := as_postgres(conn, self._dbtype): + pgdb.rollback() @property def _create_raw_table_sql(self) -> sql.SQL: