From 764af869bd8238db9e1ec87def29ad4a36f97422 Mon Sep 17 00:00:00 2001 From: AlbertDominguez Date: Thu, 21 Jul 2022 19:40:12 +0200 Subject: [PATCH] Fix pretrained initialization. Default initialization is now PCArchetypal --- neural_admixture/model/initializations.py | 17 +++++++++++++++-- neural_admixture/model/modules.py | 21 --------------------- neural_admixture/model/switchers.py | 1 - neural_admixture/src/train.py | 4 ++-- neural_admixture/src/utils.py | 2 +- setup.py | 2 +- 6 files changed, 19 insertions(+), 28 deletions(-) diff --git a/neural_admixture/model/initializations.py b/neural_admixture/model/initializations.py index 3387d42..6706510 100644 --- a/neural_admixture/model/initializations.py +++ b/neural_admixture/model/initializations.py @@ -53,7 +53,7 @@ def get_decoder_init(cls, X, K, path, run_name, n_components): centers = np.concatenate([obj.cluster_centers_ for obj in k_means_objs]) P_init = torch.as_tensor(pca_obj.inverse_transform(centers), dtype=torch.float32).view(sum(K), -1) else: - k_means_obj = KMeans(n_clusters=K, random_state=42, n_init=10, max_iter=10).fit(X_tsvd) + k_means_obj = KMeans(n_clusters=K, random_state=42, n_init=10, max_iter=10).fit(X_pca) P_init = torch.as_tensor(pca_obj.inverse_transform(k_means_obj.cluster_centers_), dtype=torch.float32).view(K, -1) te = time.time() log.info('Weights initialized in {} seconds.'.format(te-t0)) @@ -70,7 +70,7 @@ def get_decoder_init(cls, X, K, path, run_name, n_components): class PCArchetypal(object): @classmethod def get_decoder_init(cls, X, K, path, run_name, n_components, seed): - log.info('Running ArchetypalPCA initialization...') + log.info('Running PCArchetypal initialization...') np.random.seed(seed) t0 = time.time() try: @@ -123,3 +123,16 @@ def get_decoder_init(cls, X, y, K): te = time.time() log.info('Weights initialized in {} seconds.'.format(te-t0)) return P_init + + +class PretrainedInitialization(object): + @classmethod + def get_decoder_init(cls, X, K, path): + log.info('Fetching pretrained weights...') + if len(K) > 1: + raise NotImplementedError("Pretrained mode is only supported for single-head runs.") + # Loads standard ADMIXTURE output format + P_init = torch.as_tensor(1-np.genfromtxt(path, delimiter=' ').T, dtype=torch.float32) + assert P_init.shape[0] == K[0], 'Input P is not coherent with the value of K' + log.info('Weights fetched.') + return P_init diff --git a/neural_admixture/model/modules.py b/neural_admixture/model/modules.py index ace9354..2d307fb 100644 --- a/neural_admixture/model/modules.py +++ b/neural_admixture/model/modules.py @@ -62,24 +62,3 @@ def forward(self, hid_states): outputs = [torch.clamp(self._get_decoder_for_k(self.ks[i])(hid_states[i]), 0, 1) for i in range(len(self.ks))] return outputs -# class NonLinearMultiHeadDecoder(nn.Module): -# def __init__(self, ks, output_size, bias=False, -# hidden_size=512, hidden_activation=nn.ReLU(), -# inits=None): -# super().__init__() -# self.ks = ks -# self.hidden_size = hidden_size -# self.output_size = output_size -# self.heads_decoder = nn.Linear(sum(self.ks), self.hidden_size, bias=bias) -# self.common_decoder = nn.Linear(self.hidden_size, self.output_size) -# self.nonlinearity = hidden_activation -# self.sigmoid = nn.Sigmoid() - -# def forward(self, hid_states): -# if len(hid_states) > 1: -# concat_states = torch.cat(hid_states, 1) -# else: -# concat_states = hid_states[0] -# dec = self.nonlinearity(self.heads_decoder(concat_states)) -# rec = self.sigmoid(self.common_decoder(dec)) -# return rec diff --git a/neural_admixture/model/switchers.py b/neural_admixture/model/switchers.py index 4ab43f9..27b5903 100644 --- a/neural_admixture/model/switchers.py +++ b/neural_admixture/model/switchers.py @@ -15,7 +15,6 @@ class Switchers(object): 'pcarchetypal': lambda X, y, k, seed, path, run_name, n_comp: init.PCArchetypal.get_decoder_init(X, k, path, run_name, n_comp, seed), 'pretrained': lambda X, y, k, seed, path, run_name, n_comp: init.PretrainedInitialization.get_decoder_init(X, k, path), 'supervised': lambda X, y, k, seed, path, run_name, n_comp: init.SupervisedInitialization.get_decoder_init(X, y, k) - } _optimizers = { diff --git a/neural_admixture/src/train.py b/neural_admixture/src/train.py index 85409d1..6db6662 100644 --- a/neural_admixture/src/train.py +++ b/neural_admixture/src/train.py @@ -52,10 +52,10 @@ def fit_model(trX, args, valX=None, trY=None, valY=None): torch.manual_seed(seed) # Initialization log.info('Initializing...') - if init_file is None: + if init_file is None and decoder_init != "pretrained": log.warning(f'Initialization filename not provided. Going to store it to {save_dir}/{run_name}.pkl') init_file = f'{run_name}.pkl' - init_path = f'{save_dir}/{init_file}' + init_path = f'{save_dir}/{init_file}' if decoder_init != "pretrained" else init_file P_init = switchers['initializations'][decoder_init](trX, trY, Ks, seed, init_path, run_name, n_components) activation = switchers['activations'][activation_str](0) log.info('Variants: {}'.format(trX.shape[1])) diff --git a/neural_admixture/src/utils.py b/neural_admixture/src/utils.py index 2b80c2d..ab5d2d9 100644 --- a/neural_admixture/src/utils.py +++ b/neural_admixture/src/utils.py @@ -16,7 +16,7 @@ def parse_train_args(argv): description='Rapid population clustering with autoencoders - training mode') parser.add_argument('--learning_rate', required=False, default=0.0001, type=float, help='Learning rate') parser.add_argument('--max_epochs', required=False, type=int, default=50, help='Maximum number of epochs') - parser.add_argument('--initialization', required=False, type=str, default = 'pckmeans', + parser.add_argument('--initialization', required=False, type=str, default = 'pcarchetypal', choices=['pretrained', 'pckmeans', 'supervised', 'pcarchetypal'], help='Decoder initialization (overriden if supervised)') parser.add_argument('--optimizer', required=False, default='adam', type=str, choices=['adam', 'sgd'], help='Optimizer') diff --git a/setup.py b/setup.py index 79472f3..e56b705 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name='neural-admixture', - version='1.1.5', + version='1.1.6', long_description=(Path(__file__).parent / 'README.md').read_text(), long_description_content_type='text/markdown', description='Population clustering with autoencoders',