Skip to content

Commit b8110e5

Browse files
authored
Dataset Downloader (#5)
1 parent 08c7213 commit b8110e5

15 files changed

+415
-26
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,4 @@ cython_debug/
161161
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
162162
#.idea/
163163

164-
datasets
164+
construe/datasets/fixtures

MANIFEST.in

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ include *.rst
33
include *.txt
44
include *.yml
55
include *.cfg
6-
include MANIFEST.in
76

7+
include MANIFEST.in
88
include LICENSE
99

1010
graft docs
@@ -14,6 +14,7 @@ graft tests
1414
prune tests/fixtures
1515

1616
graft construe
17+
prune construe/datasets/fixtures
1718

1819
global-exclude __pycache__
1920
global-exclude *.py[co]

construe/__main__.py

+40-4
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
import torch
77
import platform
88

9-
from click import ClickException
109
from .version import get_version
1110
from .basic import BasicBenchmark
11+
from .exceptions import DeviceError
1212
from .moondream import MoonDreamBenchmark
13+
from .datasets.manifest import generate_manifest
14+
from .datasets.path import get_data_home, FIXTURES
1315

1416

1517
CONTEXT_SETTINGS = {
@@ -34,13 +36,26 @@
3436
envvar=["CONSTRUE_ENV", "ENV"],
3537
help="name of the experimental environment for comparison (default is hostname)",
3638
)
39+
@click.option(
40+
"-D",
41+
"--datadir",
42+
default=None,
43+
envvar="CONSTRUE_DATA",
44+
help="specify the location to download datasets to",
45+
)
46+
@click.option(
47+
"-C",
48+
"--cleanup/--no-cleanup",
49+
default=True,
50+
help="cleanup all downloaded datasets after the benchmark is run",
51+
)
3752
@click.pass_context
38-
def main(ctx, env=None, device=None):
53+
def main(ctx, env=None, device=None, datadir=None, cleanup=True):
3954
if device is not None:
4055
try:
4156
torch.set_default_device(device)
4257
except RuntimeError as e:
43-
raise ClickException(str(e))
58+
raise DeviceError(e)
4459

4560
click.echo(f"using torch.device(\"{device}\")")
4661

@@ -50,6 +65,8 @@ def main(ctx, env=None, device=None):
5065
ctx.ensure_object(dict)
5166
ctx.obj["device"] = device
5267
ctx.obj["env"] = env
68+
ctx.obj["data_home"] = get_data_home(datadir)
69+
ctx.obj["cleanup"] = cleanup
5370

5471

5572
@main.command()
@@ -89,11 +106,30 @@ def basic(ctx, **kwargs):
89106
@main.command()
90107
@click.pass_context
91108
def moondream(ctx, **kwargs):
92-
kwargs["env"] = ctx["env"]
109+
kwargs["env"] = ctx.obj["env"]
93110
benchmark = MoonDreamBenchmark(**kwargs)
94111
benchmark.run()
95112

96113

114+
@main.command()
115+
@click.option(
116+
"-f",
117+
"--fixtures",
118+
type=str,
119+
default=FIXTURES,
120+
help="path to fixtures directory to generate manifest from",
121+
)
122+
@click.option(
123+
"-o",
124+
"--out",
125+
type=str,
126+
default=None,
127+
help="path to write the manifest to",
128+
)
129+
def manifest(fixtures=FIXTURES, out=None):
130+
generate_manifest(fixtures, out)
131+
132+
97133
if __name__ == "__main__":
98134
main(
99135
obj={},

construe/datasets.py

-10
This file was deleted.

construe/datasets/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""
2+
Manages datasets used for inferencing
3+
"""
4+
5+
from .loaders import * # noqa
6+
from .download import download_data
7+
from .path import get_data_home, cleanup_dataset

construe/datasets/download.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""
2+
Handle downloading datasets from our content URL
3+
"""
4+
5+
import os
6+
import zipfile
7+
8+
from urllib.request import urlopen
9+
10+
from .signature import sha256sum
11+
from .path import get_data_home, cleanup_dataset
12+
13+
from construe.exceptions import DatasetsError
14+
15+
16+
# Downlod chunk size
17+
CHUNK = 524288
18+
19+
20+
def download_data(url, signature, data_home=None, replace=False, extract=True):
21+
"""
22+
Downloads the zipped data set specified at the given URL, saving it to
23+
the data directory specified by ``get_data_home``. This function verifies
24+
the download with the given signature and extracts the archive.
25+
"""
26+
data_home = get_data_home(data_home)
27+
28+
# Get the name of the file from the URL
29+
basename = os.path.basename(url)
30+
name, _ = os.path.splitext(basename)
31+
32+
# Get the archive and data directory paths
33+
archive = os.path.join(data_home, basename)
34+
datadir = os.path.join(data_home, name)
35+
36+
# If the archive exists cleanup or raise override exception
37+
if os.path.exists(archive):
38+
if not replace:
39+
raise DatasetsError(
40+
("dataset already exists at {}, set replace=False to overwrite").format(
41+
archive
42+
)
43+
)
44+
45+
cleanup_dataset(name, data_home=data_home)
46+
47+
# Create the output directory if it does not exist
48+
if not os.path.exists(datadir):
49+
os.mkdir(datadir)
50+
51+
# Fetch the response in a streaming fashion and write it to disk.
52+
response = urlopen(url)
53+
54+
with open(archive, "wb") as f:
55+
while True:
56+
chunk = response.read(CHUNK)
57+
if not chunk:
58+
break
59+
f.write(chunk)
60+
61+
# Compare the signature of the archive to the expected one
62+
if sha256sum(archive) != signature:
63+
raise ValueError("Download signature does not match hardcoded signature!")
64+
65+
# If extract, extract the zipfile.
66+
if extract:
67+
zf = zipfile.ZipFile(archive)
68+
zf.extractall(path=data_home)

construe/datasets/loaders.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
Managers for loading datasets
3+
"""
4+
5+
import os
6+
import glob
7+
8+
from .manifest import load_manifest
9+
from .download import download_data
10+
from ..exceptions import DatasetsError
11+
from .path import dataset_archive, find_dataset_path, cleanup_dataset
12+
13+
14+
__all__ = ["load_content_moderation", "cleanup_content_moderation"]
15+
16+
17+
DATASETS = load_manifest()
18+
CONTENT_MODERATION = "content-moderation"
19+
20+
21+
def _info(dataset):
22+
if dataset not in DATASETS:
23+
raise DatasetsError(f"no dataset named {dataset} exists")
24+
return DATASETS[dataset]
25+
26+
27+
def load_content_moderation(data_home=None):
28+
"""
29+
Downloads the content moderation dataset if it does not exist then
30+
yields all of the paths for the images in the dataset.
31+
"""
32+
info = _info(CONTENT_MODERATION)
33+
if not dataset_archive(CONTENT_MODERATION, info["signature"], data_home=data_home):
34+
# If the dataset does not exist, download and extract it
35+
info.update({"data_home": data_home, "replace": False, "extract": True})
36+
download_data(**info)
37+
38+
data_path = find_dataset_path(CONTENT_MODERATION, fname=None, ext=None)
39+
for path in glob.glob(os.path.join(data_path, "**", "*")):
40+
yield path
41+
42+
43+
def cleanup_content_moderation(data_home=None):
44+
"""
45+
Removes the content moderation dataset and archive.
46+
"""
47+
return cleanup_dataset(CONTENT_MODERATION, data_home=data_home)

construe/datasets/manifest.json

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"content-moderation": {
3+
"url": "https://storage.googleapis.com/construe/v0.2.0/content-moderation.zip",
4+
"signature": "4925f2733bf3b1596e4d950e49d12bf4fff2379f08cbf0a105807ff03d4c2b4e"
5+
}
6+
}

construe/datasets/manifest.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""
2+
Manifest handlers for downloading and checking signatures
3+
"""
4+
5+
import os
6+
import json
7+
import glob
8+
9+
from urllib.parse import urljoin
10+
11+
from .signature import sha256sum
12+
from ..version import get_version
13+
from .path import FIXTURES, MANIFEST
14+
15+
16+
BASE_URL = "https://storage.googleapis.com/"
17+
18+
19+
def load_manifest(path=MANIFEST):
20+
with open(MANIFEST, "r") as f:
21+
return json.load(f)
22+
23+
24+
def generate_manifest(fixtures=FIXTURES, out=MANIFEST):
25+
out = out or MANIFEST
26+
27+
manifest = {}
28+
version = get_version(short=True)
29+
30+
for path in glob.glob(os.path.join(fixtures, "*.zip")):
31+
fname = os.path.basename(path)
32+
name, _ = os.path.splitext(fname)
33+
34+
manifest[name] = {
35+
"url": urljoin(BASE_URL, f"construe/v{version}/{fname}"),
36+
"signature": sha256sum(path),
37+
}
38+
39+
with open(out, "w") as o:
40+
json.dump(manifest, o, indent=2)

0 commit comments

Comments
 (0)