Skip to content

Commit 882a986

Browse files
author
Charlles Abreu
authored
Improves public API (#71)
* Improved getEffectiveMass example * Removed setUnit method * Removed getArgument method from public API * Added example for getValue method * Included addToSystem method to all CVs * Refactored tests and doctests * Fixed code formatting
1 parent fc2eb1a commit 882a986

20 files changed

+159
-206
lines changed

cvpack/atomic_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ class AtomicFunction(openmm.CustomCompoundBondForce, BaseCustomFunction):
9393
... k = 1000 * unit.kilojoules_per_mole/unit.radian**2,
9494
... theta0 = [np.pi/2, np.pi/3] * unit.radian,
9595
... )
96-
>>> [model.system.addForce(f) for f in [angle1, angle2, colvar]]
97-
[5, 6, 7]
96+
>>> for cv in [angle1, angle2, colvar]:
97+
... cv.addToSystem(model.system)
9898
>>> integrator = openmm.VerletIntegrator(0)
9999
>>> platform = openmm.Platform.getPlatformByName('Reference')
100100
>>> context = openmm.Context(model.system, integrator, platform)

cvpack/attraction_strength.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,7 @@ class AttractionStrength(openmm.CustomNonbondedForce, BaseCollectiveVariable):
128128
>>> host = [a.index for a in model.topology.atoms() if a.residue.name == "CUC"]
129129
>>> forces = {f.getName(): f for f in model.system.getForces()}
130130
>>> cv1 = cvpack.AttractionStrength(guest, host, forces["NonbondedForce"])
131-
>>> _ = cv1.setUnusedForceGroup(0, model.system)
132-
>>> _ = model.system.addForce(cv1)
131+
>>> cv1.addToSystem(model.system)
133132
>>> platform = openmm.Platform.getPlatformByName("Reference")
134133
>>> integrator = openmm.VerletIntegrator(1.0 * mmunit.femtoseconds)
135134
>>> context = openmm.Context(model.system, integrator, platform)
@@ -139,15 +138,13 @@ class AttractionStrength(openmm.CustomNonbondedForce, BaseCollectiveVariable):
139138
140139
>>> water = [a.index for a in model.topology.atoms() if a.residue.name == "HOH"]
141140
>>> cv2 = cvpack.AttractionStrength(guest, water, forces["NonbondedForce"])
142-
>>> _ = cv2.setUnusedForceGroup(0, model.system)
143-
>>> _ = model.system.addForce(cv2)
141+
>>> cv2.addToSystem(model.system)
144142
>>> context.reinitialize(preserveState=True)
145143
>>> print(cv2.getValue(context))
146144
2063.3... dimensionless
147145
148146
>>> cv3 = cvpack.AttractionStrength(guest, host, forces["NonbondedForce"], water)
149-
>>> _ = cv3.setUnusedForceGroup(0, model.system)
150-
>>> _ = model.system.addForce(cv3)
147+
>>> cv3.addToSystem(model.system)
151148
>>> context.reinitialize(preserveState=True)
152149
>>> print(cv3.getValue(context))
153150
2849.17... dimensionless
@@ -157,8 +154,7 @@ class AttractionStrength(openmm.CustomNonbondedForce, BaseCollectiveVariable):
157154
>>> cv4 = cvpack.AttractionStrength(
158155
... guest, host, forces["NonbondedForce"], water, contrastScaling=0.5
159156
... )
160-
>>> _ = cv4.setUnusedForceGroup(0, model.system)
161-
>>> _ = model.system.addForce(cv4)
157+
>>> cv4.addToSystem(model.system)
162158
>>> context.reinitialize(preserveState=True)
163159
>>> print(cv4.getValue(context))
164160
3880.8... dimensionless

cvpack/centroid_function.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,7 @@ class CentroidFunction(openmm.CustomCentroidBondForce, BaseCustomFunction):
121121
>>> res_coord = cvpack.ResidueCoordination(
122122
... residues[115:124], residues[126:135], stepFunction="step(1-x)"
123123
... )
124-
>>> res_coord.setUnusedForceGroup(0, model.system)
125-
1
126-
>>> model.system.addForce(res_coord)
127-
6
124+
>>> res_coord.addToSystem(model.system)
128125
>>> integrator = openmm.VerletIntegrator(0)
129126
>>> platform = openmm.Platform.getPlatformByName('Reference')
130127
>>> context = openmm.Context(model.system, integrator, platform)
@@ -142,10 +139,7 @@ class CentroidFunction(openmm.CustomCentroidBondForce, BaseCustomFunction):
142139
... atoms[115:124] + atoms[126:135],
143140
... list(it.product(range(9), range(9, 18))),
144141
... )
145-
>>> colvar.setUnusedForceGroup(0, model.system)
146-
2
147-
>>> model.system.addForce(colvar)
148-
7
142+
>>> colvar.addToSystem(model.system)
149143
>>> context.reinitialize(preserveState=True)
150144
>>> print(colvar.getValue(context))
151145
33.0 dimensionless

cvpack/composite_rmsd.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,7 @@ class CompositeRMSD(CompositeRMSDForce, BaseCollectiveVariable):
105105
... )
106106
... except ImportError:
107107
... pytest.skip("openmm-cpp-forces is not installed")
108-
>>> composite_rmsd.setUnusedForceGroup(0, model.system)
109-
1
110-
>>> model.system.addForce(composite_rmsd)
111-
5
108+
>>> composite_rmsd.addToSystem(model.system)
112109
>>> context = mm.Context(
113110
... model.system,
114111
... mm.VerletIntegrator(1.0 * unit.femtoseconds),

cvpack/cvpack.py

Lines changed: 94 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,9 @@ def _registerCV(
128128
"""
129129
cls = self.__class__
130130
self.setName(cls.__name__)
131-
self.setUnit(unit)
131+
self._unit = unit
132132
self._mass_unit = mmunit.dalton * (mmunit.nanometers / self.getUnit()) ** 2
133-
arguments, _ = self.getArguments()
133+
arguments, _ = self._getArguments()
134134
self._args = dict(zip(arguments, args))
135135
self._args.update(kwargs)
136136

@@ -149,21 +149,22 @@ def _registerPeriod(self, period: float) -> None:
149149
self._period = period
150150

151151
@classmethod
152-
def getArguments(cls) -> t.Tuple[collections.OrderedDict, collections.OrderedDict]:
152+
def _getArguments(cls) -> t.Tuple[collections.OrderedDict, collections.OrderedDict]:
153153
"""
154154
Inspect the arguments needed for constructing an instance of this collective
155155
variable.
156156
157157
Returns
158158
-------
159+
OrderedDict
159160
A dictionary with the type annotations of all arguments
160-
161+
OrderedDict
161162
A dictionary with the default values of optional arguments
162163
163164
Example
164165
-------
165166
>>> import cvpack
166-
>>> args, defaults = cvpack.RadiusOfGyration.getArguments()
167+
>>> args, defaults = cvpack.RadiusOfGyration._getArguments()
167168
>>> for name, annotation in args.items():
168169
... print(f"{name}: {annotation}")
169170
group: typing.Iterable[int]
@@ -180,16 +181,32 @@ def getArguments(cls) -> t.Tuple[collections.OrderedDict, collections.OrderedDic
180181
defaults[name] = parameter.default
181182
return arguments, defaults
182183

183-
def setUnit(self, unit: mmunit.Unit) -> None:
184+
def _setUnusedForceGroup(self, system: openmm.System) -> None:
184185
"""
185-
Set the unit of measurement of this collective variable.
186+
Set the force group of this collective variable to the one at a given position
187+
in the ascending ordered list of unused force groups in an :OpenMM:`System`.
188+
189+
.. note::
190+
191+
Evaluating a collective variable (see :meth:`getValue`) or computing its
192+
effective mass (see :meth:`getEffectiveMass`) is more efficient when the
193+
collective variable is the only force in its own force group.
186194
187195
Parameters
188196
----------
189-
unit
190-
The unit of measurement of this collective variable
197+
system
198+
The system to search for unused force groups
199+
200+
Raises
201+
------
202+
RuntimeError
203+
If all force groups are already in use
191204
"""
192-
self._unit = unit
205+
used_groups = {force.getForceGroup() for force in system.getForces()}
206+
new_group = next(filter(lambda i: i not in used_groups, range(32)), None)
207+
if new_group is None:
208+
raise RuntimeError("All force groups are already in use.")
209+
self.setForceGroup(new_group)
193210

194211
def getUnit(self) -> mmunit.Unit:
195212
"""
@@ -209,42 +226,23 @@ def getPeriod(self) -> t.Optional[mmunit.SerializableQuantity]:
209226
return None
210227
return mmunit.SerializableQuantity(self._period, self.getUnit())
211228

212-
def setUnusedForceGroup(self, position: int, system: openmm.System) -> int:
229+
def addToSystem(
230+
self, system: openmm.System, setUnusedForceGroup: bool = True
231+
) -> None:
213232
"""
214-
Set the force group of this collective variable to the one at a given position
215-
in the ascending ordered list of unused force groups in an :OpenMM:`System`.
216-
217-
.. note::
218-
219-
Evaluating a collective variable (see :meth:`getValue`) or computing its
220-
effective mass (see :meth:`getEffectiveMass`) is more efficient when the
221-
collective variable is the only force in its own force group.
233+
Add this collective variable to an :OpenMM:`System`.
222234
223235
Parameters
224236
----------
225-
position
226-
The position of the force group in the ascending ordered list of unused
227-
force groups in the system
228-
system
229-
The system to search for unused force groups
230-
231-
Returns
232-
-------
233-
The index of the force group that was set
234-
235-
Raises
236-
------
237-
RuntimeError
238-
If all force groups are already in use
237+
system
238+
The system to which this collective variable should be added
239+
setUnusedForceGroup
240+
If True, the force group of this collective variable will be set to the
241+
first available force group in the system
239242
"""
240-
free_groups = sorted(
241-
set(range(32)) - {force.getForceGroup() for force in system.getForces()}
242-
)
243-
if not free_groups:
244-
raise RuntimeError("All force groups are already in use.")
245-
new_group = free_groups[position]
246-
self.setForceGroup(new_group)
247-
return new_group
243+
if setUnusedForceGroup:
244+
self._setUnusedForceGroup(system)
245+
system.addForce(self)
248246

249247
def getValue(self, context: openmm.Context) -> mmunit.Quantity:
250248
"""
@@ -257,12 +255,43 @@ def getValue(self, context: openmm.Context) -> mmunit.Quantity:
257255
258256
Parameters
259257
----------
260-
context
261-
The context at which this collective variable should be evaluated
258+
context
259+
The context at which this collective variable should be evaluated
262260
263261
Returns
264262
-------
263+
unit.Quantity
265264
The value of this collective variable at the given context
265+
266+
267+
Example
268+
-------
269+
In this example, we compute the values of the backbone dihedral angles and
270+
the radius of gyration of an alanine dipeptide molecule in water:
271+
272+
>>> import cvpack
273+
>>> import openmm
274+
>>> from openmmtools import testsystems
275+
>>> model = testsystems.AlanineDipeptideExplicit()
276+
>>> top = model.mdtraj_topology
277+
>>> backbone_atoms = top.select("name N C CA and resid 1 2")
278+
>>> phi = cvpack.Torsion(*backbone_atoms[0:4])
279+
>>> psi = cvpack.Torsion(*backbone_atoms[1:5])
280+
>>> radius_of_gyration = cvpack.RadiusOfGyration(
281+
... top.select('not water')
282+
... )
283+
>>> for cv in [phi, psi, radius_of_gyration]:
284+
... cv.addToSystem(model.system)
285+
>>> context = openmm.Context(
286+
... model.system, openmm.VerletIntegrator(0)
287+
... )
288+
>>> context.setPositions(model.positions)
289+
>>> print(phi.getValue(context))
290+
3.1415... rad
291+
>>> print(psi.getValue(context))
292+
3.1415... rad
293+
>>> print(radius_of_gyration.getValue(context))
294+
0.29514... nm
266295
"""
267296
state = get_single_force_state(self, context, getEnergy=True)
268297
value = value_in_md_units(state.getPotentialEnergy())
@@ -291,36 +320,41 @@ def getEffectiveMass(self, context: openmm.Context) -> mmunit.Quantity:
291320
292321
Parameters
293322
----------
294-
context
295-
The context at which this collective variable's effective mass should be
296-
evaluated
323+
context
324+
The context at which this collective variable's effective mass should be
325+
evaluated
297326
298327
Returns
299328
-------
329+
unit.Quantity
300330
The effective mass of this collective variable at the given context
301331
302332
Example
303333
-------
334+
In this example, we compute the effective masses of the backbone dihedral
335+
angles and the radius of gyration of an alanine dipeptide molecule in water:
336+
304337
>>> import cvpack
305338
>>> import openmm
306339
>>> from openmmtools import testsystems
307-
>>> model = testsystems.AlanineDipeptideImplicit()
308-
>>> peptide = [
309-
... a.index
310-
... for a in model.topology.atoms()
311-
... if a.residue.name != 'HOH'
312-
... ]
313-
>>> radius_of_gyration = cvpack.RadiusOfGyration(peptide)
314-
>>> radius_of_gyration.setForceGroup(1)
315-
>>> radius_of_gyration.setUnusedForceGroup(0, model.system)
316-
1
317-
>>> model.system.addForce(radius_of_gyration)
318-
6
319-
>>> platform = openmm.Platform.getPlatformByName('Reference')
340+
>>> model = testsystems.AlanineDipeptideExplicit()
341+
>>> top = model.mdtraj_topology
342+
>>> backbone_atoms = top.select("name N C CA and resid 1 2")
343+
>>> phi = cvpack.Torsion(*backbone_atoms[0:4])
344+
>>> psi = cvpack.Torsion(*backbone_atoms[1:5])
345+
>>> radius_of_gyration = cvpack.RadiusOfGyration(
346+
... top.select('not water')
347+
... )
348+
>>> for cv in [phi, psi, radius_of_gyration]:
349+
... cv.addToSystem(model.system)
320350
>>> context = openmm.Context(
321-
... model.system,openmm.VerletIntegrator(0), platform
351+
... model.system, openmm.VerletIntegrator(0)
322352
... )
323353
>>> context.setPositions(model.positions)
354+
>>> print(phi.getEffectiveMass(context))
355+
0.05119... nm**2 Da/(rad**2)
356+
>>> print(psi.getEffectiveMass(context))
357+
0.05186... nm**2 Da/(rad**2)
324358
>>> print(radius_of_gyration.getEffectiveMass(context))
325359
30.946... Da
326360
"""

cvpack/helix_angle_content.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,7 @@ class HelixAngleContent(openmm.CustomAngleForce, BaseCollectiveVariable):
8484
>>> print(*[r.name for r in residues]) # doctest: +ELLIPSIS
8585
LYS ASP GLU ... ILE LEU ARG
8686
>>> helix_content = cvpack.HelixAngleContent(residues)
87-
>>> helix_content.setUnusedForceGroup(0, model.system)
88-
1
89-
>>> model.system.addForce(helix_content)
90-
6
87+
>>> helix_content.addToSystem(model.system)
9188
>>> platform = openmm.Platform.getPlatformByName('Reference')
9289
>>> integrator = openmm.VerletIntegrator(0)
9390
>>> context = openmm.Context(model.system, integrator, platform)

cvpack/helix_hbond_content.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,7 @@ class HelixHBondContent(openmm.CustomBondForce, BaseCollectiveVariable):
7272
>>> print(*[r.name for r in residues])
7373
LYS ASP GLU ... ILE LEU ARG
7474
>>> helix_content = cvpack.HelixHBondContent(residues)
75-
>>> helix_content.setUnusedForceGroup(0, model.system)
76-
1
77-
>>> model.system.addForce(helix_content)
78-
6
75+
>>> helix_content.addToSystem(model.system)
7976
>>> platform = openmm.Platform.getPlatformByName('Reference')
8077
>>> integrator = openmm.VerletIntegrator(0)
8178
>>> context = openmm.Context(model.system, integrator, platform)

cvpack/helix_rmsd_content.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,7 @@ class HelixRMSDContent(BaseRMSDContent):
111111
... )
112112
>>> helix_content.getNumResidueBlocks()
113113
16
114-
>>> helix_content.setUnusedForceGroup(0, model.system)
115-
1
116-
>>> model.system.addForce(helix_content)
117-
6
114+
>>> helix_content.addToSystem(model.system)
118115
>>> platform = openmm.Platform.getPlatformByName('Reference')
119116
>>> integrator = openmm.VerletIntegrator(0)
120117
>>> context = openmm.Context(model.system, integrator, platform)

cvpack/helix_torsion_content.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,7 @@ class HelixTorsionContent(openmm.CustomTorsionForce, BaseCollectiveVariable):
9494
>>> print(*[r.name for r in residues])
9595
LYS ASP GLU ... ILE LEU ARG
9696
>>> helix_content = cvpack.HelixTorsionContent(residues)
97-
>>> helix_content.setUnusedForceGroup(0, model.system)
98-
1
99-
>>> model.system.addForce(helix_content)
100-
6
97+
>>> helix_content.addToSystem(model.system)
10198
>>> platform = openmm.Platform.getPlatformByName('Reference')
10299
>>> integrator = openmm.VerletIntegrator(0)
103100
>>> context = openmm.Context(model.system, integrator, platform)

cvpack/number_of_contacts.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,7 @@ class NumberOfContacts(openmm.CustomNonbondedForce, BaseCollectiveVariable):
100100
... forces["NonbondedForce"],
101101
... stepFunction="step(1-x)",
102102
... )
103-
>>> nc.setUnusedForceGroup(0, model.system)
104-
1
105-
>>> model.system.addForce(nc)
106-
5
103+
>>> nc.addToSystem(model.system)
107104
>>> platform = openmm.Platform.getPlatformByName("Reference")
108105
>>> integrator = openmm.VerletIntegrator(1.0 * mmunit.femtoseconds)
109106
>>> context = openmm.Context(model.system, integrator, platform)
@@ -117,10 +114,7 @@ class NumberOfContacts(openmm.CustomNonbondedForce, BaseCollectiveVariable):
117114
... stepFunction="step(1-x)",
118115
... reference=context,
119116
... )
120-
>>> nc_normalized.setUnusedForceGroup(0, model.system)
121-
2
122-
>>> model.system.addForce(nc_normalized)
123-
6
117+
>>> nc_normalized.addToSystem(model.system)
124118
>>> context.reinitialize(preserveState=True)
125119
>>> print(nc_normalized.getValue(context))
126120
0.99999... dimensionless

0 commit comments

Comments
 (0)