Skip to content

Commit

Permalink
Removes 'digits' option from CV evaluators (breaks API) (#69)
Browse files Browse the repository at this point in the history
* Moved evaluation functions to the utils module

* Fixed code formatting

* Using ellipses instead of digits in doctests

* Removed 'digits' option from evaluators

* Removed accidentally added file
  • Loading branch information
Charlles Abreu authored Mar 21, 2024
1 parent 9e933fb commit f08a547
Show file tree
Hide file tree
Showing 19 changed files with 144 additions and 161 deletions.
4 changes: 2 additions & 2 deletions cvpack/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class Angle(openmm.CustomAngleForce, BaseCollectiveVariable):
>>> context = openmm.Context(system, integrator, platform)
>>> positions = [[0, 0, 0], [1, 0, 0], [1, 1, 0]]
>>> context.setPositions([openmm.Vec3(*pos) for pos in positions])
>>> print(angle.getValue(context, digits=6))
1.570796 rad
>>> print(angle.getValue(context))
1.570796... rad
"""

Expand Down
20 changes: 10 additions & 10 deletions cvpack/atomic_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ class AtomicFunction(openmm.CustomCompoundBondForce, BaseCustomFunction):
>>> context.setPositions(model.positions)
>>> theta1 = angle1.getValue(context).value_in_unit(openmm.unit.radian)
>>> theta2 = angle2.getValue(context).value_in_unit(openmm.unit.radian)
>>> print(round(500*((theta1 - np.pi/2)**2 + (theta2 - np.pi/3)**2), 3))
429.479
>>> print(colvar.getValue(context, digits=6))
429.479 kJ/mol
>>> print(500*((theta1 - np.pi/2)**2 + (theta2 - np.pi/3)**2))
429.479...
>>> print(colvar.getValue(context))
429.479... kJ/mol
"""

yaml_tag = "!cvpack.AtomicFunction"
Expand Down Expand Up @@ -328,13 +328,13 @@ def fromOpenMMForce(
>>> for name in copies:
... state = context.getState(getEnergy=True, groups={indices[name]})
... value = state.getPotentialEnergy() / unit.kilojoules_per_mole
... copy_value = copies[name].getValue(context, digits=6)
... copy_value = copies[name].getValue(context)
... print(f"{name}: original={value:.6f}, copy={copy_value}")
HarmonicBondForce: original=2094.312483, copy=2094.312 kJ/mol
HarmonicAngleForce: original=3239.795215, copy=3239.795 kJ/mol
PeriodicTorsionForce: original=4226.051934, copy=4226.052 kJ/mol
CustomExternalForce: original=5.021558, copy=5.021558 kJ/mol
HelixTorsionContent: original=17.452849, copy=17.45285 dimensionless
HarmonicBondForce: original=2094.312..., copy=2094.312... kJ/mol
HarmonicAngleForce: original=3239.795..., copy=3239.795... kJ/mol
PeriodicTorsionForce: original=4226.05..., copy=4226.05... kJ/mol
CustomExternalForce: original=5.02155..., copy=5.02155... kJ/mol
HelixTorsionContent: original=17.4528..., copy=17.4528... dimensionless
"""
if isinstance(
force,
Expand Down
22 changes: 11 additions & 11 deletions cvpack/attraction_strength.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,35 +134,35 @@ class AttractionStrength(openmm.CustomNonbondedForce, BaseCollectiveVariable):
>>> integrator = openmm.VerletIntegrator(1.0 * mmunit.femtoseconds)
>>> context = openmm.Context(model.system, integrator, platform)
>>> context.setPositions(model.positions)
>>> print(cv1.getValue(context, 4))
4912.5 dimensionless
>>> print(cv1.getValue(context))
4912.5... dimensionless
>>> water = [a.index for a in model.topology.atoms() if a.residue.name == "HOH"]
>>> cv2 = cvpack.AttractionStrength(guest, water, forces["NonbondedForce"])
>>> _ = cv2.setUnusedForceGroup(0, model.system)
>>> _ = model.system.addForce(cv2)
>>> context.reinitialize(preserveState=True)
>>> print(cv2.getValue(context, 4))
2063.3 dimensionless
>>> print(cv2.getValue(context))
2063.3... dimensionless
>>> cv3 = cvpack.AttractionStrength(guest, host, forces["NonbondedForce"], water)
>>> _ = cv3.setUnusedForceGroup(0, model.system)
>>> _ = model.system.addForce(cv3)
>>> context.reinitialize(preserveState=True)
>>> print(cv3.getValue(context, 4))
2849.2 dimensionless
>>> print(cv1.getValue(context, 4) - cv2.getValue(context, 4))
2849.2 dimensionless
>>> print(cv3.getValue(context))
2849.17... dimensionless
>>> print(cv1.getValue(context) - cv2.getValue(context))
2849.17... dimensionless
>>> cv4 = cvpack.AttractionStrength(
... guest, host, forces["NonbondedForce"], water, contrastScaling=0.5
... )
>>> _ = cv4.setUnusedForceGroup(0, model.system)
>>> _ = model.system.addForce(cv4)
>>> context.reinitialize(preserveState=True)
>>> print(cv4.getValue(context, 4))
3880.8 dimensionless
>>> print(1 * cv1.getValue(context, 4) - 0.5 * cv2.getValue(context, 4))
>>> print(cv4.getValue(context))
3880.8... dimensionless
>>> print(1 * cv1.getValue(context) - 0.5 * cv2.getValue(context))
3880.8...
"""

Expand Down
116 changes: 8 additions & 108 deletions cvpack/cvpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,17 @@
"""

import collections
import functools
import inspect
import typing as t

import numpy as np
import openmm
import yaml
from openmm import app as mmapp

from cvpack import unit as mmunit

from .unit import value_in_md_units
from .utils import compute_effective_mass, get_single_force_state


class SerializableAtom(yaml.YAMLObject):
Expand Down Expand Up @@ -158,72 +157,6 @@ def _registerPeriod(self, period: float) -> None:
"""
self._period = period

def _getSingleForceState(
self, context: openmm.Context, getEnergy: bool = False, getForces: bool = False
) -> openmm.State:
"""
Get an OpenMM State containing the potential energy and/or force values computed
from this single force object.
Parameters
----------
context
The context from which the state should be extracted
getEnergy
If True, the potential energy will be computed
getForces
If True, the forces will be computed
Raises
------
ValueError
If this force is not present in the given context
"""
forces = context.getSystem().getForces()
if not any(force.this == self.this for force in forces):
raise RuntimeError("This force is not present in the given context.")
self_group = self.getForceGroup()
other_groups = {
force.getForceGroup() for force in forces if force.this != self.this
}
if self_group not in other_groups:
return context.getState(
getEnergy=getEnergy, getForces=getForces, groups=1 << self_group
)
old_group = self.getForceGroup()
new_group = self.setUnusedForceGroup(0, context.getSystem())
context.reinitialize(preserveState=True)
state = context.getState(
getEnergy=getEnergy, getForces=getForces, groups=1 << new_group
)
self.setForceGroup(old_group)
context.reinitialize(preserveState=True)
return state

def _precisionRound(self, number: float, digits: t.Optional[int] = None) -> float:
"""
Round a number to a specified number of precision digits (if specified).
The number of precision digits is defined as the number of digits after the
decimal point of the number's scientific notation representation.
Parameters
----------
number
The number to be rounded
digits
The number of digits to round to. If None, the number will not be
rounded.
Returns
-------
The rounded number
"""
if digits is None:
return number
power = f"{number:e}".split("e")[1]
return round(number, -(int(power) - digits))

@classmethod
def getArguments(cls) -> t.Tuple[collections.OrderedDict, collections.OrderedDict]:
"""
Expand Down Expand Up @@ -322,16 +255,10 @@ def setUnusedForceGroup(self, position: int, system: openmm.System) -> int:
self.setForceGroup(new_group)
return new_group

def getValue(
self, context: openmm.Context, digits: t.Optional[int] = None
) -> mmunit.Quantity:
def getValue(self, context: openmm.Context) -> mmunit.Quantity:
"""
Evaluate this collective variable at a given :OpenMM:`Context`.
Optionally, the value can be rounded to a specified number of precision digits,
which is the number of digits after the decimal point of the value in scientific
notation.
.. note::
This method will be more efficient if the collective variable is the only
Expand All @@ -341,21 +268,16 @@ def getValue(
----------
context
The context at which this collective variable should be evaluated
digits
The number of precision digits to round to. If None, the value will not
be rounded.
Returns
-------
The value of this collective variable at the given context
"""
state = self._getSingleForceState(context, getEnergy=True)
state = get_single_force_state(self, context, getEnergy=True)
value = value_in_md_units(state.getPotentialEnergy())
return mmunit.Quantity(self._precisionRound(value, digits), self.getUnit())
return mmunit.Quantity(value, self.getUnit())

def getEffectiveMass(
self, context: openmm.Context, digits: t.Optional[int] = None
) -> mmunit.Quantity:
def getEffectiveMass(self, context: openmm.Context) -> mmunit.Quantity:
r"""
Compute the effective mass of this collective variable at a given
:OpenMM:`Context`.
Expand All @@ -371,10 +293,6 @@ def getEffectiveMass(
\right\|^2
\right)^{-1}
Optionally, effective mass of this collective variable can be rounded to a
specified number of precision digits, which is the number of digits after the
decimal point of the effective mass in scientific notation.
.. note::
This method will be more efficient if the collective variable is the only
Expand All @@ -385,9 +303,6 @@ def getEffectiveMass(
context
The context at which this collective variable's effective mass should be
evaluated
digits
The number of precision digits to round to. If None, the value will not
be rounded.
Returns
-------
Expand Down Expand Up @@ -415,22 +330,7 @@ def getEffectiveMass(
... model.system,openmm.VerletIntegrator(0), platform
... )
>>> context.setPositions(model.positions)
>>> print(radius_of_gyration.getEffectiveMass(context, digits=6))
30.94693 Da
>>> print(radius_of_gyration.getEffectiveMass(context))
30.946... Da
"""
state = self._getSingleForceState(context, getForces=True)
# pylint: disable=protected-access,c-extension-no-member
get_mass = functools.partial(
openmm._openmm.System_getParticleMass, context.getSystem()
)
force_vectors = state.getForces(asNumpy=True)._value
# pylint: enable=protected-access,c-extension-no-member
squared_forces = np.sum(np.square(force_vectors), axis=1)
nonzeros = np.nonzero(squared_forces)[0]
if nonzeros.size == 0:
return mmunit.Quantity(np.inf, self._mass_unit)
mass_values = np.fromiter(map(get_mass, nonzeros), dtype=np.float64)
effective_mass = 1.0 / np.sum(squared_forces[nonzeros] / mass_values)
return mmunit.Quantity(
self._precisionRound(effective_mass, digits), self._mass_unit
)
return mmunit.Quantity(compute_effective_mass(self, context), self._mass_unit)
4 changes: 2 additions & 2 deletions cvpack/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class Distance(openmm.CustomBondForce, BaseCollectiveVariable):
>>> platform = openmm.Platform.getPlatformByName('Reference')
>>> context = openmm.Context(system, integrator, platform)
>>> context.setPositions([openmm.Vec3(0, 0, 0),openmm.Vec3(1, 1, 1)])
>>> print(distance.getValue(context, digits=5))
1.73205 nm
>>> print(distance.getValue(context))
1.7320... nm
"""

Expand Down
4 changes: 2 additions & 2 deletions cvpack/helix_angle_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ class HelixAngleContent(openmm.CustomAngleForce, BaseCollectiveVariable):
>>> integrator = openmm.VerletIntegrator(0)
>>> context = openmm.Context(model.system, integrator, platform)
>>> context.setPositions(model.positions)
>>> print(helix_content.getValue(context, digits=6))
18.76058 dimensionless
>>> print(helix_content.getValue(context))
18.7605... dimensionless
"""

yaml_tag = "!cvpack.HelixAngleContent"
Expand Down
4 changes: 2 additions & 2 deletions cvpack/helix_hbond_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ class HelixHBondContent(openmm.CustomBondForce, BaseCollectiveVariable):
>>> integrator = openmm.VerletIntegrator(0)
>>> context = openmm.Context(model.system, integrator, platform)
>>> context.setPositions(model.positions)
>>> print(helix_content.getValue(context, digits=6))
15.88038 dimensionless
>>> print(helix_content.getValue(context))
15.880... dimensionless
"""

yaml_tag = "!cvpack.HelixHBondContent"
Expand Down
4 changes: 2 additions & 2 deletions cvpack/helix_rmsd_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ class HelixRMSDContent(BaseRMSDContent):
>>> integrator = openmm.VerletIntegrator(0)
>>> context = openmm.Context(model.system, integrator, platform)
>>> context.setPositions(model.positions)
>>> print(helix_content.getValue(context, digits=4))
15.981 dimensionless
>>> print(helix_content.getValue(context))
15.98... dimensionless
"""

yaml_tag = "!cvpack.HelixRMSDContent"
Expand Down
4 changes: 2 additions & 2 deletions cvpack/helix_torsion_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ class HelixTorsionContent(openmm.CustomTorsionForce, BaseCollectiveVariable):
>>> integrator = openmm.VerletIntegrator(0)
>>> context = openmm.Context(model.system, integrator, platform)
>>> context.setPositions(model.positions)
>>> print(helix_content.getValue(context, digits=6))
17.45285 dimensionless
>>> print(helix_content.getValue(context))
17.452... dimensionless
"""

yaml_tag = "!cvpack.HelixTorsionContent"
Expand Down
8 changes: 4 additions & 4 deletions cvpack/number_of_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ class NumberOfContacts(openmm.CustomNonbondedForce, BaseCollectiveVariable):
>>> integrator = openmm.VerletIntegrator(1.0 * mmunit.femtoseconds)
>>> context = openmm.Context(model.system, integrator, platform)
>>> context.setPositions(model.positions)
>>> print(nc.getValue(context, 4))
30.0 dimensionless
>>> print(nc.getValue(context))
30.0... dimensionless
>>> nc_normalized = cvpack.NumberOfContacts(
... group1,
... group2,
Expand All @@ -122,8 +122,8 @@ class NumberOfContacts(openmm.CustomNonbondedForce, BaseCollectiveVariable):
>>> model.system.addForce(nc_normalized)
6
>>> context.reinitialize(preserveState=True)
>>> print(nc_normalized.getValue(context, 4))
1.0 dimensionless
>>> print(nc_normalized.getValue(context))
0.99999... dimensionless
"""

yaml_tag = "!cvpack.NumberOfContacts"
Expand Down
4 changes: 2 additions & 2 deletions cvpack/radius_of_gyration.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ class RadiusOfGyration(BaseRadiusOfGyration):
>>> integrator = openmm.VerletIntegrator(0)
>>> context = openmm.Context(model.system, integrator, platform)
>>> context.setPositions(model.positions)
>>> print(radius_of_gyration.getValue(context, digits=6))
0.2951431 nm
>>> print(radius_of_gyration.getValue(context))
0.2951... nm
"""

Expand Down
2 changes: 1 addition & 1 deletion cvpack/radius_of_gyration_sq.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class RadiusOfGyrationSq(BaseRadiusOfGyration):
>>> integrator = openmm.VerletIntegrator(0)
>>> context = openmm.Context(model.system, integrator, platform)
>>> context.setPositions(model.positions)
>>> print(rgsq.getValue(context, digits=6)) # doctest: +ELLIPSIS
>>> print(rgsq.getValue(context)) # doctest: +ELLIPSIS
0.0871... nm**2
"""
Expand Down
4 changes: 2 additions & 2 deletions cvpack/residue_coordination.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class ResidueCoordination(openmm.CustomCentroidBondForce, BaseCollectiveVariable
26.0 dimensionless
>>> residue_coordination.setReferenceValue(26 * unit.dimensionless)
>>> context.reinitialize(preserveState=True)
>>> print(residue_coordination.getValue(context, digits=6))
1.0 dimensionless
>>> print(residue_coordination.getValue(context))
0.99999... dimensionless
"""

yaml_tag = "!cvpack.ResidueCoordination"
Expand Down
4 changes: 2 additions & 2 deletions cvpack/rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def getNullBondForce(self) -> openmm.HarmonicBondForce:
>>> context = openmm.Context(model.system, integrator, platform)
>>> context.setPositions(model.positions)
>>> integrator.step(100)
>>> print(rmsd.getValue(context, digits=6))
0.104363 nm
>>> print(rmsd.getValue(context))
0.10436... nm
"""
force = openmm.HarmonicBondForce()
Expand Down
Loading

0 comments on commit f08a547

Please sign in to comment.