diff --git a/CHANGELOG.md b/CHANGELOG.md index 42beae1..494909c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ Please see [MIGRATING.md](./MIGRATING.md) for information on breaking changes. - psycopg3 is now used for internal operations. LDLite.connect_db_postgres will return a psycopg3 connection instead of psycopg2 in the next major release. - psycopg2 is now installed using the binary version. - Refactored internal database handling logic +- Ingesting data into postgres now uses COPY FROM which significantly improves the download performance. ### Removed diff --git a/src/ldlite/__init__.py b/src/ldlite/__init__.py index 063bea9..59381f4 100644 --- a/src/ldlite/__init__.py +++ b/src/ldlite/__init__.py @@ -62,6 +62,8 @@ from ._xlsx import to_xlsx if TYPE_CHECKING: + from collections.abc import Iterator + from _typeshed import dbapi from httpx_folio.query import QueryType @@ -362,7 +364,7 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 self.page_size, query=cast("QueryType", query), ) - (total_records, _) = next(records) + total_records = cast("int", next(records)) total = min(total_records, limit or total_records) if self._verbose: print("ldlite: estimated row count: " + str(total), file=sys.stderr) @@ -403,7 +405,7 @@ def on_processed_limit() -> bool: self._db.ingest_records( prefix, on_processed_limit if limit is not None else on_processed, - records, + cast("Iterator[tuple[bytes, bytes] | tuple[int, str]]", records), ) pbar.close() diff --git a/src/ldlite/_database.py b/src/ldlite/_database.py index 0024e4f..a18a0bf 100644 --- a/src/ldlite/_database.py +++ b/src/ldlite/_database.py @@ -173,18 +173,31 @@ def ingest_records( self, prefix: Prefix, on_processed: Callable[[], bool], - records: Iterator[tuple[int, str | bytes]], + records: Iterator[tuple[bytes, bytes] | tuple[int, str]], ) -> None: with closing(self._conn_factory()) as conn: self._prepare_raw_table(conn, prefix) + insert_sql = self._insert_record_sql.format( + table=prefix.raw_table_name, + ).as_string() with closing(conn.cursor()) as cur: - for pkey, d in records: - cur.execute( - self._insert_record_sql.format( - table=prefix.raw_table_name, - ).as_string(), - [pkey, d if isinstance(d, str) else d.decode("utf-8")], - ) - if not on_processed(): - return + fr = next(records) + if isinstance(fr[0], bytes): + record = fr + while record is not None: + (pkey, rb) = record + cur.execute( + insert_sql, + (int.from_bytes(pkey, "big"), rb.decode()), + ) + if not on_processed(): + break + record = cast("tuple[bytes, bytes]", next(records, None)) + else: + cur.execute(insert_sql, fr) + for r in records: + cur.execute(insert_sql, r) + if not on_processed(): + break + conn.commit() diff --git a/src/ldlite/_folio.py b/src/ldlite/_folio.py index cea5d7d..b6e4cd4 100644 --- a/src/ldlite/_folio.py +++ b/src/ldlite/_folio.py @@ -28,7 +28,7 @@ def iterate_records( retries: int, page_size: int, query: QueryType | None = None, - ) -> Iterator[tuple[int, str | bytes]]: + ) -> Iterator[int | tuple[bytes, bytes] | tuple[int, str]]: """Iterates all records for a given path. Returns: @@ -54,7 +54,7 @@ def iterate_records( res.raise_for_status() j = orjson.loads(res.text) r = int(j["totalRecords"]) - yield (r, b"") + yield r if r == 0: return @@ -103,7 +103,7 @@ def iterate_records( last = None for r in (o for o in orjson.loads(res.text)[key] if o is not None): last = r - yield (next(pkey), orjson.dumps(r)) + yield (next(pkey).to_bytes(4, "big"), orjson.dumps(r)) if last is None: return diff --git a/src/ldlite/_sqlx.py b/src/ldlite/_sqlx.py index e87a38f..18743d1 100644 --- a/src/ldlite/_sqlx.py +++ b/src/ldlite/_sqlx.py @@ -2,6 +2,7 @@ import secrets import sqlite3 +from contextlib import closing from enum import Enum from typing import TYPE_CHECKING, Callable, cast @@ -12,8 +13,11 @@ from ._database import Database if TYPE_CHECKING: + from collections.abc import Iterator + from _typeshed import dbapi + from ._database import Prefix from ._jsonx import JsonValue @@ -69,6 +73,54 @@ def _insert_record_sql(self) -> sql.SQL: return sql.SQL(insert_sql) + def ingest_records( + self, + prefix: Prefix, + on_processed: Callable[[], bool], + records: Iterator[tuple[bytes, bytes] | tuple[int, str]], + ) -> None: + if self._dbtype != DBType.POSTGRES: + super().ingest_records(prefix, on_processed, records) + return + + with closing(self._conn_factory()) as conn: + self._prepare_raw_table(conn, prefix) + + fr = next(records) + copy_from = "COPY {table} (__id, jsonb) FROM STDIN" + if is_bytes := isinstance(fr[0], bytes): + copy_from += " (FORMAT BINARY)" + + if pgconn := as_postgres(conn, self._dbtype): + with ( + pgconn.cursor() as cur, + cur.copy( + sql.SQL(copy_from).format(table=prefix.raw_table_name), + ) as copy, + ): + if is_bytes: + # postgres jsonb is always version 1 + # and it always goes in front + jver = (1).to_bytes(1, "big") + record = fr + while record is not None: + pkey, rb = record + rbpg = bytearray() + rbpg.extend(jver) + rbpg.extend(cast("bytes", rb)) + copy.write_row((pkey, rbpg)) + if not on_processed(): + break + record = cast("tuple[bytes, bytes]", next(records, None)) + else: + copy.write_row(fr) + for r in records: + copy.write_row(r) + if not on_processed(): + break + + pgconn.commit() + def as_duckdb( db: dbapi.DBAPIConnection,