diff --git a/src/schnetpack/data/atoms.py b/src/schnetpack/data/atoms.py index 46ed7da9..f01f866f 100644 --- a/src/schnetpack/data/atoms.py +++ b/src/schnetpack/data/atoms.py @@ -374,13 +374,14 @@ def _get_properties( return properties # Metadata - @property def metadata(self): - return self.conn.metadata + with connect(self.datapath, use_lock_file=False) as conn: + return conn.metadata def _set_metadata(self, val: Dict[str, Any]): - self.conn.metadata = val + with connect(self.datapath, use_lock_file=False) as conn: + conn.metadata = val def update_metadata(self, **kwargs): assert all( @@ -477,7 +478,7 @@ def add_system( `available_properties` of the dataset. """ - self._add_system(self.conn, atoms, atoms_metadata, **properties) + self._add_system(atoms, atoms_metadata, **properties) def add_systems( self, @@ -511,7 +512,6 @@ def add_systems( atoms_list, property_list, atoms_metadata_list ): self._add_system( - self.conn, atoms, atoms_metadata, **prop, @@ -519,7 +519,6 @@ def add_systems( def _add_system( self, - conn, atoms: Optional[Atoms] = None, atoms_metadata: Optional[Dict[str, Any]] = None, **properties, @@ -543,28 +542,30 @@ def _add_system( if atoms_metadata is None: atoms_metadata = {} - # add available properties to database - valid_props = set().union( - conn.metadata["_property_unit_dict"].keys(), - [structure.Z, structure.R, structure.cell, structure.pbc], - ) - for pname in properties: - if pname not in valid_props: - logger.warning( - f"Property `{pname}` is not a defined property for this dataset and " - + f"will be ignored. If it should be included, it has to be " - + f"provided together with its unit when calling " - + f"AseAtomsData.create()." - ) - - data = {} - for pname in conn.metadata["_property_unit_dict"].keys(): - if pname in properties: - data[pname] = properties[pname] - else: - raise AtomsDataError("Required property missing:" + pname) + with connect(self.datapath, use_lock_file=False) as conn: + prop_keys = conn.metadata["_property_unit_dict"].keys() - conn.write(atoms, data=data, key_value_pairs=atoms_metadata) + valid_props = set().union( + prop_keys, + [structure.Z, structure.R, structure.cell, structure.pbc], + ) + for pname in properties: + if pname not in valid_props: + logger.warning( + f"Property `{pname}` is not a defined property for this dataset and " + + f"will be ignored. If it should be included, it has to be " + + f"provided together with its unit when calling " + + f"AseAtomsData.create()." + ) + + data = {} + for pname in prop_keys: + if pname in properties: + data[pname] = properties[pname] + else: + raise AtomsDataError("Required property missing:" + pname) + + conn.write(atoms, data=data, key_value_pairs=atoms_metadata) def create_dataset( diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index 89c75ac5..3085a21a 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -31,9 +31,18 @@ class QM9(AtomsDataModule): References: .. [#qm9_1] https://ndownloader.figshare.com/files/3195404 - """ + base_urls = [ + "https://ndownloader.figshare.com/files/", + "https://springernature.figshare.com/ndownloader/files/", + ] + file_ids = { + "data": "3195389", + "atomrefs": "3195395", + "uncharacterized": "3195404", + } + # properties A = "rotational_constant_A" B = "rotational_constant_B" @@ -127,6 +136,18 @@ def __init__( self.remove_uncharacterized = remove_uncharacterized + def _download_file(self, file_id: str, destination: str): + for base_url in self.base_urls: + url = f"{base_url}{file_id}" + try: + request.urlretrieve(url, destination) + return + except Exception: + logging.warning(f"Could not download from {url}, trying next source...") + raise AtomsDataModuleError( + f"Could not download file with id {file_id} from any source." + ) + def prepare_data(self): if not os.path.exists(self.datapath): property_unit_dict = { @@ -179,9 +200,8 @@ def prepare_data(self): def _download_uncharacterized(self, tmpdir): logging.info("Downloading list of uncharacterized molecules...") - at_url = "https://ndownloader.figshare.com/files/3195404" tmp_path = os.path.join(tmpdir, "uncharacterized.txt") - request.urlretrieve(at_url, tmp_path) + self._download_file(self.file_ids["uncharacterized"], tmp_path) logging.info("Done.") uncharacterized = [] @@ -193,9 +213,8 @@ def _download_uncharacterized(self, tmpdir): def _download_atomrefs(self, tmpdir): logging.info("Downloading GDB-9 atom references...") - at_url = "https://ndownloader.figshare.com/files/3195395" tmp_path = os.path.join(tmpdir, "atomrefs.txt") - request.urlretrieve(at_url, tmp_path) + self._download_file(self.file_ids["atomrefs"], tmp_path) logging.info("Done.") props = [QM9.zpve, QM9.U0, QM9.U, QM9.H, QM9.G, QM9.Cv] @@ -214,9 +233,7 @@ def _download_data( logging.info("Downloading GDB-9 data...") tar_path = os.path.join(tmpdir, "gdb9.tar.gz") raw_path = os.path.join(tmpdir, "gdb9_xyz") - url = "https://ndownloader.figshare.com/files/3195389" - - request.urlretrieve(url, tar_path) + self._download_file(self.file_ids["data"], tar_path) logging.info("Done.") logging.info("Extracting files...")