diff --git a/src/schnetpack/datasets/rmd17.py b/src/schnetpack/datasets/rmd17.py index 50cf5938..dea81138 100644 --- a/src/schnetpack/datasets/rmd17.py +++ b/src/schnetpack/datasets/rmd17.py @@ -4,8 +4,8 @@ import tempfile import tarfile from typing import List, Optional, Dict -from urllib import request as request - +from urllib.request import Request, urlopen +from urllib.error import URLError, HTTPError import numpy as np from ase import Atoms @@ -194,13 +194,45 @@ def _download_data( ): logging.info("Downloading {} data".format(self.molecule)) raw_path = os.path.join(tmpdir, "rmd17") - tar_path = os.path.join(tmpdir, "rmd17.tar.gz") - url = "https://figshare.com/ndownloader/files/23950376" - request.urlretrieve(url, tar_path) + tar_path = os.path.join(tmpdir, "rmd17.tar") + urls = [ + "https://figshare.com/ndownloader/files/23950376", + "https://archive.materialscloud.org/records/pfffs-fff86/files/rmd17.tar.bz2?download=1", # Fallback mirror + ] + + downloaded = False + last_error = None + + for u in urls: + try: + logging.info(f"Downloading from: {u}") + req = Request(u) + with urlopen(req, timeout=600) as resp, open(tar_path, "wb") as f: + shutil.copyfileobj(resp, f) + + size = os.path.getsize(tar_path) + ctype = (resp.headers.get("Content-Type") or "").lower() + + if size == 0 or "text/html" in ctype: + raise RuntimeError( + f"Blocked or invalid download (size={size}, Content-Type={ctype})" + ) + logging.info(f"Download successful rMD17.") + downloaded = True + break + + except (HTTPError, URLError, RuntimeError) as e: + last_error = e + logging.warning(f"Download failed from {u}: {e}") + + if not downloaded: + raise RuntimeError( + "rMD17 download failed from both sources. " f"Error: {last_error}" + ) logging.info("Done.") logging.info("Extracting data...") - tar = tarfile.open(tar_path) + tar = tarfile.open(tar_path, mode="r:*") tar.extract( path=raw_path, member=f"rmd17/npz_data/{self.datasets_dict[self.molecule]}" )