Skip to content

Commit 4d9042d

Browse files
committed
adds elstruct parameters/readers/writers for ase calls to psi4 and
autorun ase runners for single point energies elstruct now has MLIP models allowed in the par and reader/writers for energy jobs with them in ASE. autorun can run these for MACE, more to come
1 parent 115b6a8 commit 4d9042d

File tree

17 files changed

+1138
-30
lines changed

17 files changed

+1138
-30
lines changed

src/autorun/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from autorun._run import from_input_string
88
from autorun._run import run_script
99
from autorun._run import write_input
10+
from autorun._run import dump_input
11+
from autorun._run import dump_output
1012
from autorun._run import read_output
1113
from autorun._host import host_node
1214
from autorun._host import process_id
@@ -23,6 +25,7 @@
2325
from autorun import projrot
2426
from autorun import thermp
2527
from autorun import varecof
28+
from autorun import ase
2629

2730
# MultiProgram Runners
2831
from autorun._multiprog import projected_frequencies
@@ -35,6 +38,8 @@
3538
'run_script',
3639
'from_input_string',
3740
'write_input',
41+
'dump_input',
42+
'dump_output',
3843
'read_output',
3944
'host_node',
4045
'process_id',
@@ -50,6 +55,7 @@
5055
'projrot',
5156
'thermp',
5257
'varecof',
58+
'ase',
5359
# MultiProgram Runners
5460
'projected_frequencies',
5561
'thermo'

src/autorun/_run.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import subprocess
66
import warnings
77
import stat
8+
import yaml
89

910

1011
SCRIPT_NAME = 'run.sh'
@@ -120,6 +121,90 @@ def write_input(run_dir, input_str,
120121
aux_obj.write(fstring)
121122

122123

124+
def _simplify_data(obj):
125+
"""Convert numpy arrays, numpy scalars, and tuples to simple Python lists and values.
126+
127+
Args:
128+
obj: The object to simplify
129+
Returns:
130+
Simplified version of the object with standard Python types
131+
"""
132+
import numpy as np
133+
134+
if isinstance(obj, dict):
135+
return {k: _simplify_data(v) for k, v in obj.items()}
136+
elif isinstance(obj, (list, tuple)):
137+
return [_simplify_data(x) for x in obj]
138+
elif isinstance(obj, np.ndarray):
139+
return _simplify_data(obj.tolist())
140+
elif isinstance(obj, np.generic):
141+
return obj.item()
142+
return obj
143+
144+
145+
def dump_input(run_dir, input_dct,
146+
aux_dct=None,
147+
input_name=INPUT_NAME):
148+
""" Write input from a dictionary to a file in YAML format
149+
150+
Args:
151+
run_dir: Directory where input file will be written
152+
input_dict: Dictionary containing input data to be written
153+
aux_dct: Dictionary of auxiliary files to write {filename: content}
154+
input_name: Name of the main input file
155+
"""
156+
if not os.path.exists(run_dir):
157+
os.makedirs(run_dir)
158+
159+
# Simplify the data structures before dumping
160+
simplified_dct = _simplify_data(input_dct)
161+
162+
with EnterDirectory(run_dir):
163+
# Write the main input file as YAML
164+
with open(input_name, mode='w', encoding='utf-8') as input_obj:
165+
input_obj.write('__yaml__\n')
166+
yaml.dump(
167+
simplified_dct,
168+
input_obj,
169+
default_flow_style=False,
170+
sort_keys=False,
171+
allow_unicode=True)
172+
173+
# Write all auxiliary input files
174+
if aux_dct is not None:
175+
for fname, fstring in aux_dct.items():
176+
if fstring:
177+
with open(fname, mode='w', encoding='utf-8') as aux_obj:
178+
aux_obj.write(fstring)
179+
180+
181+
182+
def dump_output(run_dir, output_dct,
183+
output_name=OUTPUT_NAME):
184+
""" Write output from a dictionary to a file in YAML format
185+
Args:
186+
run_dir: Directory where output file will be written
187+
output_dct: Dictionary containing output data to be written
188+
output_name: Name of the main output file
189+
"""
190+
if not os.path.exists(run_dir):
191+
os.makedirs(run_dir)
192+
193+
# Simplify the data structures before dumping
194+
simplified_dct = _simplify_data(output_dct)
195+
196+
with EnterDirectory(run_dir):
197+
# Write the main output file as YAML
198+
with open(output_name, mode='w', encoding='utf-8') as output_obj:
199+
output_obj.write('__yaml__\n')
200+
yaml.dump(
201+
simplified_dct,
202+
output_obj,
203+
default_flow_style=False,
204+
sort_keys=False,
205+
allow_unicode=True)
206+
207+
123208
def read_output(run_dir, output_names=(OUTPUT_NAME,)):
124209
""" Read the output string from the run directory
125210
"""

src/autorun/ase.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""Functions to run ASE calculators."""
2+
3+
from ase import Atoms
4+
from ase import units
5+
from ase.optimize import BFGS
6+
# from sella import Sella
7+
8+
def get_calculator(calculator, family=None):
9+
"""Get a new instance of the appropriate ASE calculator.
10+
11+
Args:
12+
program: The program name ('psi4', 'mace_mp', etc.)
13+
Returns:
14+
The calculator class
15+
Raises:
16+
ValueError: If the program is not supported
17+
"""
18+
# Lazy import the appropriate calculator
19+
if calculator == 'psi4':
20+
import psi4
21+
from ase.calculators.psi4 import Psi4
22+
return Psi4
23+
elif calculator == 'mace':
24+
if family == 'mace_mp':
25+
from mace.calculators import mace_mp
26+
return mace_mp
27+
elif family == 'mace_off':
28+
from mace.calculators import mace_off
29+
return mace_off
30+
elif family == 'mace_anicc':
31+
from mace.calculators import mace_anicc
32+
return mace_anicc
33+
else:
34+
raise ValueError(f"Unsupported MACE family: {family}")
35+
# elif calculator == 'uma':
36+
# elif calculator == 'nwx':
37+
# from ase.calculators.nwchem import NWChem
38+
# return NWChem
39+
40+
41+
def from_calc_dictionary(input_dct, script_str):
42+
""" Run an ASE calculation from an input dictionary and return results
43+
44+
Args:
45+
input_dct: Dictionary containing:
46+
- calculator: Name of ASE calculator to use ('psi4' or 'nwx')
47+
- atom_parameters:
48+
- symbols: List of atomic symbols
49+
- positions: List of [x,y,z] coordinates
50+
- basis: Basis set name
51+
- method: Quantum chemistry method
52+
- charge: Molecular charge
53+
- multiplicity: Spin multiplicity
54+
- reference: Reference type ('rhf', 'uhf', etc.)
55+
script_str: String of script that contains ase_<calculator>
56+
Returns:
57+
dict: Results containing:
58+
- energy: Total energy in eV
59+
- forces: Forces on atoms in eV/Å
60+
- dipole: Dipole moment in e⋅Å
61+
- charges: Atomic charges
62+
- parameters: Input parameters used
63+
- version: Version string of the program used
64+
"""
65+
# Set up Atoms object
66+
atom_parameters = input_dct.get('atom_parameters', {})
67+
job = input_dct.get('job', 'energy')
68+
atoms = Atoms(
69+
symbols=atom_parameters.get('symbols', []),
70+
positions=atom_parameters.get('positions', []),
71+
)
72+
# Set up calculator
73+
calculator = None
74+
if len(script_str.split('ase_')) > 1:
75+
calculator = script_str.split('ase_')[-1].strip()
76+
if 'method' in input_dct.keys():
77+
kwargs = {
78+
'atoms': atoms,
79+
'basis': input_dct.get('basis'),
80+
'method': input_dct.get('method'),
81+
'charge': input_dct.get('charge', 0),
82+
'multiplicity': input_dct.get('multiplicity', 1),
83+
'reference': input_dct.get('reference')
84+
}
85+
family = None
86+
elif 'family' in input_dct.keys():
87+
family = input_dct.get('family')
88+
if 'mlip' in input_dct.keys():
89+
kwargs = {
90+
'model': input_dct.get('mlip'),
91+
'device': 'cpu'
92+
}
93+
else:
94+
kwargs = {}
95+
calc = get_calculator(calculator, family)(**kwargs)
96+
97+
# Attach calculator to atoms and run
98+
atoms.calc = calc
99+
100+
# Run calculation and get basic properties that all calculators support
101+
if job == 'energy':
102+
energy = atoms.get_potential_energy() / units.Hartree
103+
positions = atoms.get_positions()
104+
elif job == 'optimize':
105+
if not input_dct.get('saddle', False):
106+
dyn = BFGS(atoms)
107+
# else:
108+
# dyn = Sella(atoms)
109+
dyn.run(
110+
fmax=input_dct.get('gconv', 0.05),
111+
steps=input_dct.get('maxcyc', 100),
112+
opt_type=input_dct.get('opt_type', 'standard')
113+
)
114+
energy = atoms.get_potential_energy()
115+
positions = atoms.get_positions()
116+
117+
# Initialize results with guaranteed properties
118+
results = {
119+
'energy': energy,
120+
'positions': positions,
121+
# 'forces': forces.tolist(),
122+
'parameters': dict(calc.parameters) # Convert to regular dict for JSON serialization
123+
}
124+
125+
# Add calculator-specific properties based on calculator type
126+
if calculator == 'psi4':
127+
results['version'] = f'ase-psi4'
128+
results['normal_termination'] = True
129+
# Psi4 supports both dipole moments and charges
130+
# if 'properties' not in calc.parameters or 'dipole' in calc.parameters.get('properties', []):
131+
# results['dipole'] = atoms.get_dipole_moment().tolist()
132+
# if 'properties' not in calc.parameters or 'mulliken' in calc.parameters.get('properties', []):
133+
# results['charges'] = atoms.get_charges().tolist()
134+
elif 'mace' in calculator:
135+
# MACE specific properties
136+
results['version'] = f'ase-mace'
137+
results['normal_termination'] = True
138+
# MACE does not provide dipole moments or charges
139+
# if 'properties' not in calc.parameters or 'dipole' in calc.parameters.get('properties', []):
140+
# results['dipole'] = atoms.get_dipole_moment().tolist()
141+
# if 'properties' not in calc.parameters or 'mulliken' in calc.parameters.get('properties', []):
142+
# results['charges'] = atoms.get_charges().tolist()
143+
elif calculator == 'nwx':
144+
# NWChem specific properties
145+
pass
146+
# if 'dipole_moment' in calc.get_implemented_properties():
147+
# results['dipole'] = atoms.get_dipole_moment().tolist()
148+
# if 'charges' in calc.get_implemented_properties():
149+
# results['charges'] = atoms.get_charges().tolist()
150+
151+
return results
152+

src/elstruct/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@
1212
from elstruct.par import Program
1313
from elstruct.par import Reference
1414
from elstruct.par import Method
15+
from elstruct.par import Model
1516
from elstruct.par import Basis
1617
from elstruct.par import programs
18+
from elstruct.par import program_models
1719
from elstruct.par import program_methods
1820
from elstruct.par import program_dft_methods
1921
from elstruct.par import program_nondft_methods
2022
from elstruct.par import program_method_orbital_types
2123
from elstruct.par import program_bases
24+
from elstruct.par import method_is_mlip
25+
from elstruct.par import mlip_from_method
2226

2327

2428
__all__ = [
@@ -34,11 +38,15 @@
3438
'Program',
3539
'Reference',
3640
'Method',
41+
'Model',
3742
'Basis',
3843
'programs',
44+
'program_models',
3945
'program_methods',
4046
'program_dft_methods',
4147
'program_nondft_methods',
4248
'program_method_orbital_types',
4349
'program_bases',
50+
'method_is_mlip',
51+
'mlip_from_method'
4452
]

0 commit comments

Comments
 (0)