Skip to content

Commit

Permalink
with hmmlearn0.3.2
Browse files Browse the repository at this point in the history
  • Loading branch information
taoliu committed Mar 2, 2024
1 parent 34f9a81 commit 43df1d4
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 35 deletions.
37 changes: 5 additions & 32 deletions MACS3/Signal/HMMR_HMM.pyx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# cython: language_level=3
# cython: profile=True
# Time-stamp: <2024-02-18 16:21:00 Tao Liu>
# Time-stamp: <2024-03-01 23:34:51 Tao Liu>

"""Module description:
Expand All @@ -20,7 +20,8 @@ from math import sqrt
import numpy as np
cimport numpy as np
from cpython cimport bool
from hmmlearn import hmm, _utils
import hmmlearn
from hmmlearn.hmm import GaussianHMM
from sklearn import cluster
import json
# from hmmlearn cimport hmm
Expand Down Expand Up @@ -51,34 +52,6 @@ cdef inline float get_weighted_density( int x, float m, float v, w ):
# Classes
# ------------------------------------

class GaussianHMM_modified( hmm.GaussianHMM ):
def _init(self, X, lengths=None):
super()._init(X, lengths)
# we will overwrite initial means_ and covars_
kmeans = cluster.KMeans(n_clusters=self.n_components,
random_state=self.random_state,
n_init=10) # https://github.com/hmmlearn/hmmlearn/pull/545
# the idea is to do the random seeds
# for 10 times orginally, hmmlearn 0.3
# will do this only once. However,
# due to the change in scikit-learn
# 1.3, the random seeding in KMeans
# will generate different results with
# previous scikit-learn. It will make
# the results irreproducible between
# sklearn <1.3 and sklearn
# >=1.3. Hopefully, if we choose to do
# the process 10 times, the results
# will be more similar.
kmeans.fit(X)
self.means_ = kmeans.cluster_centers_

cv = np.cov(X.T) + self.min_covar * np.eye(X.shape[1])
if not cv.shape:
cv.shape = (1, 1)
self.covars_ = \
_utils.distribute_covar_matrix_to_match_covariance_type( cv, self.covariance_type, self.n_components ).copy()

# ------------------------------------
# public functions
# ------------------------------------
Expand All @@ -90,7 +63,7 @@ cpdef hmm_training( list training_data, list training_data_lengths, int n_states
# according to base documentation, if init_prob not stated, it is set to be equally likely for any state (1/ # of components)
# if we have other known parameters, we should set these (ie: means_weights, covariance_type etc.)
rs = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(random_seed)))
hmm_model = GaussianHMM_modified( n_components= n_states, covariance_type = covar, random_state = rs, verbose = False )
hmm_model = GaussianHMM( n_components= n_states, covariance_type = covar, random_state = rs, verbose = False )
hmm_model = hmm_model.fit( training_data, training_data_lengths )
assert hmm_model.n_features == 4
return hmm_model
Expand Down Expand Up @@ -121,7 +94,7 @@ cpdef void hmm_model_save( str model_file, object hmm_model, int hmm_binsize, in
cpdef list hmm_model_init( str model_file ):
with open( model_file ) as f:
m = json.load( f )
hmm_model = GaussianHMM_modified( n_components=3, covariance_type=m["covariance_type"] )
hmm_model = GaussianHMM( n_components=3, covariance_type=m["covariance_type"] )
hmm_model.startprob_ = np.array(m["startprob"])
hmm_model.transmat_ = np.array(m["transmat"])
hmm_model.means_ = np.array(m["means"])
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[build-system]
requires=['setuptools>=60.0', 'numpy>=1.24.2', 'scipy>=1.11.4', 'cykhash>=2.0,<3.0', 'Cython~=3.0', 'scikit-learn>=1.2.1', 'hmmlearn==0.3.0']
requires=['setuptools>=60.0', 'numpy>=1.24.2', 'scipy>=1.11.4', 'cykhash>=2.0,<3.0', 'Cython~=3.0', 'scikit-learn>=1.2.1', 'hmmlearn>=0.3.2']

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Cython~=3.0
numpy>=1.24.2
scipy>=1.11.4
scikit-learn>=1.2.1
hmmlearn==0.3.0
hmmlearn>=0.3.2
cykhash>=2.0,<3.0
pytest>=7.0
setuptools>=60.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

install_requires = [ "numpy>=1.24.2",
"scipy>=1.11.4",
"hmmlearn==0.3.0",
"hmmlearn>=0.3.2",
"scikit-learn>=1.2.1",
"cykhash>=2.0,<3.0"]

Expand Down

0 comments on commit 43df1d4

Please sign in to comment.