From b84e0e19c8102de6e43129bc91af7e5bd201dd5b Mon Sep 17 00:00:00 2001 From: Chris Markiewicz Date: Sun, 18 Dec 2022 11:16:23 -0500 Subject: [PATCH 1/3] RF: Write DFT database manager as object This adds a dft._DB class that handles the _init_db and _db_(no)change functions. The default instance remains at dft.DB, but this allows us to create new instances for testing purposes. --- nibabel/dft.py | 114 +++++++++++++++++++++++++++---------------------- 1 file changed, 62 insertions(+), 52 deletions(-) diff --git a/nibabel/dft.py b/nibabel/dft.py index f47d70ccb6..51b6424a84 100644 --- a/nibabel/dft.py +++ b/nibabel/dft.py @@ -11,6 +11,7 @@ """ +import contextlib import os from os.path import join as pjoin import tempfile @@ -74,7 +75,7 @@ def __getattribute__(self, name): val = object.__getattribute__(self, name) if name == 'series' and val is None: val = [] - with _db_nochange() as c: + with DB.readonly_cursor() as c: c.execute("SELECT * FROM series WHERE study = ?", (self.uid, )) cols = [el[0] for el in c.description] for row in c: @@ -106,7 +107,7 @@ def __getattribute__(self, name): val = object.__getattribute__(self, name) if name == 'storage_instances' and val is None: val = [] - with _db_nochange() as c: + with DB.readonly_cursor() as c: query = """SELECT * FROM storage_instance WHERE series = ? @@ -227,7 +228,7 @@ def __init__(self, d): def __getattribute__(self, name): val = object.__getattribute__(self, name) if name == 'files' and val is None: - with _db_nochange() as c: + with DB.readonly_cursor() as c: query = """SELECT directory, name FROM file WHERE storage_instance = ? @@ -241,34 +242,6 @@ def dicom(self): return pydicom.read_file(self.files[0]) -class _db_nochange: - """context guard for read-only database access""" - - def __enter__(self): - self.c = DB.cursor() - return self.c - - def __exit__(self, type, value, traceback): - if type is None: - self.c.close() - DB.rollback() - - -class _db_change: - """context guard for database access requiring a commit""" - - def __enter__(self): - self.c = DB.cursor() - return self.c - - def __exit__(self, type, value, traceback): - if type is None: - self.c.close() - DB.commit() - else: - DB.rollback() - - def _get_subdirs(base_dir, files_dict=None, followlinks=False): dirs = [] for (dirpath, dirnames, filenames) in os.walk(base_dir, followlinks=followlinks): @@ -288,7 +261,7 @@ def update_cache(base_dir, followlinks=False): for d in dirs: os.stat(d) mtimes[d] = os.stat(d).st_mtime - with _db_nochange() as c: + with DB.readwrite_cursor() as c: c.execute("SELECT path, mtime FROM directory") db_mtimes = dict(c) c.execute("SELECT uid FROM study") @@ -297,7 +270,6 @@ def update_cache(base_dir, followlinks=False): series = [row[0] for row in c] c.execute("SELECT uid FROM storage_instance") storage_instances = [row[0] for row in c] - with _db_change() as c: for dir in sorted(mtimes.keys()): if dir in db_mtimes and mtimes[dir] <= db_mtimes[dir]: continue @@ -316,7 +288,7 @@ def get_studies(base_dir=None, followlinks=False): if base_dir is not None: update_cache(base_dir, followlinks) if base_dir is None: - with _db_nochange() as c: + with DB.readonly_cursor() as c: c.execute("SELECT * FROM study") studies = [] cols = [el[0] for el in c.description] @@ -331,7 +303,7 @@ def get_studies(base_dir=None, followlinks=False): WHERE uid IN (SELECT storage_instance FROM file WHERE directory = ?))""" - with _db_nochange() as c: + with DB.readonly_cursor() as c: study_uids = {} for dir in _get_subdirs(base_dir, followlinks=followlinks): c.execute(query, (dir, )) @@ -443,7 +415,7 @@ def _update_file(c, path, fname, studies, series, storage_instances): def clear_cache(): - with _db_change() as c: + with DB.readwrite_cursor() as c: c.execute("DELETE FROM file") c.execute("DELETE FROM directory") c.execute("DELETE FROM storage_instance") @@ -478,26 +450,64 @@ def clear_cache(): mtime INTEGER NOT NULL, storage_instance TEXT DEFAULT NULL REFERENCES storage_instance, PRIMARY KEY (directory, name))""") -DB_FNAME = pjoin(tempfile.gettempdir(), f'dft.{getpass.getuser()}.sqlite') -DB = None -def _init_db(verbose=True): - """ Initialize database """ - if verbose: - logger.info('db filename: ' + DB_FNAME) - global DB - DB = sqlite3.connect(DB_FNAME, check_same_thread=False) - with _db_change() as c: - c.execute("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'") - if c.fetchone()[0] == 0: - logger.debug('create') - for q in CREATE_QUERIES: - c.execute(q) +class _DB: + def __init__(self, fname=None, verbose=True): + self.fname = fname or pjoin(tempfile.gettempdir(), f'dft.{getpass.getuser()}.sqlite') + self.verbose = verbose + + @property + def session(self): + """Get sqlite3 Connection + + The connection is created on the first call of this property + """ + try: + return self._session + except AttributeError: + self._init_db() + return self._session + + def _init_db(self): + if self.verbose: + logger.info('db filename: ' + self.fname) + + self._session = sqlite3.connect(self.fname, isolation_level="EXCLUSIVE") + with self.readwrite_cursor() as c: + c.execute("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'") + if c.fetchone()[0] == 0: + logger.debug('create') + for q in CREATE_QUERIES: + c.execute(q) + + def __repr__(self): + return f"" + + @contextlib.contextmanager + def readonly_cursor(self): + cursor = self.session.cursor() + try: + yield cursor + finally: + cursor.close() + self.session.rollback() + + @contextlib.contextmanager + def readwrite_cursor(self): + cursor = self.session.cursor() + try: + yield cursor + except Exception: + self.session.rollback() + raise + finally: + cursor.close() + self.session.commit() +DB = None if os.name == 'nt': warnings.warn('dft needs FUSE which is not available for windows') else: - _init_db() -# eof + DB = _DB() From 32e02d728654571c55e34965ada010de09fd1741 Mon Sep 17 00:00:00 2001 From: Chris Markiewicz Date: Mon, 19 Dec 2022 20:14:50 -0500 Subject: [PATCH 2/3] TEST: Create fresh in-memory database for each dft unit test --- nibabel/tests/test_dft.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/nibabel/tests/test_dft.py b/nibabel/tests/test_dft.py index b00c136312..30fafcd8db 100644 --- a/nibabel/tests/test_dft.py +++ b/nibabel/tests/test_dft.py @@ -29,12 +29,21 @@ def setUpModule(): raise unittest.SkipTest('Need pydicom for dft tests, skipping') -def test_init(): +@pytest.fixture +def db(monkeypatch): + """Build a dft database in memory to avoid cross-process races + and not modify the host filesystem.""" + database = dft._DB(fname=":memory:") + monkeypatch.setattr(dft, "DB", database) + yield database + + +def test_init(db): dft.clear_cache() dft.update_cache(data_dir) -def test_study(): +def test_study(db): studies = dft.get_studies(data_dir) assert len(studies) == 1 assert (studies[0].uid == @@ -48,7 +57,7 @@ def test_study(): assert studies[0].patient_sex == 'F' -def test_series(): +def test_series(db): studies = dft.get_studies(data_dir) assert len(studies[0].series) == 1 ser = studies[0].series[0] @@ -62,7 +71,7 @@ def test_series(): assert ser.bits_stored == 12 -def test_storage_instances(): +def test_storage_instances(db): studies = dft.get_studies(data_dir) sis = studies[0].series[0].storage_instances assert len(sis) == 2 @@ -74,19 +83,19 @@ def test_storage_instances(): '1.3.12.2.1107.5.2.32.35119.2010011420300180088599504.1') -def test_storage_instance(): +def test_storage_instance(db): pass @unittest.skipUnless(have_pil, 'could not import PIL.Image') -def test_png(): +def test_png(db): studies = dft.get_studies(data_dir) data = studies[0].series[0].as_png() im = PImage.open(BytesIO(data)) assert im.size == (256, 256) -def test_nifti(): +def test_nifti(db): studies = dft.get_studies(data_dir) data = studies[0].series[0].as_nifti() assert len(data) == 352 + 2 * 256 * 256 * 2 From 32bc89acaa4f5887036aa71f7c7c28092d0d224a Mon Sep 17 00:00:00 2001 From: Chris Markiewicz Date: Tue, 20 Dec 2022 08:33:53 -0500 Subject: [PATCH 3/3] TEST: Test _DB class, increase coverage a bit --- nibabel/tests/test_dft.py | 49 +++++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/nibabel/tests/test_dft.py b/nibabel/tests/test_dft.py index 30fafcd8db..61e031b8d3 100644 --- a/nibabel/tests/test_dft.py +++ b/nibabel/tests/test_dft.py @@ -5,6 +5,7 @@ from os.path import join as pjoin, dirname from io import BytesIO from ..testing import suppress_warnings +import sqlite3 with suppress_warnings(): from .. import dft @@ -29,6 +30,24 @@ def setUpModule(): raise unittest.SkipTest('Need pydicom for dft tests, skipping') +class Test_DBclass: + """Some tests on the database manager class that don't get exercised through the API""" + def setup_method(self): + self._db = dft._DB(fname=":memory:", verbose=False) + + def test_repr(self): + assert repr(self._db) == "" + + def test_cursor_conflict(self): + rwc = self._db.readwrite_cursor + statement = ("INSERT INTO directory (path, mtime) VALUES (?, ?)", ("/tmp", 0)) + with pytest.raises(sqlite3.IntegrityError): + # Whichever exits first will commit and make the second violate uniqueness + with rwc() as c1, rwc() as c2: + c1.execute(*statement) + c2.execute(*statement) + + @pytest.fixture def db(monkeypatch): """Build a dft database in memory to avoid cross-process races @@ -41,20 +60,24 @@ def db(monkeypatch): def test_init(db): dft.clear_cache() dft.update_cache(data_dir) + # Verify a second update doesn't crash + dft.update_cache(data_dir) def test_study(db): - studies = dft.get_studies(data_dir) - assert len(studies) == 1 - assert (studies[0].uid == - '1.3.12.2.1107.5.2.32.35119.30000010011408520750000000022') - assert studies[0].date == '20100114' - assert studies[0].time == '121314.000000' - assert studies[0].comments == 'dft study comments' - assert studies[0].patient_name == 'dft patient name' - assert studies[0].patient_id == '1234' - assert studies[0].patient_birth_date == '19800102' - assert studies[0].patient_sex == 'F' + # First pass updates the cache, second pass reads it out + for base_dir in (data_dir, None): + studies = dft.get_studies(base_dir) + assert len(studies) == 1 + assert (studies[0].uid == + '1.3.12.2.1107.5.2.32.35119.30000010011408520750000000022') + assert studies[0].date == '20100114' + assert studies[0].time == '121314.000000' + assert studies[0].comments == 'dft study comments' + assert studies[0].patient_name == 'dft patient name' + assert studies[0].patient_id == '1234' + assert studies[0].patient_birth_date == '19800102' + assert studies[0].patient_sex == 'F' def test_series(db): @@ -83,10 +106,6 @@ def test_storage_instances(db): '1.3.12.2.1107.5.2.32.35119.2010011420300180088599504.1') -def test_storage_instance(db): - pass - - @unittest.skipUnless(have_pil, 'could not import PIL.Image') def test_png(db): studies = dft.get_studies(data_dir)