diff --git a/emle/models/_mace.py b/emle/models/_mace.py index ce99fae..cb50662 100644 --- a/emle/models/_mace.py +++ b/emle/models/_mace.py @@ -199,19 +199,30 @@ def __init__( f"Unsupported MACE model: '{mace_model}'. Available MACE-OFF23 models are " "'mace-off23-small', 'mace-off23-medium', 'mace-off23-large'" ) - self._mace = _mace_off(model=size, device=device, return_raw_model=True) + source_model = _mace_off(model=size, device=device, return_raw_model=True) else: # Assuming that the model is a local model. if _os.path.exists(mace_model): - self._mace = _torch.load(mace_model, map_location=device) + source_model = _torch.load(mace_model, map_location=device) else: raise FileNotFoundError(f"MACE model file not found: {mace_model}") else: # If no MACE model is provided, use the default MACE-OFF23(S) model. - self._mace = _mace_off(model="small", device=device, return_raw_model=True) + source_model = _mace_off(model="small", device=device, return_raw_model=True) + + from mace.tools.scripts_utils import extract_config_mace_model + + # Extract the config from the model. + config = extract_config_mace_model(source_model) + + # Create the target model. + target_model = source_model.__class__(**config).to(device) + + # Load the state dict. + target_model.load_state_dict(source_model.state_dict()) # Compile the model. - self._mace = _e3nn_jit.compile(self._mace).to(self._dtype) + self._mace = _e3nn_jit.compile(target_model).to(self._dtype) # Create the z_table of the MACE model. self._z_table = [int(z.item()) for z in self._mace.atomic_numbers] diff --git a/environment.yaml b/environment.yaml index c5c8d01..4098e39 100644 --- a/environment.yaml +++ b/environment.yaml @@ -18,4 +18,4 @@ dependencies: - torchani - xtb-python - pip: - - mace-torch < 0.3.9 + - mace-torch diff --git a/environment_rascal.yaml b/environment_rascal.yaml index d6d8a01..dc2bac3 100644 --- a/environment_rascal.yaml +++ b/environment_rascal.yaml @@ -21,4 +21,4 @@ dependencies: - xtb-python - pip: - git+https://github.com/lab-cosmo/librascal.git - - mace-torch < 0.3.9 + - mace-torch diff --git a/environment_sire.yaml b/environment_sire.yaml index a3e991d..6850064 100644 --- a/environment_sire.yaml +++ b/environment_sire.yaml @@ -21,4 +21,4 @@ dependencies: - torchani - xtb-python - pip: - - mace-torch < 0.3.9 + - mace-torch