3838
3939import sqlite3
4040import sys
41+ from itertools import count
4142from typing import TYPE_CHECKING , NoReturn , cast
4243
4344import duckdb
45+ import psycopg
4446import psycopg2
4547from httpx_folio .auth import FolioParams
4648from tqdm import tqdm
4749
4850from ._csv import to_csv
51+ from ._database import Prefix
4952from ._folio import FolioClient
50- from ._jsonx import Attr , drop_json_tables , transform_json
53+ from ._jsonx import Attr , transform_json
5154from ._select import select
52- from ._sqlx import DBType , as_postgres , autocommit , encode_sql_str , json_type , sqlid
55+ from ._sqlx import (
56+ DBType ,
57+ DBTypeDatabase ,
58+ as_postgres ,
59+ autocommit ,
60+ sqlid ,
61+ )
5362from ._xlsx import to_xlsx
5463
5564if TYPE_CHECKING :
@@ -73,6 +82,7 @@ def __init__(self) -> None:
7382 self ._quiet = False
7483 self .dbtype : DBType = DBType .UNDEFINED
7584 self .db : dbapi .DBAPIConnection | None = None
85+ self ._db : DBTypeDatabase | None = None
7686 self ._folio : FolioClient | None = None
7787 self .page_size = 1000
7888 self ._okapi_timeout = 60
@@ -122,8 +132,13 @@ def _connect_db_duckdb(
122132 self .dbtype = DBType .DUCKDB
123133 fn = filename if filename is not None else ":memory:"
124134 db = duckdb .connect (database = fn )
125- self .db = cast ("dbapi.DBAPIConnection" , db )
126- return db
135+ self .db = cast ("dbapi.DBAPIConnection" , db .cursor ())
136+ self ._db = DBTypeDatabase (
137+ DBType .DUCKDB ,
138+ lambda : cast ("dbapi.DBAPIConnection" , db .cursor ()),
139+ )
140+
141+ return db .cursor ()
127142
128143 def connect_db_postgresql (self , dsn : str ) -> psycopg2 .extensions .connection :
129144 """Connects to a PostgreSQL database for storing data.
@@ -132,15 +147,24 @@ def connect_db_postgresql(self, dsn: str) -> psycopg2.extensions.connection:
132147 connection to the database which can be used to submit SQL queries.
133148 The returned connection defaults to autocommit mode.
134149
150+ This will return a psycopg3 connection in the next major release of LDLite.
151+
135152 Example:
136153 db = ld.connect_db_postgresql(dsn='dbname=ld host=localhost user=ldlite')
137154
138155 """
139156 self .dbtype = DBType .POSTGRES
140- db = psycopg2 .connect (dsn )
157+ db = psycopg .connect (dsn )
141158 self .db = cast ("dbapi.DBAPIConnection" , db )
142- autocommit (self .db , self .dbtype , True )
143- return db
159+ self ._db = DBTypeDatabase (
160+ DBType .POSTGRES ,
161+ lambda : cast ("dbapi.DBAPIConnection" , psycopg .connect (dsn )),
162+ )
163+
164+ ret_db = psycopg2 .connect (dsn )
165+ ret_db .rollback ()
166+ ret_db .set_session (autocommit = True )
167+ return ret_db
144168
145169 def experimental_connect_db_sqlite (
146170 self ,
@@ -163,9 +187,15 @@ def experimental_connect_db_sqlite(
163187
164188 """
165189 self .dbtype = DBType .SQLITE
166- fn = filename if filename is not None else ": memory:"
190+ fn = filename if filename is not None else "file:: memory:?cache=shared "
167191 self .db = sqlite3 .connect (fn )
168- autocommit (self .db , self .dbtype , True )
192+ self ._db = DBTypeDatabase (
193+ DBType .SQLITE ,
194+ lambda : cast ("dbapi.DBAPIConnection" , sqlite3 .connect (fn )),
195+ )
196+
197+ db = sqlite3 .connect (fn )
198+ autocommit (db , self .dbtype , True )
169199 return self .db
170200
171201 def _check_folio (self ) -> None :
@@ -206,22 +236,16 @@ def drop_tables(self, table: str) -> None:
206236 ld.drop_tables('g')
207237
208238 """
209- if self .db is None :
239+ if self .db is None or self . _db is None :
210240 self ._check_db ()
211241 return
212- autocommit (self .db , self .dbtype , True )
213242 schema_table = table .strip ().split ("." )
214- if len (schema_table ) < 1 or len (schema_table ) > 2 :
243+ if len (schema_table ) != 1 and len (schema_table ) != 2 :
215244 raise ValueError ("invalid table name: " + table )
216- self ._check_db ()
217- cur = self .db .cursor ()
218- try :
219- cur .execute ("DROP TABLE IF EXISTS " + sqlid (table ))
220- except (RuntimeError , psycopg2 .Error ):
221- pass
222- finally :
223- cur .close ()
224- drop_json_tables (self .db , table )
245+ if len (schema_table ) == 2 and self .dbtype == DBType .SQLITE :
246+ table = schema_table [0 ] + "_" + schema_table [1 ]
247+ prefix = Prefix (table )
248+ self ._db .drop_prefix (prefix )
225249
226250 def set_folio_max_retries (self , max_retries : int ) -> None :
227251 """Sets the maximum number of retries for FOLIO requests.
@@ -321,32 +345,15 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915
321345 if self ._folio is None :
322346 self ._check_folio ()
323347 return []
324- if self .db is None :
348+ if self .db is None or self . _db is None :
325349 self ._check_db ()
326350 return []
327351 if len (schema_table ) == 2 and self .dbtype == DBType .SQLITE :
328352 table = schema_table [0 ] + "_" + schema_table [1 ]
329- schema_table = [ table ]
353+ prefix = Prefix ( table )
330354 if not self ._quiet :
331355 print ("ldlite: querying: " + path , file = sys .stderr )
332- drop_json_tables (self .db , table )
333- autocommit (self .db , self .dbtype , False )
334356 try :
335- cur = self .db .cursor ()
336- try :
337- if len (schema_table ) == 2 :
338- cur .execute ("CREATE SCHEMA IF NOT EXISTS " + sqlid (schema_table [0 ]))
339- cur .execute ("DROP TABLE IF EXISTS " + sqlid (table ))
340- cur .execute (
341- "CREATE TABLE "
342- + sqlid (table )
343- + "(__id integer, jsonb "
344- + json_type (self .dbtype )
345- + ")" ,
346- )
347- finally :
348- cur .close ()
349- self .db .commit ()
350357 # First get total number of records
351358 records = self ._folio .iterate_records (
352359 path ,
@@ -355,70 +362,61 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915
355362 self .page_size ,
356363 query = cast ("QueryType" , query ),
357364 )
358-
359365 (total_records , _ ) = next (records )
360- total = total_records if total_records is not None else 0
366+ total = min ( total_records , limit or total_records )
361367 if self ._verbose :
362368 print ("ldlite: estimated row count: " + str (total ), file = sys .stderr )
363- # Read result pages
364- pbar = None
365- pbartotal = 0
369+
370+ p_count = count (1 )
371+ processed = 0
372+ pbar : tqdm | PbarNoop # type:ignore[type-arg]
366373 if not self ._quiet :
367- if total == - 1 :
368- pbar = tqdm (
369- desc = "reading" ,
370- leave = False ,
371- mininterval = 3 ,
372- smoothing = 0 ,
373- colour = "#A9A9A9" ,
374- bar_format = "{desc} {elapsed} {bar}{postfix}" ,
375- )
376- else :
377- pbar = tqdm (
378- desc = "reading" ,
379- total = total ,
380- leave = False ,
381- mininterval = 3 ,
382- smoothing = 0 ,
383- colour = "#A9A9A9" ,
384- bar_format = "{desc} {bar}{postfix}" ,
385- )
386- cur = self .db .cursor ()
387- try :
388- count = 0
389- for pkey , d in records :
390- cur .execute (
391- "INSERT INTO "
392- + sqlid (table )
393- + " VALUES("
394- + str (pkey )
395- + ","
396- + encode_sql_str (self .dbtype , d )
397- + ")" ,
398- )
399- count += 1
400- if pbar is not None :
401- if pbartotal + 1 > total :
402- pbartotal = total
403- pbar .update (total - pbartotal )
404- else :
405- pbartotal += 1
406- pbar .update (1 )
407- if limit is not None and count == limit :
408- break
409- finally :
410- cur .close ()
411- if pbar is not None :
412- pbar .close ()
413- self .db .commit ()
374+ pbar = tqdm (
375+ desc = "reading" ,
376+ total = total ,
377+ leave = False ,
378+ mininterval = 3 ,
379+ smoothing = 0 ,
380+ colour = "#A9A9A9" ,
381+ bar_format = "{desc} {bar}{postfix}" ,
382+ )
383+ else :
384+
385+ class PbarNoop :
386+ def update (self , _ : int ) -> None : ...
387+ def close (self ) -> None : ...
388+
389+ pbar = PbarNoop ()
390+
391+ def on_processed () -> bool :
392+ pbar .update (1 )
393+ nonlocal processed
394+ processed = next (p_count )
395+ return True
396+
397+ def on_processed_limit () -> bool :
398+ pbar .update (1 )
399+ nonlocal processed , limit
400+ processed = next (p_count )
401+ return limit is None or processed < limit
402+
403+ self ._db .ingest_records (
404+ prefix ,
405+ on_processed_limit if limit is not None else on_processed ,
406+ records ,
407+ )
408+ pbar .close ()
409+
410+ self ._db .drop_extracted_tables (prefix )
414411 newtables = [table ]
415412 newattrs = {}
416413 if json_depth > 0 :
414+ autocommit (self .db , self .dbtype , False )
417415 jsontables , jsonattrs = transform_json (
418416 self .db ,
419417 self .dbtype ,
420418 table ,
421- count ,
419+ processed ,
422420 self ._quiet ,
423421 json_depth ,
424422 )
@@ -429,12 +427,7 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915
429427 newattrs [table ] = {"__id" : Attr ("__id" , "bigint" )}
430428
431429 if not keep_raw :
432- cur = self .db .cursor ()
433- try :
434- cur .execute ("DROP TABLE " + sqlid (table ))
435- self .db .commit ()
436- finally :
437- cur .close ()
430+ self ._db .drop_raw_table (prefix )
438431
439432 finally :
440433 autocommit (self .db , self .dbtype , True )
@@ -459,22 +452,18 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915
459452 colour = "#A9A9A9" ,
460453 bar_format = "{desc} {bar}{postfix}" ,
461454 )
462- pbartotal = 0
463455 for t , attr in indexable_attrs :
464456 cur = self .db .cursor ()
465457 try :
466458 cur .execute (
467459 "CREATE INDEX ON " + sqlid (t ) + " (" + sqlid (attr .name ) + ")" ,
468460 )
469- except (RuntimeError , psycopg2 .Error ):
461+ except (RuntimeError , psycopg .Error ):
470462 pass
471463 finally :
472464 cur .close ()
473- if pbar is not None :
474- pbartotal += 1
475- pbar .update (1 )
476- if pbar is not None :
477- pbar .close ()
465+ pbar .update (1 )
466+ pbar .close ()
478467 # Return table names
479468 if not self ._quiet :
480469 print ("ldlite: created tables: " + ", " .join (newtables ), file = sys .stderr )
0 commit comments