1
1
"""Run calculations defined by a config."""
2
2
3
+ import collections
3
4
import functools
4
5
import multiprocessing
5
6
import pathlib
17
18
import openff .toolkit
18
19
import openff .utilities
19
20
import openmm
21
+ import openmm .app
20
22
import openmm .unit
21
23
import pymbar
22
24
import tqdm
@@ -35,12 +37,152 @@ class PreparedSystem(typing.NamedTuple):
35
37
system : openmm .System
36
38
"""The alchemically modified OpenMM system."""
37
39
38
- topology : openff . toolkit .Topology
39
- """The OpenFF topology with any box vectors set."""
40
+ topology : openmm . app .Topology
41
+ """The OpenMM topology with any box vectors set."""
40
42
coords : openmm .unit .Quantity
41
43
"""The coordinates of the system."""
42
44
43
45
46
+ def _rebuild_topology (
47
+ orig_top : openff .toolkit .Topology ,
48
+ orig_coords : openmm .unit .Quantity ,
49
+ system : openmm .System ,
50
+ ) -> tuple [openmm .app .Topology , openmm .unit .Quantity , list [set [int ]]]:
51
+ """Rebuild the topology to also include virtual sites."""
52
+ atom_idx_to_residue_idx = {}
53
+ atom_idx = 0
54
+
55
+ for residue_idx , molecule in enumerate (orig_top .molecules ):
56
+ for _ in molecule .atoms :
57
+ atom_idx_to_residue_idx [atom_idx ] = residue_idx
58
+ atom_idx += 1
59
+
60
+ particle_idx_to_atom_idx = {}
61
+ atom_idx = 0
62
+
63
+ for particle_idx in range (system .getNumParticles ()):
64
+ if system .isVirtualSite (particle_idx ):
65
+ continue
66
+
67
+ particle_idx_to_atom_idx [particle_idx ] = atom_idx
68
+ atom_idx += 1
69
+
70
+ atoms_off = [* orig_top .atoms ]
71
+ particles = []
72
+
73
+ for particle_idx in range (system .getNumParticles ()):
74
+ if system .isVirtualSite (particle_idx ):
75
+ v_site = system .getVirtualSite (particle_idx )
76
+
77
+ parent_idxs = {
78
+ particle_idx_to_atom_idx [v_site .getParticle (i )]
79
+ for i in range (v_site .getNumParticles ())
80
+ }
81
+ parent_residue = atom_idx_to_residue_idx [next (iter (parent_idxs ))]
82
+
83
+ particles .append ((- 1 , parent_residue ))
84
+ continue
85
+
86
+ atom_idx = particle_idx_to_atom_idx [particle_idx ]
87
+ residue_idx = atom_idx_to_residue_idx [atom_idx ]
88
+
89
+ particles .append ((atoms_off [atom_idx ].atomic_number , residue_idx ))
90
+
91
+ topology = openmm .app .Topology ()
92
+
93
+ if orig_top .box_vectors is not None :
94
+ topology .setPeriodicBoxVectors (orig_top .box_vectors .to_openmm ())
95
+
96
+ chain = topology .addChain ()
97
+
98
+ atom_counts_per_residue = collections .defaultdict (
99
+ lambda : collections .defaultdict (int )
100
+ )
101
+ atoms = []
102
+
103
+ last_residue_idx = - 1
104
+ residue = None
105
+
106
+ residue_to_particle_idx = collections .defaultdict (list )
107
+
108
+ for particle_idx , (atomic_num , residue_idx ) in enumerate (particles ):
109
+ if residue_idx != last_residue_idx :
110
+ last_residue_idx = residue_idx
111
+ residue = topology .addResidue ("UNK" , chain )
112
+
113
+ element = (
114
+ None if atomic_num < 0 else openmm .app .Element .getByAtomicNumber (atomic_num )
115
+ )
116
+ symbol = "X" if element is None else element .symbol
117
+
118
+ atom_counts_per_residue [residue_idx ][atomic_num ] += 1
119
+ atom = topology .addAtom (
120
+ f"{ symbol } { atom_counts_per_residue [residue_idx ][atomic_num ]} " .ljust (3 , "x" ),
121
+ element ,
122
+ residue ,
123
+ )
124
+ atoms .append (atom )
125
+
126
+ residue_to_particle_idx [residue_idx ].append (particle_idx )
127
+
128
+ _rename_residues (topology )
129
+
130
+ atom_idx_to_particle_idx = {j : i for i , j in particle_idx_to_atom_idx .items ()}
131
+
132
+ for bond in orig_top .bonds :
133
+ if atoms [atom_idx_to_particle_idx [bond .atom1_index ]].residue .name == "HOH" :
134
+ continue
135
+
136
+ topology .addBond (
137
+ atoms [atom_idx_to_particle_idx [bond .atom1_index ]],
138
+ atoms [atom_idx_to_particle_idx [bond .atom2_index ]],
139
+ )
140
+
141
+ coords_full = []
142
+
143
+ for particle_idx in range (system .getNumParticles ()):
144
+ if particle_idx in particle_idx_to_atom_idx :
145
+ coords_i = orig_coords [particle_idx_to_atom_idx [particle_idx ]]
146
+ coords_full .append (coords_i .value_in_unit (openmm .unit .angstrom ))
147
+ else :
148
+ coords_full .append (numpy .zeros ((1 , 3 )))
149
+
150
+ coords_full = numpy .vstack (coords_full ) * openmm .unit .angstrom
151
+
152
+ if len (orig_coords ) != len (coords_full ):
153
+ context = openmm .Context (system , openmm .VerletIntegrator (1.0 ))
154
+ context .setPositions (coords_full )
155
+ context .computeVirtualSites ()
156
+
157
+ coords_full = context .getState (getPositions = True ).getPositions (asNumpy = True )
158
+
159
+ residues = [
160
+ set (residue_to_particle_idx [residue_idx ])
161
+ for residue_idx in range (len (residue_to_particle_idx ))
162
+ ]
163
+
164
+ return topology , coords_full , residues
165
+
166
+
167
+ def _rename_residues (topology : openmm .app .Topology ):
168
+ """Attempts to assign standard residue names to known residues"""
169
+
170
+ for residue in topology .residues ():
171
+ symbols = sorted (
172
+ (
173
+ atom .element .symbol
174
+ for atom in residue .atoms ()
175
+ if atom .element is not None
176
+ )
177
+ )
178
+
179
+ if symbols == ["H" , "H" , "O" ]:
180
+ residue .name = "HOH"
181
+
182
+ for i , atom in enumerate (residue .atoms ()):
183
+ atom .name = "OW" if atom .element .symbol == "O" else f"HW{ i } "
184
+
185
+
44
186
def _setup_solvent (
45
187
solvent_idx : typing .Literal ["solvent-a" , "solvent-b" ],
46
188
components : list [tuple [str , int ]],
@@ -67,19 +209,21 @@ def _setup_solvent(
67
209
68
210
is_vacuum = n_solvent_molecules == 0
69
211
70
- topology , coords = absolv .setup .setup_system (components )
71
- topology .box_vectors = None if is_vacuum else topology .box_vectors
212
+ topology_off , coords = absolv .setup .setup_system (components )
213
+ topology_off .box_vectors = None if is_vacuum else topology_off .box_vectors
214
+
215
+ if isinstance (force_field , openff .toolkit .ForceField ):
216
+ original_system = force_field .create_openmm_system (topology_off )
217
+ else :
218
+ original_system : openmm .System = force_field (topology_off , coords , solvent_idx )
72
219
73
- atom_indices = absolv .utils .topology .topology_to_atom_indices (topology )
220
+ topology , coords , atom_indices = _rebuild_topology (
221
+ topology_off , coords , original_system
222
+ )
74
223
75
224
alchemical_indices = atom_indices [:n_solute_molecules ]
76
225
persistent_indices = atom_indices [n_solute_molecules :]
77
226
78
- if isinstance (force_field , openff .toolkit .ForceField ):
79
- original_system = force_field .create_openmm_system (topology )
80
- else :
81
- original_system : openmm .System = force_field (topology , coords , solvent_idx )
82
-
83
227
alchemical_system = absolv .fep .apply_fep (
84
228
original_system ,
85
229
alchemical_indices ,
@@ -196,7 +340,7 @@ def _run_eq_phase(
196
340
"""
197
341
platform = (
198
342
femto .md .constants .OpenMMPlatform .REFERENCE
199
- if prepared_system .topology .box_vectors is None
343
+ if prepared_system .topology .getPeriodicBoxVectors () is None
200
344
else platform
201
345
)
202
346
@@ -312,7 +456,7 @@ def _run_phase_end_states(
312
456
):
313
457
platform = (
314
458
femto .md .constants .OpenMMPlatform .REFERENCE
315
- if prepared_system .topology .box_vectors is None
459
+ if prepared_system .topology .getPeriodicBoxVectors () is None
316
460
else platform
317
461
)
318
462
@@ -363,11 +507,11 @@ def _run_switching(
363
507
):
364
508
platform = (
365
509
femto .md .constants .OpenMMPlatform .REFERENCE
366
- if prepared_system .topology .box_vectors is None
510
+ if prepared_system .topology .getPeriodicBoxVectors () is None
367
511
else platform
368
512
)
369
513
370
- mdtraj_topology = mdtraj .Topology .from_openmm (prepared_system .topology . to_openmm () )
514
+ mdtraj_topology = mdtraj .Topology .from_openmm (prepared_system .topology )
371
515
372
516
trajectory_0 = mdtraj .load_dcd (str (output_dir / "state-0.dcd" ), mdtraj_topology )
373
517
trajectory_1 = mdtraj .load_dcd (str (output_dir / "state-1.dcd" ), mdtraj_topology )
0 commit comments