diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 69ffc80b..01863574 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -53,7 +53,7 @@ jobs: pip install "setuptools>=64" --upgrade # Install package in editable mode. - pip install --use-pep517 --prefer-binary --editable=.[test,develop] + pip install --use-pep517 --prefer-binary --editable=.[io,test,develop] - name: Run linter and software tests run: | diff --git a/CHANGES.md b/CHANGES.md index 853f9c95..fc0f9e80 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -4,6 +4,7 @@ ## Unreleased - Add SQL runner utility primitives to `io.sql` namespace +- Add `import_csv_pandas` and `import_csv_dask` utility primitives ## 2023/11/06 v0.0.2 diff --git a/cratedb_toolkit/util/database.py b/cratedb_toolkit/util/database.py index 7ef709ac..e5d435d4 100644 --- a/cratedb_toolkit/util/database.py +++ b/cratedb_toolkit/util/database.py @@ -1,6 +1,7 @@ # Copyright (c) 2023, Crate.io Inc. # Distributed under the terms of the AGPLv3 license, see LICENSE. import io +import os import typing as t from pathlib import Path @@ -194,6 +195,60 @@ def ensure_repository_az( """ self.run_sql(sql) + def import_csv_pandas( + self, filepath: t.Union[str, Path], tablename: str, index=False, chunksize=1000, if_exists="replace" + ): + """ + Import CSV data using pandas. + """ + import pandas as pd + from crate.client.sqlalchemy.support import insert_bulk + + df = pd.read_csv(filepath) + with self.engine.connect() as connection: + return df.to_sql( + tablename, connection, index=index, chunksize=chunksize, if_exists=if_exists, method=insert_bulk + ) + + def import_csv_dask( + self, + filepath: t.Union[str, Path], + tablename: str, + index=False, + chunksize=1000, + if_exists="replace", + npartitions: int = None, + progress: bool = False, + ): + """ + Import CSV data using Dask. + """ + import dask.dataframe as dd + import pandas as pd + from crate.client.sqlalchemy.support import insert_bulk + + # Set a few defaults. + npartitions = npartitions or os.cpu_count() + + if progress: + from dask.diagnostics import ProgressBar + + pbar = ProgressBar() + pbar.register() + + # Load data into database. + df = pd.read_csv(filepath) + ddf = dd.from_pandas(df, npartitions=npartitions) + return ddf.to_sql( + tablename, + uri=self.dburi, + index=index, + chunksize=chunksize, + if_exists=if_exists, + method=insert_bulk, + parallel=True, + ) + def sa_is_empty(thing): """ diff --git a/pyproject.toml b/pyproject.toml index 49d257a9..7592cf46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,10 @@ develop = [ "ruff==0.1.3", "validate-pyproject<0.16", ] +io = [ + "dask<=2023.10.1,>=2020", + "pandas<3,>=1", +] release = [ "build<2", "twine<5", diff --git a/release/oci/Dockerfile b/release/oci/Dockerfile index da865623..ecb38a52 100644 --- a/release/oci/Dockerfile +++ b/release/oci/Dockerfile @@ -21,7 +21,7 @@ COPY . /src # Install package. RUN --mount=type=cache,id=pip,target=/root/.cache/pip \ - pip install --use-pep517 --prefer-binary '/src' + pip install --use-pep517 --prefer-binary '/src[io]' # Uninstall Git again. RUN apt-get --yes remove --purge git && apt-get --yes autoremove diff --git a/tests/io/test_import.py b/tests/io/test_import.py new file mode 100644 index 00000000..cf126d8b --- /dev/null +++ b/tests/io/test_import.py @@ -0,0 +1,49 @@ +import pytest + + +@pytest.fixture +def dummy_csv(tmp_path): + """ + Provide a dummy CSV file to the test cases. + """ + csvfile = tmp_path / "dummy.csv" + csvfile.write_text("name,value\ntemperature,42.42\nhumidity,84.84") + return csvfile + + +def test_import_csv_pandas(cratedb, dummy_csv): + """ + Invoke convenience function `import_csv_pandas`, and verify database content. + """ + result = cratedb.database.import_csv_pandas(filepath=dummy_csv, tablename="foobar") + assert result is None + + cratedb.database.run_sql("REFRESH TABLE foobar;") + result = cratedb.database.run_sql("SELECT COUNT(*) FROM foobar;") + assert result == [(2,)] + + +def test_import_csv_dask(cratedb, dummy_csv): + """ + Invoke convenience function `import_csv_dask`, and verify database content. + """ + result = cratedb.database.import_csv_dask(filepath=dummy_csv, tablename="foobar") + assert result is None + + cratedb.database.run_sql("REFRESH TABLE foobar;") + result = cratedb.database.run_sql("SELECT COUNT(*) FROM foobar;") + assert result == [(2,)] + + +def test_import_csv_dask_with_progressbar(cratedb, dummy_csv): + """ + Invoke convenience function `import_csv_dask`, and verify database content. + This time, use `progress=True` to make Dask display a progress bar. + However, the code does not verify it. + """ + result = cratedb.database.import_csv_dask(filepath=dummy_csv, tablename="foobar", progress=True) + assert result is None + + cratedb.database.run_sql("REFRESH TABLE foobar;") + result = cratedb.database.run_sql("SELECT COUNT(*) FROM foobar;") + assert result == [(2,)]