Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 56 additions & 42 deletions src/ldlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from ._csv import to_csv
from ._database import Prefix
from ._folio import FolioClient
from ._jsonx import Attr, drop_json_tables, transform_json
from ._jsonx import Attr, transform_json
from ._select import select
from ._sqlx import (
DBType,
Expand Down Expand Up @@ -82,6 +82,7 @@ def __init__(self) -> None:
self._quiet = False
self.dbtype: DBType = DBType.UNDEFINED
self.db: dbapi.DBAPIConnection | None = None
self._db: DBTypeDatabase | None = None
self._folio: FolioClient | None = None
self.page_size = 1000
self._okapi_timeout = 60
Expand Down Expand Up @@ -132,6 +133,11 @@ def _connect_db_duckdb(
fn = filename if filename is not None else ":memory:"
db = duckdb.connect(database=fn)
self.db = cast("dbapi.DBAPIConnection", db.cursor())
self._db = DBTypeDatabase(
DBType.DUCKDB,
lambda: cast("dbapi.DBAPIConnection", db.cursor()),
)

return db.cursor()

def connect_db_postgresql(self, dsn: str) -> psycopg2.extensions.connection:
Expand All @@ -150,7 +156,10 @@ def connect_db_postgresql(self, dsn: str) -> psycopg2.extensions.connection:
self.dbtype = DBType.POSTGRES
db = psycopg.connect(dsn)
self.db = cast("dbapi.DBAPIConnection", db)
autocommit(self.db, self.dbtype, True)
self._db = DBTypeDatabase(
DBType.POSTGRES,
lambda: cast("dbapi.DBAPIConnection", psycopg.connect(dsn)),
)

ret_db = psycopg2.connect(dsn)
ret_db.rollback()
Expand Down Expand Up @@ -180,6 +189,10 @@ def experimental_connect_db_sqlite(
self.dbtype = DBType.SQLITE
fn = filename if filename is not None else "file::memory:?cache=shared"
self.db = sqlite3.connect(fn)
self._db = DBTypeDatabase(
DBType.SQLITE,
lambda: cast("dbapi.DBAPIConnection", sqlite3.connect(fn)),
)

db = sqlite3.connect(fn)
autocommit(db, self.dbtype, True)
Expand Down Expand Up @@ -223,22 +236,16 @@ def drop_tables(self, table: str) -> None:
ld.drop_tables('g')

"""
if self.db is None:
if self.db is None or self._db is None:
self._check_db()
return
autocommit(self.db, self.dbtype, True)
schema_table = table.strip().split(".")
if len(schema_table) < 1 or len(schema_table) > 2:
if len(schema_table) != 1 and len(schema_table) != 2:
raise ValueError("invalid table name: " + table)
self._check_db()
cur = self.db.cursor()
try:
cur.execute("DROP TABLE IF EXISTS " + sqlid(table))
except (RuntimeError, psycopg2.Error):
pass
finally:
cur.close()
drop_json_tables(self.db, table)
if len(schema_table) == 2 and self.dbtype == DBType.SQLITE:
table = schema_table[0] + "_" + schema_table[1]
prefix = Prefix(table)
self._db.drop_prefix(prefix)

def set_folio_max_retries(self, max_retries: int) -> None:
"""Sets the maximum number of retries for FOLIO requests.
Expand Down Expand Up @@ -338,16 +345,14 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915
if self._folio is None:
self._check_folio()
return []
if self.db is None:
if self.db is None or self._db is None:
self._check_db()
return []
if len(schema_table) == 2 and self.dbtype == DBType.SQLITE:
table = schema_table[0] + "_" + schema_table[1]
schema_table = [table]
prefix = Prefix(table)
if not self._quiet:
print("ldlite: querying: " + path, file=sys.stderr)
drop_json_tables(self.db, table)
autocommit(self.db, self.dbtype, False)
try:
# First get total number of records
records = self._folio.iterate_records(
Expand All @@ -362,8 +367,9 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915
if self._verbose:
print("ldlite: estimated row count: " + str(total), file=sys.stderr)

processed = count(0)
pbar = None
p_count = count(0)
processed = 0
pbar: tqdm | PbarNoop # type:ignore[type-arg]
if not self._quiet:
pbar = tqdm(
desc="reading",
Expand All @@ -374,28 +380,43 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915
colour="#A9A9A9",
bar_format="{desc} {bar}{postfix}",
)
else:

class PbarNoop:
def update(self, _: int) -> None: ...
def close(self) -> None: ...

pbar = PbarNoop()

def on_processed() -> bool:
if pbar is not None:
pbar.update(1)
p = next(processed)
return limit is None or p >= limit

cur = self.db.cursor()
db = DBTypeDatabase(self.dbtype, self.db)
db.ingest_records(self.db, Prefix(table), on_processed, records)
self.db.commit()
if pbar is not None:
pbar.close()
pbar.update(1)
nonlocal processed
processed = next(p_count)
return True

def on_processed_limit() -> bool:
pbar.update(1)
nonlocal processed
processed = next(p_count)
return limit is None or processed >= limit

self._db.ingest_records(
prefix,
on_processed_limit if limit is not None else on_processed,
records,
)
pbar.close()

self._db.drop_extracted_tables(prefix)
newtables = [table]
newattrs = {}
if json_depth > 0:
autocommit(self.db, self.dbtype, False)
jsontables, jsonattrs = transform_json(
self.db,
self.dbtype,
table,
next(processed) - 1,
processed,
self._quiet,
json_depth,
)
Expand All @@ -406,12 +427,7 @@ def on_processed() -> bool:
newattrs[table] = {"__id": Attr("__id", "bigint")}

if not keep_raw:
cur = self.db.cursor()
try:
cur.execute("DROP TABLE " + sqlid(table))
self.db.commit()
finally:
cur.close()
self._db.drop_raw_table(prefix)

finally:
autocommit(self.db, self.dbtype, True)
Expand Down Expand Up @@ -446,10 +462,8 @@ def on_processed() -> bool:
pass
finally:
cur.close()
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.close()
pbar.update(1)
pbar.close()
# Return table names
if not self._quiet:
print("ldlite: created tables: " + ", ".join(newtables), file=sys.stderr)
Expand Down
Loading
Loading