Skip to content

Commit

Permalink
fix typos
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeleHardie committed Feb 6, 2024
1 parent 786eb04 commit 0463161
Showing 1 changed file with 37 additions and 25 deletions.
62 changes: 37 additions & 25 deletions ammo/msm/_msm.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,11 +648,9 @@ def plot_clusters(self, cluster_sets, shape, titles=None, x=0, y=1, features='in
fig, ax, cbar = self.plot_data(shape, titles, x, y, features, cmap)

# sort the cluster sets
if type(cluster_sets) == list:
cluster_sets = _np.array(cluster_sets)
if same_clusters: #if only one cluster set given
cluster_sets = _np.array([cluster_sets for i in range(len(titles))])
cluster_sets = [cluster_sets for i in range(len(titles))]

for row in range(shape[0]):
for col in range(shape[1]):
cluster = row*shape[1] + col
Expand Down Expand Up @@ -692,9 +690,9 @@ def get_pcca_clusters(self, n_states, titles=None):
for i in range(states):
if len(msm.pcca[states][i]) > 0:
centers.append(msm.cluster_centers[msm.pcca[states][i]])
all_cluster_centers.append(_np.array(centers))
all_cluster_centers.append(centers)

return _np.array(all_cluster_centers)
return all_cluster_centers

def mfpt(self, n_states, msm=None, timestep=None, titles=None, verbose=True, overwrite=False):
"""Compute mean first passage times for each MSM, based on specified pcca metastable state assignment
Expand Down Expand Up @@ -817,7 +815,7 @@ def compare_states_and_timescales(self, n_states, msm=None, timestep=None, title

return None

def bootstrapping(self, n_states, msm=None, titles=None, cluster_centers=None, min_iter=100, max_iter=100, tol=1, last=10, verbose=False, overwrite=False):
def bootstrapping(self, n_states, msm=None, titles=None, lag_time=None, cluster_centers=None, min_iter=100, max_iter=100, tol=1, last=10, verbose=False, overwrite=False):
"""
Compute bootstrapped probabilities until they have converged to a Gaussian distribution or until maximum number
of iterations have been reached.
Expand All @@ -832,6 +830,9 @@ def bootstrapping(self, n_states, msm=None, titles=None, cluster_centers=None, m
titles : str, [str]
MSMs to compute probabilities for. If None, all with be used
lag_time : int, str
MSM lag time in trajectory steps (if int) or in format "value unit", e.g. "10 ps" (if str). If None, lag time of existing pyemma MSM will be used.
cluster_centers : [float], numpy.array
cluster centers to assign data to. If None, msm own cluster centers will be used
Expand Down Expand Up @@ -873,7 +874,7 @@ def bootstrapping(self, n_states, msm=None, titles=None, cluster_centers=None, m
print(f'Bootstrapped probabilities, based on {pcca} MSM, {n_states} states:')
for key in titles:
print(key)
probabilities[key] = self._MSMs[key].bootstrapping(n_states, self._MSMs[pcca], cluster_centers, min_iter, max_iter, tol, last, verbose)
probabilities[key] = self._MSMs[key].bootstrapping(n_states, self._MSMs[pcca], lag_time, cluster_centers, min_iter, max_iter, tol, last, verbose)
print('-'*30)

return probabilities
Expand Down Expand Up @@ -1725,12 +1726,14 @@ def __fit_gaus(self, data):

return coeff

def __build_bootstrapped_msm(self, cluster_centers=None):
def __build_bootstrapped_msm(self, lag, cluster_centers=None):
"""
Build an msm with randomly resampled data and return state probability
Parameters
----------
----------
lag : int
msm lag time in steps
cluster_centers : [float], numpy.array
cluster centers to assign data to. If None, msm own cluster centers will be used
Expand All @@ -1742,23 +1745,21 @@ def __build_bootstrapped_msm(self, cluster_centers=None):
traj_idxs : [int]
indices of trajectories used for bootstrapped msm
"""
if cluster_centers is None:
cluster_centers = self.cluster_centers

#get resampled data
traj_num = len(self.data)
traj_idxs = _np.array([int(idx*traj_num) for idx in _np.random.rand(traj_num)])

# get new trajectories
# if different clusters provided, assign new dtrajs
if cluster_centers is not None:
new_data = [self.data[idx] for idx in traj_idxs]
dtrajs = _assign_to_centers(new_data, cluster_centers)
# otherwise resample dtrajs directly
else:
dtrajs = [self.dtrajs[idx] for idx in traj_idxs]
# get new trajectories
# if different clusters provided, assign new dtrajs
if cluster_centers is not None:
new_data = [self.data[idx] for idx in traj_idxs]
dtrajs = _assign_to_centers(new_data, cluster_centers)
# otherwise resample dtrajs directly
else:
dtrajs = [self.dtrajs[idx] for idx in traj_idxs]
cluster_centers = self.cluster_centers
#build msm
bootstrap_msm = _bayesian_msm(dtrajs, 2000)
bootstrap_msm = _bayesian_msm(dtrajs, lag)

#get stationary distribution
stationary_distribution = bootstrap_msm.stationary_distribution
Expand Down Expand Up @@ -1816,7 +1817,7 @@ def bootstrapping_convergence(self, state_probabilities, tol=1, last=10):

return converged

def bootstrapping(self, n_states, msm=None, cluster_centers=None, min_iter=10, max_iter=100, tol=1, last=10, verbose=False, overwrite=False):
def bootstrapping(self, n_states, msm=None, lag_time=None, cluster_centers=None, min_iter=10, max_iter=100, tol=1, last=10, verbose=False, overwrite=False):
"""
Compute bootstrapped probabilities of a state until they have converged to a Gaussian distribution or until maximum number
of iterations have been reached.
Expand All @@ -1829,6 +1830,9 @@ def bootstrapping(self, n_states, msm=None, cluster_centers=None, min_iter=10, m
msm : allostery.msm.MSM
MSMs whose pcca assignment to use. If None, own will be used
lag_time : int, str
MSM lag time in trajectory steps (if int) or in format "value unit", e.g. "10 ps" (if str). If None, lag time of existing pyemma MSM will be used.
cluster_centers : [float], numpy.array
cluster centers to assign data to. If None, msm own cluster centers will be used
Expand Down Expand Up @@ -1860,6 +1864,14 @@ def bootstrapping(self, n_states, msm=None, cluster_centers=None, min_iter=10, m
if msm is None:
msm = self

# fix lag time
if isinstance(lag_time, str):
traj_step = _parse_time(self.timestep, 'ps', output_type='number')
msm_step = _parse_time(lag_time, 'ps', output_type='number')
lag_time = msm_step//traj_step
elif lag_time is None:
lag_time = self.msm.lagtime

# check if bootstrapping is already done
pcca = f'{msm.title}, {n_states} states'
if pcca in self.bootstrapping_data and not overwrite:
Expand All @@ -1878,8 +1890,8 @@ def bootstrapping(self, n_states, msm=None, cluster_centers=None, min_iter=10, m
print('%3i/%i'%(i,max_iter), end='\r')
# build a bootstrapped msm
try:
stationary_distribution, trajectories = self.__build_bootstrapped_msm(cluster_centers)
except: # if msm stationary probabilities too low, an error is thrown - discard those
stationary_distribution, trajectories = self.__build_bootstrapped_msm(lag_time, cluster_centers)
except Exception as e: # if msm stationary probabilities too low, an error is thrown - discard those
continue
# add results
probability = _np.array([[round(stationary_distribution[list(state_clusters)].sum()*100, 2) for state_clusters in msm.pcca[n_states]]])
Expand Down

0 comments on commit 0463161

Please sign in to comment.