From 8b9b6f32e33ead4e7787bd43bd71595654e4ccd9 Mon Sep 17 00:00:00 2001 From: Ben Cumming Date: Thu, 6 Jun 2024 13:06:25 +0200 Subject: [PATCH] fix database issue (#34) There is a mistake in the database design, where if there is more than one image with the same `name/version:tag`, e.g. * `netcdf-tools/2024:v1` on `balfrin` * `netcdf-tools/2024:v1` on `santis` Only one of the instances will be in the in-memory database that is formed when `uenv image find` is called, leading to missing entries. This PR moves the additional data, `uarch` and `system` to the unique uenv identifier table Adds unit tests that are fixed by the changes. --- lib/datastore.py | 95 +++++++++++++++++++++++++++++-------- test/unit/test_datastore.py | 51 +++++++++++++++++--- 2 files changed, 119 insertions(+), 27 deletions(-) diff --git a/lib/datastore.py b/lib/datastore.py index 807cc75..c352d81 100644 --- a/lib/datastore.py +++ b/lib/datastore.py @@ -28,7 +28,8 @@ def __init__(self, message): def __str__(self): return self.message -create_db_command = """ +create_db_commands = { + "v1": """ BEGIN; PRAGMA foreign_keys=on; @@ -81,7 +82,64 @@ def __str__(self): INNER JOIN images ON images.sha256 = tags.sha256; COMMIT; -""" +""", + "v2": """ +BEGIN; + +PRAGMA foreign_keys=on; + +CREATE TABLE images ( + sha256 TEXT PRIMARY KEY CHECK(length(sha256)==64), + id TEXT UNIQUE CHECK(length(id)==16), + date TEXT NOT NULL, + size INTEGER NOT NULL +); + +CREATE TABLE uenv ( + version_id INTEGER PRIMARY KEY, + system TEXT NOT NULL, + uarch TEXT NOT NULL, + name TEXT NOT NULL, + version TEXT NOT NULL, + UNIQUE (system, uarch, name, version) +); + +CREATE TABLE tags ( + version_id INTEGER, + tag TEXT NOT NULL, + sha256 TEXT NOT NULL, + PRIMARY KEY (version_id, tag), + FOREIGN KEY (version_id) + REFERENCES uenv (version_id) + ON DELETE CASCADE + ON UPDATE CASCADE, + FOREIGN KEY (sha256) + REFERENCES images (sha256) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +-- for convenient generation of the Record type used internally by uenv-image +CREATE VIEW records AS +SELECT + uenv.system AS system, + uenv.uarch AS uarch, + uenv.name AS name, + uenv.version AS version, + tags.tag AS tag, + images.date AS date, + images.size AS size, + tags.sha256 AS sha256, + images.id AS id +FROM tags + INNER JOIN uenv ON uenv.version_id = tags.version_id + INNER JOIN images ON images.sha256 = tags.sha256; + +COMMIT; +"""} + +db_version = "v2" +create_db_command = create_db_commands[db_version] class RecordSet(): def __init__(self, records, request): @@ -172,16 +230,16 @@ def add_record(self, r: Record): cursor.execute("BEGIN;") cursor.execute("PRAGMA foreign_keys=on;") - cursor.execute("INSERT OR IGNORE INTO images (sha256, id, date, size, uarch, system) VALUES (?, ?, ?, ?, ?, ?)", - (r.sha256, r.id, r.date, r.size, r.uarch, r.system)) - # Insert a new name/version to the uenv table if no existing images with that pair exist - cursor.execute("INSERT OR IGNORE INTO uenv (name, version) VALUES (?, ?)", - (r.name, r.version)) - # Retrieve the version_id of the name/version pair + cursor.execute("INSERT OR IGNORE INTO images (sha256, id, date, size) VALUES (?, ?, ?, ?)", + (r.sha256, r.id, r.date, r.size)) + # Insert a new system/uarch/name/version to the uenv table if no existing images exist + cursor.execute("INSERT OR IGNORE INTO uenv (system, uarch, name, version) VALUES (?, ?, ?, ?)", + (r.system, r.uarch, r.name, r.version)) + # Retrieve the version_id of the system/uarch/name/version identifier # This requires a SELECT query to get the correct version_id whether or not # a new row was added in the last INSERT - cursor.execute("SELECT version_id FROM uenv WHERE name = ? AND version = ?", - (r.name, r.version)) + cursor.execute("SELECT version_id FROM uenv WHERE system = ? AND uarch = ? AND name = ? AND version = ?", + (r.system, r.uarch, r.name, r.version)) version_id = cursor.fetchone()[0] # Check whether an image with the same tag already exists in the repos cursor.execute("SELECT version_id, tag, sha256 FROM tags WHERE version_id = ? AND tag = ?", @@ -255,15 +313,12 @@ def find_records(self, **constraints): # Find matching records for each constraint items = self._store.execute(f"SELECT * FROM records WHERE {query_criteria}") - request = "" - if "name" in constraints: - request = constraints["name"] - if "version" in constraints: - request = request + f"/{constraints['version']}" - if "tag" in constraints: - request = request + f":{constraints['tag']}" - if "uarch" in constraints: - request = request + f"@{constraints['uarch']}" + ns = constraints.get("name", "*") + vs = constraints.get("version", "*") + ts = constraints.get("tag", "*") + us = constraints.get("uarch", "*") + ss = constraints.get("system", "all") + request = f"{ns}/{vs}:{ts}@{us} on {ss}" results = [self.to_record(r) for r in items]; results.sort(reverse=True) @@ -272,7 +327,7 @@ def find_records(self, **constraints): @property def images(self): items = self._store.execute(f"SELECT * FROM records") - return RecordSet([self.to_record(r) for r in items], "all") + return RecordSet([self.to_record(r) for r in items], "{*}/{*}:{*}@{*} on all") # return a list of records that match a sha def get_record(self, sha: str) -> Record: diff --git a/test/unit/test_datastore.py b/test/unit/test_datastore.py index 2966cce..f55338e 100644 --- a/test/unit/test_datastore.py +++ b/test/unit/test_datastore.py @@ -19,6 +19,14 @@ record.Record("santis", "gh200", "icon", "2024", "v1", "2023/01/01", 5024, "2"*64), ] +# records with the same name/version:tag, that should be disambiguated by hash, system, uarch +duplicate_records = [ + record.Record("santis", "gh200", "netcdf-tools", "2024", "v1", "2024/02/12", 1024, "w"*64), + record.Record("todi", "gh200", "netcdf-tools", "2024", "v1", "2024/02/12", 1024, "x"*64), + record.Record("balfrin", "a100", "netcdf-tools", "2024", "v1", "2024/02/12", 1024, "y"*64), + record.Record("balfrin", "zen3", "netcdf-tools", "2024", "v1", "2024/02/12", 1024, "z"*64), +] + def create_prgenv_repo(con): # add some records that will be inserted into the database # these defined different versions of prgenv-gnu @@ -35,6 +43,9 @@ def create_full_repo(con): for r in icon_records: con.add_record(r) + for r in duplicate_records: + con.add_record(r) + class TestInMemory(unittest.TestCase): def setUp(self): @@ -90,11 +101,11 @@ def test_find_records(self): self.assertEqual(2, len(store.find_records(name="prgenv-gnu", tag="v1").records)) self.assertEqual(1, len(store.find_records(name="prgenv-gnu", tag="v2").records)) - self.assertEqual(3, len(store.find_records(tag="v1").records)) + self.assertEqual(7, len(store.find_records(tag="v1").records)) self.assertEqual(2, len(store.find_records(tag="v2").records)) self.assertEqual(3, len(store.find_records(tag="default").records)) - self.assertEqual(3, len(store.find_records(version="2024").records)) + self.assertEqual(7, len(store.find_records(version="2024").records)) self.assertEqual(3, len(store.find_records(version="23.11").records)) self.assertEqual(0, len(store.find_records(name="icon", version="23.11").records)) self.assertEqual(3, len(store.find_records(name="prgenv-gnu", version="23.11").records)) @@ -106,16 +117,42 @@ def test_find_records(self): self.assertEqual(store.find_records(name="icon", version="2024", tag="default").records, store.find_records(name="icon", version="2024", tag="v2").records) - self.assertEqual(8, len(store.find_records(uarch="gh200").records)) - self.assertEqual(8, len(store.find_records(system="santis").records)) - - self.assertEqual(store.find_records(system="santis").records, - store.find_records(uarch="gh200").records) + self.assertEqual(10, len(store.find_records(uarch="gh200").records)) + self.assertEqual(9, len(store.find_records(system="santis").records)) # expect an exception when an invalid field is passed (sustem is a typo for system) with self.assertRaises(ValueError): result = store.find_records(sustem="santis") + def test_find_duplicates(self): + store = datastore.DataStore(path=None) + create_full_repo(store) + + # there are 4 records that match "netcdf-tools/2024:v1" that are disambiguated by + # system and uarch + # + # - santis, gh200, sha=wwwww... + # - todi, gh200, sha=xxxxx... + # - balfrin, a100, sha=yyyyy... + # - balfrin, zen3, sha=zzzzz... + # + # different vClusters are expected in the DB. + self.assertEqual(4, len(store.find_records(name="netcdf-tools").records)) + + self.assertEqual(1, len(store.find_records(name="netcdf-tools", system="santis").records)) + self.assertEqual(1, len(store.find_records(name="netcdf-tools", system="todi").records)) + self.assertEqual(2, len(store.find_records(name="netcdf-tools", system="balfrin").records)) + self.assertEqual(1, len(store.find_records(name="netcdf-tools", system="balfrin", uarch="a100").records)) + self.assertEqual("y"*64, + store.find_records( + name="netcdf-tools", system="balfrin", uarch="a100" + ).records[0].sha256) + self.assertEqual(1, len(store.find_records(name="netcdf-tools", system="balfrin", uarch="zen3").records)) + self.assertEqual("z"*64, + store.find_records( + name="netcdf-tools", system="balfrin", uarch="zen3" + ).records[0].sha256) + def test_get_record(self): store = self.store