Skip to content

Commit

Permalink
Improves registration of serializable objects (#70)
Browse files Browse the repository at this point in the history
* Improves registration of serializable objects

* Fixed code formatting
  • Loading branch information
Charlles Abreu authored Mar 21, 2024
1 parent f08a547 commit fc2eb1a
Show file tree
Hide file tree
Showing 27 changed files with 107 additions and 115 deletions.
29 changes: 0 additions & 29 deletions cvpack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
"""Useful Collective Variables for OpenMM"""

import yaml

# Add imports here
from ._version import __version__ # noqa: F401
from .angle import Angle # noqa: F401
from .atomic_function import AtomicFunction # noqa: F401
Expand All @@ -24,29 +21,3 @@
from .sheet_rmsd_content import SheetRMSDContent # noqa: F401
from .torsion import Torsion # noqa: F401
from .torsion_similarity import TorsionSimilarity # noqa: F401

for _cv in [
Angle,
AtomicFunction,
AttractionStrength,
CentroidFunction,
CompositeRMSD,
Distance,
HelixAngleContent,
HelixHBondContent,
HelixRMSDContent,
HelixTorsionContent,
NumberOfContacts,
OpenMMForceWrapper,
PathInCVSpace,
RadiusOfGyration,
RadiusOfGyrationSq,
ResidueCoordination,
RMSD,
SheetRMSDContent,
Torsion,
TorsionSimilarity,
]:
yaml.SafeDumper.add_representer(_cv, _cv.to_yaml)
yaml.SafeLoader.add_constructor(_cv.yaml_tag, _cv.from_yaml)
del _cv
5 changes: 3 additions & 2 deletions cvpack/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ class Angle(openmm.CustomAngleForce, BaseCollectiveVariable):
"""

yaml_tag = "!cvpack.Angle"

def __init__(self, atom1: int, atom2: int, atom3: int, pbc: bool = False) -> None:
super().__init__("theta")
self.addAngle(atom1, atom2, atom3, [])
self.setUsesPeriodicBoundaryConditions(pbc)
self._registerCV(mmunit.radians, atom1, atom2, atom3, pbc)
self._registerPeriod(2 * math.pi)


Angle.registerTag("!cvpack.Angle")
5 changes: 3 additions & 2 deletions cvpack/atomic_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ class AtomicFunction(openmm.CustomCompoundBondForce, BaseCustomFunction):
429.479... kJ/mol
"""

yaml_tag = "!cvpack.AtomicFunction"

@mmunit.convert_quantities
def __init__( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -354,3 +352,6 @@ def fromOpenMMForce(
if isinstance(force, openmm.PeriodicTorsionForce):
return cls._fromPeriodicTorsionForce(force, unit, pbc)
raise TypeError(f"Force {force} is not convertible to an AtomicFunction")


AtomicFunction.registerTag("!cvpack.AtomicFunction")
5 changes: 3 additions & 2 deletions cvpack/attraction_strength.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,6 @@ class AttractionStrength(openmm.CustomNonbondedForce, BaseCollectiveVariable):
3880.8...
"""

yaml_tag = "!cvpack.AttractionStrength"

@mmunit.convert_quantities
def __init__( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -236,3 +234,6 @@ def __init__( # pylint: disable=too-many-arguments
reference,
contrastScaling,
)


AttractionStrength.registerTag("!cvpack.AttractionStrength")
5 changes: 3 additions & 2 deletions cvpack/centroid_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,6 @@ class CentroidFunction(openmm.CustomCentroidBondForce, BaseCustomFunction):
33.0 dimensionless
"""

yaml_tag = "!cvpack.CentroidFunction"

@mmunit.convert_quantities
def __init__( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -195,3 +193,6 @@ def __init__( # pylint: disable=too-many-arguments
)
if period is not None:
self._registerPeriod(period)


CentroidFunction.registerTag("!cvpack.CentroidFunction")
5 changes: 3 additions & 2 deletions cvpack/composite_rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ class CompositeRMSD(CompositeRMSDForce, BaseCollectiveVariable):
0.0 nm
"""

yaml_tag = "!cvpack.CompositeRMSD"

@mmunit.convert_quantities
def __init__(
self,
Expand All @@ -149,3 +147,6 @@ def __init__(
for group in groups:
self.addGroup(group)
self._registerCV(mmunit.nanometers, defined_coords, groups, num_atoms)


CompositeRMSD.registerTag("!cvpack.CompositeRMSD")
23 changes: 7 additions & 16 deletions cvpack/cvpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@

from cvpack import unit as mmunit

from .serializer import Serializable
from .unit import value_in_md_units
from .utils import compute_effective_mass, get_single_force_state


class SerializableAtom(yaml.YAMLObject):
class SerializableAtom(Serializable):
r"""
A serializable version of OpenMM's Atom class.
"""

yaml_tag = "!cvpack.Atom"

def __init__( # pylint: disable=super-init-not-called
self, atom: t.Union[mmapp.topology.Atom, "SerializableAtom"]
) -> None:
Expand All @@ -48,16 +47,11 @@ def __setstate__(self, keywords: t.Dict[str, t.Any]) -> None:
self.__dict__.update(keywords)


yaml.SafeDumper.add_representer(SerializableAtom, SerializableAtom.to_yaml)
yaml.SafeLoader.add_constructor(SerializableAtom.yaml_tag, SerializableAtom.from_yaml)

SerializableAtom.registerTag("!cvpack.Atom")

class SerializableResidue(yaml.YAMLObject):
r"""
A serializable version of OpenMM's Residue class.
"""

yaml_tag = "!cvpack.Residue"
class SerializableResidue(Serializable):
r"""A serializable version of OpenMM's Residue class."""

def __init__( # pylint: disable=super-init-not-called
self, residue: t.Union[mmapp.topology.Residue, "SerializableResidue"]
Expand Down Expand Up @@ -85,13 +79,10 @@ def atoms(self):
return iter(self._atoms)


yaml.SafeDumper.add_representer(SerializableResidue, SerializableResidue.to_yaml)
yaml.SafeLoader.add_constructor(
SerializableResidue.yaml_tag, SerializableResidue.from_yaml
)
SerializableResidue.registerTag("!cvpack.Residue")


class BaseCollectiveVariable(openmm.Force, yaml.YAMLObject):
class BaseCollectiveVariable(openmm.Force, Serializable):
r"""
An abstract class with common attributes and method for all CVs.
"""
Expand Down
5 changes: 3 additions & 2 deletions cvpack/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ class Distance(openmm.CustomBondForce, BaseCollectiveVariable):
"""

yaml_tag = "!cvpack.Distance"

def __init__(self, atom1: int, atom2: int, pbc: bool = False) -> None:
super().__init__("r")
self.addBond(atom1, atom2, [])
self.setUsesPeriodicBoundaryConditions(pbc)
self._registerCV(mmunit.nanometers, atom1, atom2, pbc)


Distance.registerTag("!cvpack.Distance")
5 changes: 3 additions & 2 deletions cvpack/helix_angle_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ class HelixAngleContent(openmm.CustomAngleForce, BaseCollectiveVariable):
18.7605... dimensionless
"""

yaml_tag = "!cvpack.HelixAngleContent"

@mmunit.convert_quantities
def __init__( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -139,3 +137,6 @@ def find_alpha_carbon(residue: mmapp.topology.Residue) -> int:
halfExponent,
normalize,
)


HelixAngleContent.registerTag("!cvpack.HelixAngleContent")
5 changes: 3 additions & 2 deletions cvpack/helix_hbond_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ class HelixHBondContent(openmm.CustomBondForce, BaseCollectiveVariable):
15.880... dimensionless
"""

yaml_tag = "!cvpack.HelixHBondContent"

@mmunit.convert_quantities
def __init__( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -126,3 +124,6 @@ def find_atom(residue: mmapp.topology.Residue, pattern: t.Pattern) -> int:
halfExponent,
normalize,
)


HelixHBondContent.registerTag("!cvpack.HelixHBondContent")
5 changes: 3 additions & 2 deletions cvpack/helix_rmsd_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ class HelixRMSDContent(BaseRMSDContent):
15.98... dimensionless
"""

yaml_tag = "!cvpack.HelixRMSDContent"

@mmunit.convert_quantities
def __init__( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -155,3 +153,6 @@ def __init__( # pylint: disable=too-many-arguments
stepFunction,
normalize,
)


HelixRMSDContent.registerTag("!cvpack.HelixRMSDContent")
5 changes: 3 additions & 2 deletions cvpack/helix_torsion_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ class HelixTorsionContent(openmm.CustomTorsionForce, BaseCollectiveVariable):
17.452... dimensionless
"""

yaml_tag = "!cvpack.HelixTorsionContent"

@mmunit.convert_quantities
def __init__( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -157,3 +155,6 @@ def find_atom(residue: mmapp.topology.Residue, name: str) -> int:
tolerance,
halfExponent,
)


HelixTorsionContent.registerTag("!cvpack.HelixTorsionContent")
5 changes: 3 additions & 2 deletions cvpack/number_of_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ class NumberOfContacts(openmm.CustomNonbondedForce, BaseCollectiveVariable):
0.99999... dimensionless
"""

yaml_tag = "!cvpack.NumberOfContacts"

@mmunit.convert_quantities
def __init__( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -175,3 +173,6 @@ def __init__( # pylint: disable=too-many-arguments
cutoffFactor,
switchFactor,
)


NumberOfContacts.registerTag("!cvpack.NumberOfContacts")
5 changes: 3 additions & 2 deletions cvpack/openmm_force_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ class OpenMMForceWrapper(BaseCollectiveVariable):
0.00538... nm**2 Da/(rad**2)
"""

yaml_tag = "!cvpack.OpenMMForceWrapper"

def __init__( # pylint: disable=too-many-arguments, super-init-not-called
self,
openmmForce: t.Union[openmm.Force, str],
Expand All @@ -79,3 +77,6 @@ def __init__( # pylint: disable=too-many-arguments, super-init-not-called
self._registerCV(unit, openmmForce, unit, period)
if period is not None:
self._registerPeriod(period)


OpenMMForceWrapper.registerTag("!cvpack.OpenMMForceWrapper")
13 changes: 9 additions & 4 deletions cvpack/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
"""

import yaml
from .serializer import Serializable


class Metric(yaml.YAMLObject):
class Metric(Serializable):
"""
A measure of progress or deviation with respect to a path in CV space
"""
Expand All @@ -26,9 +26,14 @@ def __repr__(self) -> str:
def __eq__(self, other: object) -> bool:
return isinstance(other, Metric) and self.name == other.name

def __getstate__(self) -> dict:
return {"name": self.name}

yaml.SafeDumper.add_representer(Metric, Metric.to_yaml)
yaml.SafeLoader.add_constructor(Metric.yaml_tag, Metric.from_yaml)
def __setstate__(self, state: dict) -> None:
self.name = state["name"]


Metric.registerTag("!cvpack.path.Metric")


progress: Metric = Metric("progress")
Expand Down
5 changes: 3 additions & 2 deletions cvpack/path_in_cv_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ class PathInCVSpace(openmm.CustomCVForce, BaseCollectiveVariable):
z = 0.25... dimensionless
"""

yaml_tag = "!cvpack.PathInCVSpace"

@mmunit.convert_quantities
def __init__( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -177,3 +175,6 @@ def __init__( # pylint: disable=too-many-arguments
sigma,
scales,
)


PathInCVSpace.registerTag("!cvpack.PathInCVSpace")
5 changes: 3 additions & 2 deletions cvpack/radius_of_gyration.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ class RadiusOfGyration(BaseRadiusOfGyration):
"""

yaml_tag = "!cvpack.RadiusOfGyration"

def __init__(
self, group: t.Iterable[int], pbc: bool = False, weighByMass: bool = False
) -> None:
Expand All @@ -91,3 +89,6 @@ def __init__(
)
self.addBond(list(range(num_groups)))
self._registerCV(mmunit.nanometers, group, pbc, weighByMass)


RadiusOfGyration.registerTag("!cvpack.RadiusOfGyration")
5 changes: 3 additions & 2 deletions cvpack/radius_of_gyration_sq.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ class RadiusOfGyrationSq(BaseRadiusOfGyration):
"""

yaml_tag = "!cvpack.RadiusOfGyrationSq"

def __init__(
self, group: t.Iterable[int], pbc: bool = False, weighByMass: bool = False
) -> None:
Expand All @@ -88,3 +86,6 @@ def __init__(
for atom in group:
self.addBond([atom, num_atoms])
self._registerCV(mmunit.nanometers**2, group, pbc, weighByMass)


RadiusOfGyrationSq.registerTag("!cvpack.RadiusOfGyrationSq")
5 changes: 3 additions & 2 deletions cvpack/residue_coordination.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ class ResidueCoordination(openmm.CustomCentroidBondForce, BaseCollectiveVariable
0.99999... dimensionless
"""

yaml_tag = "!cvpack.ResidueCoordination"

@mmunit.convert_quantities
def __init__( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -184,3 +182,6 @@ def setReferenceValue(self, value: mmunit.ScalarQuantity) -> None:
expression.replace(f"refval={self._ref_val}", f"refval={value}")
)
self._ref_val = value


ResidueCoordination.registerTag("!cvpack.ResidueCoordination")
5 changes: 3 additions & 2 deletions cvpack/rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ class RMSD(openmm.RMSDForce, BaseCollectiveVariable):
"""

yaml_tag = "!cvpack.RMSD"

@mmunit.convert_quantities
def __init__(
self,
Expand Down Expand Up @@ -149,3 +147,6 @@ def getNullBondForce(self) -> openmm.HarmonicBondForce:
for i, j in zip(group[:-1], group[1:]):
force.addBond(i, j, 0.0, 0.0)
return force


RMSD.registerTag("!cvpack.RMSD")
Loading

0 comments on commit fc2eb1a

Please sign in to comment.