Skip to content

Commit ede08ca

Browse files
Fix solutes with v-sites (#76)
1 parent 697a106 commit ede08ca

File tree

10 files changed

+331
-149
lines changed

10 files changed

+331
-149
lines changed

absolv/fep.py

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Prepare OpenMM systems for FEP calculations."""
2+
23
import copy
34
import itertools
45

@@ -22,54 +23,6 @@
2223
)
2324

2425

25-
def _find_v_sites(
26-
system: openmm.System, atom_indices: list[set[int]]
27-
) -> list[set[int]]:
28-
"""Finds any virtual sites in the system and ensures their indices get appended
29-
to the atom index list.
30-
31-
Args:
32-
system: The system that may contain v-sites.
33-
atom_indices: A list of per-molecule atom indices
34-
35-
Returns:
36-
A list of the per molecule **particle** indices.
37-
"""
38-
39-
atom_to_molecule_idx = {
40-
atom_idx: i for i, indices in enumerate(atom_indices) for atom_idx in indices
41-
}
42-
43-
particle_to_atom_idx = {}
44-
atom_idx = 0
45-
46-
for particle_idx in range(system.getNumParticles()):
47-
if system.isVirtualSite(particle_idx):
48-
continue
49-
50-
particle_to_atom_idx[particle_idx] = atom_idx
51-
atom_idx += 1
52-
53-
atom_idx = 0
54-
55-
remapped_atom_indices: list[set[int]] = [set() for _ in range(len(atom_indices))]
56-
57-
for particle_idx in range(system.getNumParticles()):
58-
if not system.isVirtualSite(particle_idx):
59-
molecule_idx = atom_to_molecule_idx[atom_idx]
60-
atom_idx += 1
61-
62-
else:
63-
v_site = system.getVirtualSite(particle_idx)
64-
parent_atom_idx = particle_to_atom_idx[v_site.getParticle(0)]
65-
66-
molecule_idx = atom_to_molecule_idx[parent_atom_idx]
67-
68-
remapped_atom_indices[molecule_idx].add(particle_idx)
69-
70-
return remapped_atom_indices
71-
72-
7326
def _find_nonbonded_forces(
7427
system: openmm.System,
7528
) -> tuple[
@@ -468,7 +421,7 @@ def apply_fep(
468421
system: The chemical system to generate the alchemical system from
469422
alchemical_indices: The atom indices corresponding to each molecule that
470423
should be alchemically transformable. The atom indices **must**
471-
correspond to **all** atoms in each molecule as alchemically
424+
correspond to **all** atoms / v-sites in each molecule as alchemically
472425
transforming part of a molecule is not supported.
473426
persistent_indices: The atom indices corresponding to each molecule that
474427
should **not** be alchemically transformable.
@@ -481,15 +434,6 @@ def apply_fep(
481434

482435
system = copy.deepcopy(system)
483436

484-
# Make sure we track v-sites attached to any solutes that may be alchemically
485-
# turned off. We do this as a post-process step as the OpenFF toolkit does not
486-
# currently expose a clean way to access this information.
487-
atom_indices = alchemical_indices + persistent_indices
488-
atom_indices = _find_v_sites(system, atom_indices)
489-
490-
alchemical_indices = atom_indices[: len(alchemical_indices)]
491-
persistent_indices = atom_indices[len(alchemical_indices) :]
492-
493437
(
494438
nonbonded_force,
495439
custom_nonbonded_force,

absolv/runner.py

Lines changed: 158 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Run calculations defined by a config."""
22

3+
import collections
34
import functools
45
import multiprocessing
56
import pathlib
@@ -17,6 +18,7 @@
1718
import openff.toolkit
1819
import openff.utilities
1920
import openmm
21+
import openmm.app
2022
import openmm.unit
2123
import pymbar
2224
import tqdm
@@ -35,12 +37,152 @@ class PreparedSystem(typing.NamedTuple):
3537
system: openmm.System
3638
"""The alchemically modified OpenMM system."""
3739

38-
topology: openff.toolkit.Topology
39-
"""The OpenFF topology with any box vectors set."""
40+
topology: openmm.app.Topology
41+
"""The OpenMM topology with any box vectors set."""
4042
coords: openmm.unit.Quantity
4143
"""The coordinates of the system."""
4244

4345

46+
def _rebuild_topology(
47+
orig_top: openff.toolkit.Topology,
48+
orig_coords: openmm.unit.Quantity,
49+
system: openmm.System,
50+
) -> tuple[openmm.app.Topology, openmm.unit.Quantity, list[set[int]]]:
51+
"""Rebuild the topology to also include virtual sites."""
52+
atom_idx_to_residue_idx = {}
53+
atom_idx = 0
54+
55+
for residue_idx, molecule in enumerate(orig_top.molecules):
56+
for _ in molecule.atoms:
57+
atom_idx_to_residue_idx[atom_idx] = residue_idx
58+
atom_idx += 1
59+
60+
particle_idx_to_atom_idx = {}
61+
atom_idx = 0
62+
63+
for particle_idx in range(system.getNumParticles()):
64+
if system.isVirtualSite(particle_idx):
65+
continue
66+
67+
particle_idx_to_atom_idx[particle_idx] = atom_idx
68+
atom_idx += 1
69+
70+
atoms_off = [*orig_top.atoms]
71+
particles = []
72+
73+
for particle_idx in range(system.getNumParticles()):
74+
if system.isVirtualSite(particle_idx):
75+
v_site = system.getVirtualSite(particle_idx)
76+
77+
parent_idxs = {
78+
particle_idx_to_atom_idx[v_site.getParticle(i)]
79+
for i in range(v_site.getNumParticles())
80+
}
81+
parent_residue = atom_idx_to_residue_idx[next(iter(parent_idxs))]
82+
83+
particles.append((-1, parent_residue))
84+
continue
85+
86+
atom_idx = particle_idx_to_atom_idx[particle_idx]
87+
residue_idx = atom_idx_to_residue_idx[atom_idx]
88+
89+
particles.append((atoms_off[atom_idx].atomic_number, residue_idx))
90+
91+
topology = openmm.app.Topology()
92+
93+
if orig_top.box_vectors is not None:
94+
topology.setPeriodicBoxVectors(orig_top.box_vectors.to_openmm())
95+
96+
chain = topology.addChain()
97+
98+
atom_counts_per_residue = collections.defaultdict(
99+
lambda: collections.defaultdict(int)
100+
)
101+
atoms = []
102+
103+
last_residue_idx = -1
104+
residue = None
105+
106+
residue_to_particle_idx = collections.defaultdict(list)
107+
108+
for particle_idx, (atomic_num, residue_idx) in enumerate(particles):
109+
if residue_idx != last_residue_idx:
110+
last_residue_idx = residue_idx
111+
residue = topology.addResidue("UNK", chain)
112+
113+
element = (
114+
None if atomic_num < 0 else openmm.app.Element.getByAtomicNumber(atomic_num)
115+
)
116+
symbol = "X" if element is None else element.symbol
117+
118+
atom_counts_per_residue[residue_idx][atomic_num] += 1
119+
atom = topology.addAtom(
120+
f"{symbol}{atom_counts_per_residue[residue_idx][atomic_num]}".ljust(3, "x"),
121+
element,
122+
residue,
123+
)
124+
atoms.append(atom)
125+
126+
residue_to_particle_idx[residue_idx].append(particle_idx)
127+
128+
_rename_residues(topology)
129+
130+
atom_idx_to_particle_idx = {j: i for i, j in particle_idx_to_atom_idx.items()}
131+
132+
for bond in orig_top.bonds:
133+
if atoms[atom_idx_to_particle_idx[bond.atom1_index]].residue.name == "HOH":
134+
continue
135+
136+
topology.addBond(
137+
atoms[atom_idx_to_particle_idx[bond.atom1_index]],
138+
atoms[atom_idx_to_particle_idx[bond.atom2_index]],
139+
)
140+
141+
coords_full = []
142+
143+
for particle_idx in range(system.getNumParticles()):
144+
if particle_idx in particle_idx_to_atom_idx:
145+
coords_i = orig_coords[particle_idx_to_atom_idx[particle_idx]]
146+
coords_full.append(coords_i.value_in_unit(openmm.unit.angstrom))
147+
else:
148+
coords_full.append(numpy.zeros((1, 3)))
149+
150+
coords_full = numpy.vstack(coords_full) * openmm.unit.angstrom
151+
152+
if len(orig_coords) != len(coords_full):
153+
context = openmm.Context(system, openmm.VerletIntegrator(1.0))
154+
context.setPositions(coords_full)
155+
context.computeVirtualSites()
156+
157+
coords_full = context.getState(getPositions=True).getPositions(asNumpy=True)
158+
159+
residues = [
160+
set(residue_to_particle_idx[residue_idx])
161+
for residue_idx in range(len(residue_to_particle_idx))
162+
]
163+
164+
return topology, coords_full, residues
165+
166+
167+
def _rename_residues(topology: openmm.app.Topology):
168+
"""Attempts to assign standard residue names to known residues"""
169+
170+
for residue in topology.residues():
171+
symbols = sorted(
172+
(
173+
atom.element.symbol
174+
for atom in residue.atoms()
175+
if atom.element is not None
176+
)
177+
)
178+
179+
if symbols == ["H", "H", "O"]:
180+
residue.name = "HOH"
181+
182+
for i, atom in enumerate(residue.atoms()):
183+
atom.name = "OW" if atom.element.symbol == "O" else f"HW{i}"
184+
185+
44186
def _setup_solvent(
45187
solvent_idx: typing.Literal["solvent-a", "solvent-b"],
46188
components: list[tuple[str, int]],
@@ -67,19 +209,21 @@ def _setup_solvent(
67209

68210
is_vacuum = n_solvent_molecules == 0
69211

70-
topology, coords = absolv.setup.setup_system(components)
71-
topology.box_vectors = None if is_vacuum else topology.box_vectors
212+
topology_off, coords = absolv.setup.setup_system(components)
213+
topology_off.box_vectors = None if is_vacuum else topology_off.box_vectors
214+
215+
if isinstance(force_field, openff.toolkit.ForceField):
216+
original_system = force_field.create_openmm_system(topology_off)
217+
else:
218+
original_system: openmm.System = force_field(topology_off, coords, solvent_idx)
72219

73-
atom_indices = absolv.utils.topology.topology_to_atom_indices(topology)
220+
topology, coords, atom_indices = _rebuild_topology(
221+
topology_off, coords, original_system
222+
)
74223

75224
alchemical_indices = atom_indices[:n_solute_molecules]
76225
persistent_indices = atom_indices[n_solute_molecules:]
77226

78-
if isinstance(force_field, openff.toolkit.ForceField):
79-
original_system = force_field.create_openmm_system(topology)
80-
else:
81-
original_system: openmm.System = force_field(topology, coords, solvent_idx)
82-
83227
alchemical_system = absolv.fep.apply_fep(
84228
original_system,
85229
alchemical_indices,
@@ -196,7 +340,7 @@ def _run_eq_phase(
196340
"""
197341
platform = (
198342
femto.md.constants.OpenMMPlatform.REFERENCE
199-
if prepared_system.topology.box_vectors is None
343+
if prepared_system.topology.getPeriodicBoxVectors() is None
200344
else platform
201345
)
202346

@@ -312,7 +456,7 @@ def _run_phase_end_states(
312456
):
313457
platform = (
314458
femto.md.constants.OpenMMPlatform.REFERENCE
315-
if prepared_system.topology.box_vectors is None
459+
if prepared_system.topology.getPeriodicBoxVectors() is None
316460
else platform
317461
)
318462

@@ -363,11 +507,11 @@ def _run_switching(
363507
):
364508
platform = (
365509
femto.md.constants.OpenMMPlatform.REFERENCE
366-
if prepared_system.topology.box_vectors is None
510+
if prepared_system.topology.getPeriodicBoxVectors() is None
367511
else platform
368512
)
369513

370-
mdtraj_topology = mdtraj.Topology.from_openmm(prepared_system.topology.to_openmm())
514+
mdtraj_topology = mdtraj.Topology.from_openmm(prepared_system.topology)
371515

372516
trajectory_0 = mdtraj.load_dcd(str(output_dir / "state-0.dcd"), mdtraj_topology)
373517
trajectory_1 = mdtraj.load_dcd(str(output_dir / "state-1.dcd"), mdtraj_topology)

absolv/tests/test_fep.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,11 @@
1313
_add_electrostatics_lambda,
1414
_add_lj_vdw_lambda,
1515
_find_nonbonded_forces,
16-
_find_v_sites,
1716
apply_fep,
1817
)
1918
from absolv.tests import is_close
2019

2120

22-
def test_find_v_sites():
23-
"""Ensure that v-sites are correctly detected from an OMM system and assigned
24-
to the right parent molecule."""
25-
26-
# Construct a mock system of V A A A V A A where (0, 5, 6), (3,), (4, 1, 2)
27-
# are the core molecules.
28-
system = openmm.System()
29-
30-
for _ in range(7):
31-
system.addParticle(1.0)
32-
33-
system.setVirtualSite(0, openmm.TwoParticleAverageSite(5, 6, 0.5, 0.5))
34-
system.setVirtualSite(4, openmm.TwoParticleAverageSite(1, 2, 0.5, 0.5))
35-
36-
atom_indices = [{0, 1}, {2}, {3, 4}]
37-
38-
particle_indices = _find_v_sites(system, atom_indices)
39-
40-
assert particle_indices == [{1, 2, 4}, {3}, {0, 5, 6}]
41-
42-
4321
def test_find_nonbonded_forces_lj_only(aq_nacl_lj_system):
4422
(
4523
nonbonded_force,

0 commit comments

Comments
 (0)