diff --git a/gdown/cached_download.py b/gdown/cached_download.py index 7c196773..e00cfd73 100644 --- a/gdown/cached_download.py +++ b/gdown/cached_download.py @@ -4,6 +4,8 @@ import shutil import sys import tempfile +import warnings +from typing import Optional import filelock @@ -18,6 +20,10 @@ def md5sum(filename, blocksize=None): + warnings.warn( + "md5sum is deprecated and will be removed in the future.", FutureWarning + ) + if blocksize is None: blocksize = 65536 @@ -29,6 +35,10 @@ def md5sum(filename, blocksize=None): def assert_md5sum(filename, md5, quiet=False, blocksize=None): + warnings.warn( + "assert_md5sum is deprecated and will be removed in the future.", FutureWarning + ) + if not (isinstance(md5, str) and len(md5) == 32): raise ValueError(f"MD5 must be 32 chars: {md5}") @@ -45,7 +55,13 @@ def assert_md5sum(filename, md5, quiet=False, blocksize=None): def cached_download( - url=None, path=None, md5=None, quiet=False, postprocess=None, **kwargs + url=None, + path=None, + md5=None, + quiet=False, + postprocess=None, + hash: Optional[str] = None, + **kwargs, ): """Cached download from URL. @@ -56,11 +72,14 @@ def cached_download( path: str, optional Output filename. Default is basename of URL. md5: str, optional - Expected MD5 for specified file. + Expected MD5 for specified file. Deprecated in favor of `hash`. quiet: bool Suppress terminal output. Default is False. - postprocess: callable + postprocess: callable, optional Function called with filename as postprocess. + hash: str, optional + Hash value of file in the format of {algorithm}:{hash_value} + such as sha256:abcdef.... Supported algorithms: md5, sha1, sha256, sha512. kwargs: dict Keyword arguments to be passed to `download`. @@ -78,14 +97,25 @@ def cached_download( ) path = osp.join(cache_root, path) + if md5 is not None and hash is not None: + raise ValueError("md5 and hash cannot be specified at the same time.") + + if md5 is not None: + warnings.warn( + "md5 is deprecated in favor of hash. Please use hash='md5:xxx...' instead.", + FutureWarning, + ) + hash = f"md5:{md5}" + del md5 + # check existence - if osp.exists(path) and not md5: + if osp.exists(path) and not hash: if not quiet: print(f"File exists: {path}", file=sys.stderr) return path - elif osp.exists(path) and md5: + elif osp.exists(path) and hash: try: - assert_md5sum(path, md5, quiet=quiet) + _assert_filehash(path=path, hash=hash, quiet=quiet) return path except AssertionError as e: # show warning and overwrite if md5 doesn't match @@ -116,11 +146,47 @@ def cached_download( shutil.rmtree(temp_root) raise - if md5: - assert_md5sum(path, md5, quiet=quiet) + if hash: + _assert_filehash(path=path, hash=hash, quiet=quiet) # postprocess if postprocess is not None: postprocess(path) return path + + +def _compute_filehash(path, algorithm): + BLOCKSIZE = 65536 + + if algorithm not in hashlib.algorithms_guaranteed: + raise ValueError( + f"Unsupported hash algorithm: {algorithm}. " + f"Supported algorithms: {hashlib.algorithms_guaranteed}" + ) + + algorithm_instance = getattr(hashlib, algorithm)() + with open(path, "rb") as f: + for block in iter(lambda: f.read(BLOCKSIZE), b""): + algorithm_instance.update(block) + return f"{algorithm}:{algorithm_instance.hexdigest()}" + + +def _assert_filehash(path, hash, quiet=False, blocksize=None): + if ":" not in hash: + raise ValueError( + f"Invalid hash: {hash}. " + "Hash must be in the format of {algorithm}:{hash_value}." + ) + algorithm = hash.split(":")[0] + + hash_actual = _compute_filehash(path=path, algorithm=algorithm) + + if hash_actual == hash: + if not quiet: + print(f"File hash matches: {path!r} == {hash!r}", file=sys.stderr) + return True + + raise AssertionError( + f"File hash doesn't match:\nactual: {hash_actual}\nexpected: {hash}" + ) diff --git a/tests/test___main__.py b/tests/test___main__.py index 6f5e8d07..403aee7c 100644 --- a/tests/test___main__.py +++ b/tests/test___main__.py @@ -4,7 +4,7 @@ import sys import tempfile -from gdown.cached_download import assert_md5sum +from gdown.cached_download import _assert_filehash here = os.path.dirname(os.path.abspath(__file__)) @@ -15,7 +15,7 @@ def _test_cli_with_md5(url_or_id, md5, options=None): if options is not None: cmd = f"{cmd} {options}" subprocess.call(shlex.split(cmd)) - assert_md5sum(filename=f.name, md5=md5) + _assert_filehash(path=f.name, hash=f"md5:{md5}") def _test_cli_with_content(url_or_id, content): diff --git a/tests/test_cached_download.py b/tests/test_cached_download.py new file mode 100644 index 00000000..521c45e2 --- /dev/null +++ b/tests/test_cached_download.py @@ -0,0 +1,37 @@ +import tempfile +import warnings + +import gdown + + +def test_cached_download_md5_deprecated(): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + gdown.cached_download( + url="https://drive.google.com/uc?id=0B9P1L--7Wd2vU3VUVlFnbTgtS2c", + md5="cb31a703b96c1ab2f80d164e9676fe7d", + ) + assert len(w) == 1 + assert issubclass(w[-1].category, FutureWarning) + assert "md5" in str(w[-1].message) + + +def _cached_download(**kwargs): + url = "https://drive.google.com/uc?id=0B9P1L--7Wd2vU3VUVlFnbTgtS2c" + with tempfile.NamedTemporaryFile() as f: + for _ in range(2): + gdown.cached_download(url=url, path=f.name, **kwargs) + + +def test_cached_download_md5(): + _cached_download(hash="md5:cb31a703b96c1ab2f80d164e9676fe7d") + + +def test_cached_download_sha1(): + _cached_download(hash="sha1:69a5a1000f98237efea9231c8a39d05edf013494") + + +def test_cached_download_sha256(): + _cached_download( + hash="sha256:284e3029cce3ae5ee0b05866100e300046359f53ae4c77fe6b34c05aa7a72cee" + )