diff --git a/model/tgn.py b/model/tgn.py index 05704f0..30ec65b 100644 --- a/model/tgn.py +++ b/model/tgn.py @@ -178,12 +178,15 @@ def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_ source_nodes, source_node_embedding, edge_times, edge_idxs) - if self.memory_update_at_start: - self.memory.store_raw_messages(unique_sources, source_id_to_messages) - self.memory.store_raw_messages(unique_destinations, destination_id_to_messages) - else: - self.update_memory(unique_sources, source_id_to_messages) - self.update_memory(unique_destinations, destination_id_to_messages) + + self.memory.store_raw_messages(unique_sources, source_id_to_messages) + self.memory.store_raw_messages(unique_destinations, destination_id_to_messages) + + if not self.memory_update_at_start: + unique_node_ids = np.unique(np.concatenate((unique_sources, unique_destinations))) + self.update_memory(unique_node_ids, + self.memory.messages) + self.memory.clear_messages(unique_node_ids) if self.dyrep: source_node_embedding = memory[source_nodes]