Skip to content

Commit

Permalink
fix database issue (#34)
Browse files Browse the repository at this point in the history
There is a mistake in the database design, where if there is more than one image with the same `name/version:tag`, e.g.
* `netcdf-tools/2024:v1` on `balfrin`
* `netcdf-tools/2024:v1` on `santis`
Only one of the instances will be in the in-memory database that is formed when `uenv image find` is called, leading to missing entries.

This PR moves the additional data, `uarch` and `system` to the unique uenv identifier table

Adds unit tests that are fixed by the changes.
  • Loading branch information
bcumming authored Jun 6, 2024
1 parent 3b5d8f0 commit 8b9b6f3
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 27 deletions.
95 changes: 75 additions & 20 deletions lib/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def __init__(self, message):
def __str__(self):
return self.message

create_db_command = """
create_db_commands = {
"v1": """
BEGIN;
PRAGMA foreign_keys=on;
Expand Down Expand Up @@ -81,7 +82,64 @@ def __str__(self):
INNER JOIN images ON images.sha256 = tags.sha256;
COMMIT;
"""
""",
"v2": """
BEGIN;
PRAGMA foreign_keys=on;
CREATE TABLE images (
sha256 TEXT PRIMARY KEY CHECK(length(sha256)==64),
id TEXT UNIQUE CHECK(length(id)==16),
date TEXT NOT NULL,
size INTEGER NOT NULL
);
CREATE TABLE uenv (
version_id INTEGER PRIMARY KEY,
system TEXT NOT NULL,
uarch TEXT NOT NULL,
name TEXT NOT NULL,
version TEXT NOT NULL,
UNIQUE (system, uarch, name, version)
);
CREATE TABLE tags (
version_id INTEGER,
tag TEXT NOT NULL,
sha256 TEXT NOT NULL,
PRIMARY KEY (version_id, tag),
FOREIGN KEY (version_id)
REFERENCES uenv (version_id)
ON DELETE CASCADE
ON UPDATE CASCADE,
FOREIGN KEY (sha256)
REFERENCES images (sha256)
ON DELETE CASCADE
ON UPDATE CASCADE
);
-- for convenient generation of the Record type used internally by uenv-image
CREATE VIEW records AS
SELECT
uenv.system AS system,
uenv.uarch AS uarch,
uenv.name AS name,
uenv.version AS version,
tags.tag AS tag,
images.date AS date,
images.size AS size,
tags.sha256 AS sha256,
images.id AS id
FROM tags
INNER JOIN uenv ON uenv.version_id = tags.version_id
INNER JOIN images ON images.sha256 = tags.sha256;
COMMIT;
"""}

db_version = "v2"
create_db_command = create_db_commands[db_version]

class RecordSet():
def __init__(self, records, request):
Expand Down Expand Up @@ -172,16 +230,16 @@ def add_record(self, r: Record):

cursor.execute("BEGIN;")
cursor.execute("PRAGMA foreign_keys=on;")
cursor.execute("INSERT OR IGNORE INTO images (sha256, id, date, size, uarch, system) VALUES (?, ?, ?, ?, ?, ?)",
(r.sha256, r.id, r.date, r.size, r.uarch, r.system))
# Insert a new name/version to the uenv table if no existing images with that pair exist
cursor.execute("INSERT OR IGNORE INTO uenv (name, version) VALUES (?, ?)",
(r.name, r.version))
# Retrieve the version_id of the name/version pair
cursor.execute("INSERT OR IGNORE INTO images (sha256, id, date, size) VALUES (?, ?, ?, ?)",
(r.sha256, r.id, r.date, r.size))
# Insert a new system/uarch/name/version to the uenv table if no existing images exist
cursor.execute("INSERT OR IGNORE INTO uenv (system, uarch, name, version) VALUES (?, ?, ?, ?)",
(r.system, r.uarch, r.name, r.version))
# Retrieve the version_id of the system/uarch/name/version identifier
# This requires a SELECT query to get the correct version_id whether or not
# a new row was added in the last INSERT
cursor.execute("SELECT version_id FROM uenv WHERE name = ? AND version = ?",
(r.name, r.version))
cursor.execute("SELECT version_id FROM uenv WHERE system = ? AND uarch = ? AND name = ? AND version = ?",
(r.system, r.uarch, r.name, r.version))
version_id = cursor.fetchone()[0]
# Check whether an image with the same tag already exists in the repos
cursor.execute("SELECT version_id, tag, sha256 FROM tags WHERE version_id = ? AND tag = ?",
Expand Down Expand Up @@ -255,15 +313,12 @@ def find_records(self, **constraints):
# Find matching records for each constraint
items = self._store.execute(f"SELECT * FROM records WHERE {query_criteria}")

request = ""
if "name" in constraints:
request = constraints["name"]
if "version" in constraints:
request = request + f"/{constraints['version']}"
if "tag" in constraints:
request = request + f":{constraints['tag']}"
if "uarch" in constraints:
request = request + f"@{constraints['uarch']}"
ns = constraints.get("name", "*")
vs = constraints.get("version", "*")
ts = constraints.get("tag", "*")
us = constraints.get("uarch", "*")
ss = constraints.get("system", "all")
request = f"{ns}/{vs}:{ts}@{us} on {ss}"

results = [self.to_record(r) for r in items];
results.sort(reverse=True)
Expand All @@ -272,7 +327,7 @@ def find_records(self, **constraints):
@property
def images(self):
items = self._store.execute(f"SELECT * FROM records")
return RecordSet([self.to_record(r) for r in items], "all")
return RecordSet([self.to_record(r) for r in items], "{*}/{*}:{*}@{*} on all")

# return a list of records that match a sha
def get_record(self, sha: str) -> Record:
Expand Down
51 changes: 44 additions & 7 deletions test/unit/test_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
record.Record("santis", "gh200", "icon", "2024", "v1", "2023/01/01", 5024, "2"*64),
]

# records with the same name/version:tag, that should be disambiguated by hash, system, uarch
duplicate_records = [
record.Record("santis", "gh200", "netcdf-tools", "2024", "v1", "2024/02/12", 1024, "w"*64),
record.Record("todi", "gh200", "netcdf-tools", "2024", "v1", "2024/02/12", 1024, "x"*64),
record.Record("balfrin", "a100", "netcdf-tools", "2024", "v1", "2024/02/12", 1024, "y"*64),
record.Record("balfrin", "zen3", "netcdf-tools", "2024", "v1", "2024/02/12", 1024, "z"*64),
]

def create_prgenv_repo(con):
# add some records that will be inserted into the database
# these defined different versions of prgenv-gnu
Expand All @@ -35,6 +43,9 @@ def create_full_repo(con):
for r in icon_records:
con.add_record(r)

for r in duplicate_records:
con.add_record(r)

class TestInMemory(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -90,11 +101,11 @@ def test_find_records(self):
self.assertEqual(2, len(store.find_records(name="prgenv-gnu", tag="v1").records))
self.assertEqual(1, len(store.find_records(name="prgenv-gnu", tag="v2").records))

self.assertEqual(3, len(store.find_records(tag="v1").records))
self.assertEqual(7, len(store.find_records(tag="v1").records))
self.assertEqual(2, len(store.find_records(tag="v2").records))
self.assertEqual(3, len(store.find_records(tag="default").records))

self.assertEqual(3, len(store.find_records(version="2024").records))
self.assertEqual(7, len(store.find_records(version="2024").records))
self.assertEqual(3, len(store.find_records(version="23.11").records))
self.assertEqual(0, len(store.find_records(name="icon", version="23.11").records))
self.assertEqual(3, len(store.find_records(name="prgenv-gnu", version="23.11").records))
Expand All @@ -106,16 +117,42 @@ def test_find_records(self):
self.assertEqual(store.find_records(name="icon", version="2024", tag="default").records,
store.find_records(name="icon", version="2024", tag="v2").records)

self.assertEqual(8, len(store.find_records(uarch="gh200").records))
self.assertEqual(8, len(store.find_records(system="santis").records))

self.assertEqual(store.find_records(system="santis").records,
store.find_records(uarch="gh200").records)
self.assertEqual(10, len(store.find_records(uarch="gh200").records))
self.assertEqual(9, len(store.find_records(system="santis").records))

# expect an exception when an invalid field is passed (sustem is a typo for system)
with self.assertRaises(ValueError):
result = store.find_records(sustem="santis")

def test_find_duplicates(self):
store = datastore.DataStore(path=None)
create_full_repo(store)

# there are 4 records that match "netcdf-tools/2024:v1" that are disambiguated by
# system and uarch
#
# - santis, gh200, sha=wwwww...
# - todi, gh200, sha=xxxxx...
# - balfrin, a100, sha=yyyyy...
# - balfrin, zen3, sha=zzzzz...
#
# different vClusters are expected in the DB.
self.assertEqual(4, len(store.find_records(name="netcdf-tools").records))

self.assertEqual(1, len(store.find_records(name="netcdf-tools", system="santis").records))
self.assertEqual(1, len(store.find_records(name="netcdf-tools", system="todi").records))
self.assertEqual(2, len(store.find_records(name="netcdf-tools", system="balfrin").records))
self.assertEqual(1, len(store.find_records(name="netcdf-tools", system="balfrin", uarch="a100").records))
self.assertEqual("y"*64,
store.find_records(
name="netcdf-tools", system="balfrin", uarch="a100"
).records[0].sha256)
self.assertEqual(1, len(store.find_records(name="netcdf-tools", system="balfrin", uarch="zen3").records))
self.assertEqual("z"*64,
store.find_records(
name="netcdf-tools", system="balfrin", uarch="zen3"
).records[0].sha256)

def test_get_record(self):
store = self.store

Expand Down

0 comments on commit 8b9b6f3

Please sign in to comment.