diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..6c2ff60 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "githubPullRequests.ignoredPullRequestBranches": [ + "master" + ] +} \ No newline at end of file diff --git a/README.md b/README.md index 420f510..e2b4c65 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,14 @@ -# TGN: Temporal Graph Networks [[arXiv](https://arxiv.org/abs/2006.10637), [YouTube](https://www.youtube.com/watch?v=W1GvX2ZcUmY), [Blog Post](https://towardsdatascience.com/temporal-graph-networks-ab8f327f2efe)] +# Experimentos con Temporal Graph Networks -Dynamic Graph | TGN -:-------------------------:|:-------------------------: -![](figures/dynamic_graph.png) | ![](figures/tgn.png) +## Introducción +Partiendo del artículo [Temporal Graph Networks for Deep Learning on Dynamic Graphs](https://arxiv.org/abs/2006.10637) y del repositorio https://github.com/twitterresearch/tgn, la propuesta consiste en una evaluación experimental de las Temporal Graph +Networks definidas en dicho artículo, analizando las líneas futuras propuestas y valorando +los resultados obtenidos. +## Ejecución de los experimentos - -## Introduction - -Despite the plethora of different models for deep learning on graphs, few approaches have been proposed thus far for dealing with graphs that present some sort of dynamic nature (e.g. evolving features or connectivity over time). - -In this paper, we present Temporal Graph Networks (TGNs), a generic, efficient framework for deep learning on dynamic graphs represented as sequences of timed events. Thanks to a novel combination of memory modules and graph-based operators, TGNs are able to significantly outperform previous approaches being at the same time more computationally efficient. - -We furthermore show that several previous models for learning on dynamic graphs can be cast as specific instances of our framework. We perform a detailed ablation study of different components of our framework and devise the best configuration that achieves state-of-the-art performance on several transductive and inductive prediction tasks for dynamic graphs. - - -#### Paper link: [Temporal Graph Networks for Deep Learning on Dynamic Graphs](https://arxiv.org/abs/2006.10637) - - -## Running the experiments - -### Requirements +### Requerimientos Dependencies (with python >= 3.7): @@ -31,81 +18,78 @@ torch==1.6.0 scikit_learn==0.23.1 ``` -### Dataset and Preprocessing +### Conjuntos de datos y pre-procesado + +#### Conjuntos de datos de TGN -#### Download the public data -Download the sample datasets (eg. wikipedia and reddit) from -[here](http://snap.stanford.edu/jodie/) and store their csv files in a folder named -```data/```. +##### Descargar los datos +Se pueden descargar los conjuntos de datos de wikipedia y reddit desde [aquí](http://snap.stanford.edu/jodie/) y se deben almacenar en las carpetas +```data/tgn_wikipedia``` y ```data/tgn_reddit``` respectivamente. -#### Preprocess the data -We use the dense `npy` format to save the features in binary format. If edge features or nodes -features are absent, they will be replaced by a vector of zeros. +#### Pre-procesamiento de los datos +Se emplean archivos .npy para guardar los datos creados. Si las características de los nodos o aristas están vacías, se rellenarán con 0's. ```{bash} -python utils/preprocess_data.py --data wikipedia --bipartite -python utils/preprocess_data.py --data reddit --bipartite +python3 utils/tgn_preprocess_data.py --data wikipedia --bipartite +python3 utils/tgn_preprocess_data.py --data reddit --bipartite ``` +### Entrenamiento del modelo - -### Model Training - -Self-supervised learning using the link prediction task: +Para la predicción de enlaces: ```{bash} # TGN-attn: Supervised learning on the wikipedia dataset -python train_self_supervised.py --use_memory --prefix tgn-attn --n_runs 10 +python3 tgn_link_prediction.py --use_memory --prefix tgn-attn --n_runs 10 # TGN-attn-reddit: Supervised learning on the reddit dataset -python train_self_supervised.py -d reddit --use_memory --prefix tgn-attn-reddit --n_runs 10 +python tgn_link_prediction.py -d reddit --use_memory --prefix tgn-attn-reddit --n_runs 10 ``` -Supervised learning on dynamic node classification (this requires a trained model from -the self-supervised task, by eg. running the commands above): +Para la clasificación de nodos(se requiere el modelo entrenado en la tarea de predicción de enlaces): ```{bash} # TGN-attn: self-supervised learning on the wikipedia dataset -python train_supervised.py --use_memory --prefix tgn-attn --n_runs 10 +python3 tgn_node_classification.py --use_memory --prefix tgn-attn --n_runs 10 # TGN-attn-reddit: self-supervised learning on the reddit dataset -python train_supervised.py -d reddit --use_memory --prefix tgn-attn-reddit --n_runs 10 +python3 tgn_node_classification.py -d reddit --use_memory --prefix tgn-attn-reddit --n_runs 10 ``` -### Baselines +### JODIE y DyRep ```{bash} -### Wikipedia Self-supervised +### Predicción de enlaces en Wikipedia # Jodie -python train_self_supervised.py --use_memory --memory_updater rnn --embedding_module time --prefix jodie_rnn --n_runs 10 +python3 tgn_link_prediction.py --use_memory --memory_updater rnn --embedding_module time --prefix jodie_rnn --n_runs 10 # DyRep -python train_self_supervised.py --use_memory --memory_updater rnn --dyrep --use_destination_embedding_in_message --prefix dyrep_rnn --n_runs 10 +python3 tgn_link_prediction.py --use_memory --memory_updater rnn --dyrep --use_destination_embedding_in_message --prefix dyrep_rnn --n_runs 10 -### Reddit Self-supervised +### Predicción de enlaces en Reddit # Jodie -python train_self_supervised.py -d reddit --use_memory --memory_updater rnn --embedding_module time --prefix jodie_rnn_reddit --n_runs 10 +python3 tgn_link_prediction.py -d reddit --use_memory --memory_updater rnn --embedding_module time --prefix jodie_rnn_reddit --n_runs 10 # DyRep -python train_self_supervised.py -d reddit --use_memory --memory_updater rnn --dyrep --use_destination_embedding_in_message --prefix dyrep_rnn_reddit --n_runs 10 +python3 tgn_link_prediction.py -d reddit --use_memory --memory_updater rnn --dyrep --use_destination_embedding_in_message --prefix dyrep_rnn_reddit --n_runs 10 -### Wikipedia Supervised +### Clasificación de nodos en Wikipedia # Jodie -python train_supervised.py --use_memory --memory_updater rnn --embedding_module time --prefix jodie_rnn --n_runs 10 +python3 tgn_node_classification.py --use_memory --memory_updater rnn --embedding_module time --prefix jodie_rnn --n_runs 10 # DyRep -python train_supervised.py --use_memory --memory_updater rnn --dyrep --use_destination_embedding_in_message --prefix dyrep_rnn --n_runs 10 +python3 tgn_node_classification.py --use_memory --memory_updater rnn --dyrep --use_destination_embedding_in_message --prefix dyrep_rnn --n_runs 10 -### Reddit Supervised +### Clasificación de nodos en Reddit # Jodie -python train_supervised.py -d reddit --use_memory --memory_updater rnn --embedding_module time --prefix jodie_rnn_reddit --n_runs 10 +python3 tgn_node_classification.py -d reddit --use_memory --memory_updater rnn --embedding_module time --prefix jodie_rnn_reddit --n_runs 10 # DyRep -python train_supervised.py -d reddit --use_memory --memory_updater rnn --dyrep --use_destination_embedding_in_message --prefix dyrep_rnn_reddit --n_runs 10 +python3 tgn_node_classification.py -d reddit --use_memory --memory_updater rnn --dyrep --use_destination_embedding_in_message --prefix dyrep_rnn_reddit --n_runs 10 ``` diff --git a/model/tgn.py b/model/tgn.py index 05704f0..4d9f9dc 100644 --- a/model/tgn.py +++ b/model/tgn.py @@ -10,6 +10,7 @@ from modules.memory_updater import get_memory_updater from modules.embedding_module import get_embedding_module from model.time_encoding import TimeEncode +from modules.feature_embedding import get_feature_embedding class TGN(torch.nn.Module): @@ -17,13 +18,15 @@ def __init__(self, neighbor_finder, node_features, edge_features, device, n_laye n_heads=2, dropout=0.1, use_memory=False, memory_update_at_start=True, message_dimension=100, memory_dimension=500, embedding_module_type="graph_attention", - message_function="mlp", + message_function="identity", mean_time_shift_src=0, std_time_shift_src=1, mean_time_shift_dst=0, std_time_shift_dst=1, n_neighbors=None, aggregator_type="last", memory_updater_type="gru", use_destination_embedding_in_message=False, use_source_embedding_in_message=False, - dyrep=False): + dyrep=False, + feature_embedding_type="identity", + feature_dimension=50): super(TGN, self).__init__() self.n_layers = n_layers @@ -43,6 +46,8 @@ def __init__(self, neighbor_finder, node_features, edge_features, device, n_laye self.use_destination_embedding_in_message = use_destination_embedding_in_message self.use_source_embedding_in_message = use_source_embedding_in_message self.dyrep = dyrep + self.feature_embedding_type = feature_embedding_type + self.feature_dimension = feature_dimension self.use_memory = use_memory self.time_encoder = TimeEncode(dimension=self.n_node_features) @@ -65,7 +70,7 @@ def __init__(self, neighbor_finder, node_features, edge_features, device, n_laye message_dimension=message_dimension, device=device) self.message_aggregator = get_message_aggregator(aggregator_type=aggregator_type, - device=device) + device=device, raw_message_dimension=raw_message_dimension) self.message_function = get_message_function(module_type=message_function, raw_message_dimension=raw_message_dimension, message_dimension=message_dimension) @@ -74,6 +79,9 @@ def __init__(self, neighbor_finder, node_features, edge_features, device, n_laye message_dimension=message_dimension, memory_dimension=self.memory_dimension, device=device) + self.feature_embedding = get_feature_embedding(module_type=self.feature_embedding_type, + raw_features_dimension=self.n_edge_features, + features_dimension=self.feature_dimension) self.embedding_module_type = embedding_module_type @@ -124,25 +132,22 @@ def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_ if self.memory_update_at_start: # Update memory for all nodes with messages stored in previous batches memory, last_update = self.get_updated_memory(list(range(self.n_nodes)), - self.memory.messages) + self.memory.messages) else: memory = self.memory.get_memory(list(range(self.n_nodes))) last_update = self.memory.last_update - ### Compute differences between the time the memory of a node was last updated, - ### and the time for which we want to compute the embedding of a node - source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[ - source_nodes].long() - source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src - destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[ - destination_nodes].long() - destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst - negative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[ - negative_nodes].long() - negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst - - time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs], - dim=0) + ### Compute differences between the time the memory of a node was last updated, + ### and the time for which we want to compute the embedding of a node + source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[source_nodes].long() + source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src + destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[destination_nodes].long() + destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst + negative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[negative_nodes].long() + negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst + + time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs], + dim=0) # Compute the embeddings using the embedding module node_embedding = self.embedding_module.compute_embedding(memory=memory, @@ -162,22 +167,22 @@ def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_ # new messages for them) self.update_memory(positives, self.memory.messages) - assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \ - "Something wrong in how the memory was updated" + # assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \ + # "Something wrong in how the memory was updated" # Remove messages for the positives since we have already updated the memory using them self.memory.clear_messages(positives) unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes, - source_node_embedding, - destination_nodes, - destination_node_embedding, - edge_times, edge_idxs) + source_node_embedding, + destination_nodes, + destination_node_embedding, + edge_times, edge_idxs) unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes, - destination_node_embedding, - source_nodes, - source_node_embedding, - edge_times, edge_idxs) + destination_node_embedding, + source_nodes, + source_node_embedding, + edge_times, edge_idxs) if self.memory_update_at_start: self.memory.store_raw_messages(unique_sources, source_id_to_messages) self.memory.store_raw_messages(unique_destinations, destination_id_to_messages) @@ -191,6 +196,7 @@ def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_ negative_node_embedding = memory[negative_nodes] return source_node_embedding, destination_node_embedding, negative_node_embedding + def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, n_neighbors=20): @@ -219,33 +225,48 @@ def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_n return pos_score.sigmoid(), neg_score.sigmoid() def update_memory(self, nodes, messages): + """ + 1. Agrega los mensajes pertenecientes a los mismos nodos. + 2. Calcula los mensajes únicos de cada nodo. + 3. Actualiza la memoria con los mensajes agregados y calculados. + """ # Aggregate messages for the same nodes - unique_nodes, unique_messages, unique_timestamps = \ - self.message_aggregator.aggregate( - nodes, - messages) - - if len(unique_nodes) > 0: + unique_nodes, unique_messages, unique_timestamps = self.message_aggregator.aggregate(nodes, messages) + + # Compute messages from raw messages + if len(unique_messages) > 0: unique_messages = self.message_function.compute_message(unique_messages) # Update the memory with the aggregated messages self.memory_updater.update_memory(unique_nodes, unique_messages, timestamps=unique_timestamps) + + def update_memory_new(self, positives, updated_memory, last_update): + + + # Update the memory with the aggregated messages + #self.memory_updater.update_memory_new(unique_nodes, updated_memory, last_update) + return def get_updated_memory(self, nodes, messages): + """ + 1. Agrega los mensajes pertenecientes a los mismos nodos. + 2. Calcula los mensajes únicos de cada nodo. + 3. Actualiza la memoria con los mensajes agregados y calculados. + 4. Devuelve la memoria actualizada y los tiempos de la última actualización. + """ + # Aggregate messages for the same nodes - unique_nodes, unique_messages, unique_timestamps = \ - self.message_aggregator.aggregate( - nodes, - messages) + unique_nodes, unique_messages, unique_timestamps = self.message_aggregator.aggregate(nodes, messages) + print(len(unique_nodes)) + if len(unique_nodes) > 0: unique_messages = self.message_function.compute_message(unique_messages) updated_memory, updated_last_update = self.memory_updater.get_updated_memory(unique_nodes, unique_messages, timestamps=unique_timestamps) - return updated_memory, updated_last_update def get_raw_messages(self, source_nodes, source_node_embedding, destination_nodes, @@ -253,6 +274,9 @@ def get_raw_messages(self, source_nodes, source_node_embedding, destination_node edge_times = torch.from_numpy(edge_times).float().to(self.device) edge_features = self.edge_raw_features[edge_idxs] + # Aprendizaje de características de aristas + edge_features = self.feature_embedding.compute_features(edge_features) + source_memory = self.memory.get_memory(source_nodes) if not \ self.use_source_embedding_in_message else source_node_embedding destination_memory = self.memory.get_memory(destination_nodes) if \ @@ -264,7 +288,7 @@ def get_raw_messages(self, source_nodes, source_node_embedding, destination_node source_message = torch.cat([source_memory, destination_memory, edge_features, source_time_delta_encoding], - dim=1) + dim=1) messages = defaultdict(list) unique_sources = np.unique(source_nodes) diff --git a/modules/feature_embedding.py b/modules/feature_embedding.py new file mode 100644 index 0000000..92b6652 --- /dev/null +++ b/modules/feature_embedding.py @@ -0,0 +1,41 @@ +from torch import nn + +class FeatureEmbedding(nn.Module): + """ + Embedding module for edge features. + """ + def compute_features(self, raw_features): + return None + +class MLPFeatureEmbedding(FeatureEmbedding): + def __init__(self, raw_features_dimension, features_dimension): + super(MLPFeatureEmbedding, self).__init__() + + self.mlp = self.layers = nn.Sequential( + nn.Linear(raw_features_dimension, raw_features_dimension // 2), + nn.ReLU(), + nn.BatchNorm1d(raw_features_dimension // 2), + nn.Dropout(0.2), + + nn.Linear(raw_features_dimension // 2, raw_features_dimension // 4), + nn.ReLU(), + nn.BatchNorm1d(raw_features_dimension // 4), + nn.Dropout(0.2), + + nn.Linear(raw_features_dimension // 4, features_dimension), + ) + + def compute_features(self, raw_features): + messages = self.mlp(raw_features) + + return messages + +class IdentityFeatureEmbedding(FeatureEmbedding): + def compute_features(self, raw_features): + return raw_features + +def get_feature_embedding(module_type, raw_features_dimension, features_dimension=50): + if module_type == "mlp": + return MLPFeatureEmbedding(raw_features_dimension, features_dimension) + elif module_type == "identity": + return IdentityFeatureEmbedding() \ No newline at end of file diff --git a/modules/memory.py b/modules/memory.py index 8cfa18b..88bae99 100644 --- a/modules/memory.py +++ b/modules/memory.py @@ -32,6 +32,7 @@ def __init_memory__(self): self.messages = defaultdict(list) + def store_raw_messages(self, nodes, node_id_to_messages): for node in nodes: self.messages[node].extend(node_id_to_messages[node]) diff --git a/modules/memory_updater.py b/modules/memory_updater.py index 5f8dc24..f10c72c 100644 --- a/modules/memory_updater.py +++ b/modules/memory_updater.py @@ -33,8 +33,7 @@ def get_updated_memory(self, unique_node_ids, unique_messages, timestamps): if len(unique_node_ids) <= 0: return self.memory.memory.data.clone(), self.memory.last_update.data.clone() - assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \ - "update memory to time in the past" + assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to update memory to time in the past" updated_memory = self.memory.memory.data.clone() updated_memory[unique_node_ids] = self.memory_updater(unique_messages, updated_memory[unique_node_ids]) @@ -43,6 +42,25 @@ def get_updated_memory(self, unique_node_ids, unique_messages, timestamps): updated_last_update[unique_node_ids] = timestamps return updated_memory, updated_last_update + + def update_memory_new(self, unique_node_ids, unique_messages, timestamps): + if len(unique_node_ids) <= 0: + return + + assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \ + "update memory to time in the past" + + memory = self.memory.get_memory(unique_node_ids) + self.memory.last_update[unique_node_ids] = timestamps + + updated_memory = self.memory_updater(unique_messages, memory) + + self.memory.set_memory(unique_node_ids, updated_memory) + + updated_memory = self.memory.memory.data.clone() + updated_last_update = self.memory.last_update.data.clone() + + return updated_memory, updated_last_update class GRUMemoryUpdater(SequenceMemoryUpdater): diff --git a/modules/message_aggregator.py b/modules/message_aggregator.py index 0c610dc..5e2ff90 100644 --- a/modules/message_aggregator.py +++ b/modules/message_aggregator.py @@ -1,6 +1,7 @@ from collections import defaultdict import torch import numpy as np +from torch import nn class MessageAggregator(torch.nn.Module): @@ -80,11 +81,110 @@ def aggregate(self, node_ids, messages): return to_update_node_ids, unique_messages, unique_timestamps +class RNNMessageAggregator(MessageAggregator): + def __init__(self, device, raw_message_dimension): + super(RNNMessageAggregator, self).__init__(device) -def get_message_aggregator(aggregator_type, device): + self.raw_message_dimension = raw_message_dimension + self.message_aggregator = nn.RNNCell(input_size=raw_message_dimension, hidden_size=raw_message_dimension) + + def aggregate(self, node_ids, messages): + unique_node_ids = np.unique(node_ids) + unique_messages = [] + unique_timestamps = [] + + to_update_node_ids = [] + + for node_id in unique_node_ids: + if len(messages[node_id]) > 0: + to_update_node_ids.append(node_id) + + hidden_state = nn.Parameter(torch.zeros(self.raw_message_dimension).to(self.device), + requires_grad=False) + + for message in messages[node_id]: + hidden_state = self.message_aggregator(message[0].squeeze(0), hidden_state) + + unique_messages.append(hidden_state) + unique_timestamps.append(messages[node_id][-1][1]) + + unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else [] + unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else [] + + return to_update_node_ids, unique_messages, unique_timestamps + +class GRUMessageAggregator(MessageAggregator): + def __init__(self, device, raw_message_dimension): + super(GRUMessageAggregator, self).__init__(device) + + self.raw_message_dimension = raw_message_dimension + self.message_aggregator = nn.GRUCell(input_size=raw_message_dimension, hidden_size=raw_message_dimension) + + def aggregate(self, node_ids, messages): + unique_node_ids = np.unique(node_ids) + unique_messages = [] + unique_timestamps = [] + + to_update_node_ids = [] + + for node_id in unique_node_ids: + if len(messages[node_id]) > 0: + to_update_node_ids.append(node_id) + + hidden_state = nn.Parameter(torch.zeros(self.raw_message_dimension).to(self.device), + requires_grad=False) + + for message in messages[node_id]: + hidden_state = self.message_aggregator(message[0].squeeze(0), hidden_state) + + unique_messages.append(hidden_state.detach()) + unique_timestamps.append(messages[node_id][-1][1]) + + unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else [] + unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else [] + + return to_update_node_ids, unique_messages, unique_timestamps + +class GLUMessageAggregator(MessageAggregator): + def __init__(self, device, raw_message_dimension): + super(GRUMessageAggregator, self).__init__(device) + + self.raw_message_dimension = raw_message_dimension + self.message_aggregator = nn.GRUCell(input_size=raw_message_dimension, hidden_size=raw_message_dimension) + + def aggregate(self, node_ids, messages): + unique_node_ids = np.unique(node_ids) + unique_messages = [] + unique_timestamps = [] + + to_update_node_ids = [] + + for node_id in unique_node_ids: + if len(messages[node_id]) > 0: + to_update_node_ids.append(node_id) + + hidden_state = nn.Parameter(torch.zeros(self.raw_message_dimension).to(self.device), + requires_grad=False) + + for message in messages[node_id]: + hidden_state = self.message_aggregator(message[0].squeeze(0), hidden_state) + + unique_messages.append(hidden_state.detach()) + unique_timestamps.append(messages[node_id][-1][1]) + + unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else [] + unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else [] + + return to_update_node_ids, unique_messages, unique_timestamps + +def get_message_aggregator(aggregator_type, device, raw_message_dimension=None): if aggregator_type == "last": return LastMessageAggregator(device=device) elif aggregator_type == "mean": return MeanMessageAggregator(device=device) + elif aggregator_type == "rnn": + return RNNMessageAggregator(device=device, raw_message_dimension=raw_message_dimension) + elif aggregator_type == "gru": + return GRUMessageAggregator(device=device, raw_message_dimension=raw_message_dimension) else: raise ValueError("Message aggregator {} not implemented".format(aggregator_type)) diff --git a/modules/message_function.py b/modules/message_function.py index 58c184c..85bef82 100644 --- a/modules/message_function.py +++ b/modules/message_function.py @@ -1,15 +1,12 @@ from torch import nn - class MessageFunction(nn.Module): """ Module which computes the message for a given interaction. """ - def compute_message(self, raw_messages): return None - class MLPMessageFunction(MessageFunction): def __init__(self, raw_message_dimension, message_dimension): super(MLPMessageFunction, self).__init__() @@ -17,7 +14,15 @@ def __init__(self, raw_message_dimension, message_dimension): self.mlp = self.layers = nn.Sequential( nn.Linear(raw_message_dimension, raw_message_dimension // 2), nn.ReLU(), - nn.Linear(raw_message_dimension // 2, message_dimension), + nn.BatchNorm1d(raw_message_dimension // 2), + nn.Dropout(0.2), + + nn.Linear(raw_message_dimension // 2, raw_message_dimension // 4), + nn.ReLU(), + nn.BatchNorm1d(raw_message_dimension // 4), + nn.Dropout(0.2), + + nn.Linear(raw_message_dimension // 4, message_dimension), ) def compute_message(self, raw_messages): @@ -25,14 +30,11 @@ def compute_message(self, raw_messages): return messages - class IdentityMessageFunction(MessageFunction): - def compute_message(self, raw_messages): return raw_messages - def get_message_function(module_type, raw_message_dimension, message_dimension): if module_type == "mlp": return MLPMessageFunction(raw_message_dimension, message_dimension) diff --git a/pruebas.py b/pruebas.py new file mode 100644 index 0000000..e8749f1 --- /dev/null +++ b/pruebas.py @@ -0,0 +1,69 @@ +import pandas as pd +import numpy as np +import random +import dgl +import torch +import torch.nn + +def simplificar_wikipedia(graph_df, edge_feat): + ''' + Simplificamos el conjunto de datos de Wikipedia para ahorrar en el tiempo de las pruebas. + + Se eliminan un 50% de los usuarios y las páginas correspondientes, dejando un total de: + + * 4113/8227 usuario y 952/1000 páginas + ''' + unique_u = graph_df['u'].unique() + + num_nodes_u_to_remove = int(len(unique_u) * 0.5) + + # Establecer una semilla para reproducibilidad + random.seed(1) + + random_selection = random.sample(list(unique_u), num_nodes_u_to_remove) + + df_filtered = graph_df[graph_df['u'].isin(random_selection)] + + # Reemplazar los valores de 'u' e 'i' con nuevos índices comenzando desde 1 + new_index = {old_index: new_index + 1 for new_index, old_index in enumerate(sorted(set(df_filtered['u']).union(df_filtered['i'])))} + df_reindexed = df_filtered.replace({'u': new_index, 'i': new_index}).sort_values(by='ts') + + new_edge_feat = edge_feat[df_reindexed['idx']] + + # Reemplazar los valores del index e 'idx' con nuevos índices comenzando desde 0 + df_reindexed.reset_index(drop=True, inplace=True) + df_reindexed.iloc[:, 0]= df_reindexed.index + df_reindexed['idx'] = df_reindexed.index + + return df_reindexed, new_edge_feat + + +# graph_df = pd.read_csv('./data/wikipedia-tgn/ml_wikipedia_df.csv') +# edge_feat = np.load('./data/wikipedia-tgn/ml_wikipedia_edge_feat.npy') + +# print("Número de usuarios en el grafo original: ") +# print(len(graph_df['u'].unique())) +# print("Número de items en el grafo original: ") +# print(len(graph_df['i'].unique())) + +# df_reindexed, new_edge_feat = simplificar_wikipedia(graph_df, edge_feat) + +# print("Número de usuarios en el grafo simplificado: ") +# print(len(df_reindexed['u'].unique())) +# print("Número de items en el grafo simplificado: ") +# print(len(df_reindexed['i'].unique())) + +# df_reindexed.to_csv('./data/wikipedia-simplificada/ml_wikipedia_simplificada_df.csv', index=False) +# np.save('./data/wikipedia-simplificada/ml_wikipedia_simplificada_edge_feat.npy', new_edge_feat) + +# import pickle + +# # Ruta al archivo .pkl +# file_path = './results/link_prediction/wikipedia-simplificada/Original WorkFlow/tgn-attn/tgn-attn.pkl' + +# # Abrir el archivo .pkl +# with open(file_path, 'rb') as file: +# data = pickle.load(file) + +# # Mostrar el contenido del archivo +# print(data) \ No newline at end of file diff --git a/results_comparator.py b/results_comparator.py new file mode 100644 index 0000000..24f706a --- /dev/null +++ b/results_comparator.py @@ -0,0 +1,221 @@ +import pickle +import matplotlib.pyplot as plt +import pandas as pd +import os + +def load_results(file_path): + with open(file_path, "rb") as file: + return pickle.load(file) + +def plot_results(path, results): + # Crear una figura con múltiples subgráficas + _, axs = plt.subplots(2, 2, figsize=(12, 10)) + + # Gráfica 1: val_aps vs Epoch + axs[0, 0].set_title('APs Validación por Época') + axs[0, 0].set_xlabel('Época') + axs[0, 0].set_ylabel('AP') + + # Gráfica 2: new_nodes_val_aps vs Epoch + axs[0, 1].set_title('APs Validación nuevos nodos por Época') + axs[0, 1].set_xlabel('Época') + axs[0, 1].set_ylabel('AP') + + # Gráfica 3: train_losses vs Epoch + axs[1, 0].set_title('Pérdida de entrenamiento por Época') + axs[1, 0].set_xlabel('Época') + axs[1, 0].set_ylabel('Pérdida') + + # Gráfica 4: epoch_times vs Epoch + axs[1, 1].set_title('Tiempo por Época') + axs[1, 1].set_xlabel('Época') + axs[1, 1].set_ylabel('Tiempo') + + for name, result in results: + # Extraer los datos + val_aps = result["val_aps"] + new_nodes_val_aps = result["new_nodes_val_aps"] + train_losses = result["train_losses"] + epoch_times = result["epoch_times"] + + # Gráfica 1: val_aps vs Epoch + axs[0, 0].plot(val_aps, label="{}".format(name)) + axs[0, 0].legend() + + # Gráfica 2: new_nodes_val_aps vs Epoch + axs[0, 1].plot(new_nodes_val_aps, label="{}".format(name)) + axs[0, 1].legend() + + # Gráfica 3: train_losses vs Epoch + axs[1, 0].plot(train_losses, label="{}".format(name)) + axs[1, 0].legend() + + # Gráfica 4: epoch_times vs Epoch + axs[1, 1].plot(epoch_times, label="{}".format(name)) + axs[1, 1].legend() + + # Ajustar el layout + plt.tight_layout() + + # Guardar gráfica + plt.savefig(path + "Resultados.png", dpi=300) + +def plot_model_results(root_path, file_paths): + for model in list(file_paths.keys()): + if len(file_paths[model]) > 0: + resultados = [] + for file_path in file_paths[model]: + # Cargar los resultados + results = load_results(file_path) + + resultados.append((model, results)) + + path = root_path + model + '/' + plot_results(path, resultados) + +def evolucion_nodos(graph_df): + # Inicializar sets para guardar nodos únicos que ya hemos visto + seen_u = set() + seen_i = set() + + # Lista para almacenar los resultados + results = [] + + for batch, group in graph_df.groupby('batch'): + # Contar nodos únicos en el batch actual + new_u = group[~group['u'].isin(seen_u)]['u'].nunique() + new_i = group[~group['i'].isin(seen_i)]['i'].nunique() + + # Actualizar los sets de nodos vistos + seen_u.update(group['u']) + seen_i.update(group['i']) + + # Obtener el rango de timestamps + ts_min = group['ts'].min() + ts_max = group['ts'].max() + + # Agregar el resultado actual a la lista + results.append({ + 'batch': batch, + 'new_u': new_u, + 'new_i': new_i, + 'ts_min': ts_min, + 'ts_max': ts_max + }) + + # Convertir la lista de resultados en un DataFrame + result_df = pd.DataFrame(results) + + return result_df + +def evolucion_mensajes(graph_df): + # Número de mensajes por nodo único + # Agrupar por batch y contar las ocurrencias de cada nodo "u" + count_u_per_batch = graph_df.groupby(['batch', 'u']).size().reset_index(name='count_u') + + # Agrupar por batch y contar las ocurrencias de cada nodo "i" + count_i_per_batch = graph_df.groupby(['batch', 'i']).size().reset_index(name='count_i') + + # Obtener el nodo "u" con máximas ocurrencias en cada batch + max_u_per_batch = count_u_per_batch.loc[count_u_per_batch.groupby('batch')['count_u'].idxmax()].reset_index(drop=True) + + # Obtener el nodo "i" con máximas ocurrencias en cada batch + max_i_per_batch = count_i_per_batch.loc[count_i_per_batch.groupby('batch')['count_i'].idxmax()].reset_index(drop=True) + + return count_u_per_batch, count_i_per_batch, max_u_per_batch, max_i_per_batch + +def grafica_evolucion_nodos(evol_df): + # Crear la figura y los ejes + fig, ax1 = plt.subplots(figsize=(10, 6)) + + # Crear un gráfico de barras para los nuevos nodos + ax1.bar(evol_df['batch'] - 0.2, evol_df['new_u'], width=0.4, label='Nuevos u', align='center') + ax1.bar(evol_df['batch'] + 0.2, evol_df['new_i'], width=0.4, label='Nuevos i', align='center') + + # Etiquetas y título para el gráfico de barras + ax1.set_xlabel('Batch') + ax1.set_ylabel('Número de Nuevos Nodos') + ax1.set_title('Nuevos Nodos por Batch') + ax1.legend(loc='upper left') + + # Crear un segundo eje y para el rango de timestamps + ax2 = ax1.twinx() + ax2.plot(evol_df['batch'], evol_df['ts_min'], color='green', marker='o', linestyle='dashed', linewidth=2, label='ts_min') + ax2.plot(evol_df['batch'], evol_df['ts_max'], color='red', marker='o', linestyle='dashed', linewidth=2, label='ts_max') + + # Etiquetas y título para el gráfico de líneas + ax2.set_ylabel('Timestamps') + ax2.legend(loc='upper right') + + # Mostrar el gráfico + plt.show() + +def grafica_evolucion_mensajes(max_u_per_batch, max_i_per_batch): + # Crear dos gráficos uno al lado del otro + fig, axs = plt.subplots(1, 2, figsize=(12, 6)) # 1 fila, 2 columnas + + # Graficar los datos + # Graficar max_u_per_batch + axs[0].bar(max_u_per_batch['batch'], max_u_per_batch['count_u'], color='blue') + axs[0].set_title('Gráfico 1: Count_u por Batch') + axs[0].set_xlabel('Batch') + axs[0].set_ylabel('Count_u') + + # Graficar max_i_per_batch + axs[1].bar(max_i_per_batch['batch'], max_i_per_batch['count_i'], color='red') + axs[1].set_title('Gráfico 2: Count_i por Batch') + axs[1].set_xlabel('Batch') + axs[1].set_ylabel('Count_i') + + plt.tight_layout() # Ajusta automáticamente el diseño de las figuras + plt.show() + +def plot_best_results_ap(root_path, file_paths): + mejores_resultados = [] + for model in list(file_paths.keys()): + if len(file_paths[model]) > 0: + mejor_val_aps = -float('inf') + mejor_resultado = None + + for file_path in file_paths[model]: + # Cargar los resultados + results = load_results(file_path) + + # Extraer los datos + val_aps = results["val_aps"] + + # Obtener el valor máximo de val_aps + max_val_aps = max(val_aps) + + if max_val_aps > mejor_val_aps: + mejor_val_aps = max_val_aps + mejor_resultado = results + + mejores_resultados.append((model, mejor_resultado)) + + print(mejores_resultados) + + plot_results(root_path, mejores_resultados) + +file_paths = {} +root_path = "./results/link_prediction/wikipedia-simplificada/" + +folders = os.listdir(root_path) + +for folder in folders: + if "tgn" in folder: + file_paths[folder] = [] + + dir_path = root_path + folder + '/' + runs = os.listdir(dir_path) + + for run in runs: + if "tgn" in run: + run_path = dir_path + run + file_paths[folder].append(run_path) + +plot_model_results(root_path, file_paths) + +plot_best_results_ap(root_path, file_paths) + + diff --git a/train_self_supervised.py b/tgn_link_prediction.py similarity index 81% rename from train_self_supervised.py rename to tgn_link_prediction.py index 3cebec1..ca07d41 100644 --- a/train_self_supervised.py +++ b/tgn_link_prediction.py @@ -18,8 +18,8 @@ ### Argument and global variables parser = argparse.ArgumentParser('TGN self-supervised training') -parser.add_argument('-d', '--data', type=str, help='Dataset name (eg. wikipedia or reddit)', - default='wikipedia') +parser.add_argument('-d', '--data', type=str, help='Dataset name (eg. wikipedia-tgn, reddit-tgn, wikipedia-tgb, review-tgb, coin-tgb, comment-tgb, flight-tgb)', + default='wikipedia-tgn') parser.add_argument('--bs', type=int, default=200, help='Batch_size') parser.add_argument('--prefix', type=str, default='', help='Prefix to name the checkpoints') parser.add_argument('--n_degree', type=int, default=10, help='Number of neighbors to sample') @@ -44,7 +44,7 @@ parser.add_argument('--memory_updater', type=str, default="gru", choices=[ "gru", "rnn"], help='Type of memory updater') parser.add_argument('--aggregator', type=str, default="last", help='Type of message ' - 'aggregator') + 'aggregator (e.g. mean, last, rnn or gru)') parser.add_argument('--memory_update_at_end', action='store_true', help='Whether to update memory at the end or at the start of the batch') parser.add_argument('--message_dim', type=int, default=100, help='Dimensions of the messages') @@ -62,7 +62,12 @@ help='Whether to use the embedding of the source node as part of the message') parser.add_argument('--dyrep', action='store_true', help='Whether to run the dyrep model') +parser.add_argument('--feature_embedding_type', type=str, default='identity', help='Type of feature embedding (e.g. identity or mlp)') +parser.add_argument('--feature_dim', type=int, default=50, help='Dimensions of the feature embedding') +parser.add_argument('--neg_sample', type=str, default='rnd', help='Strategy for the edge negative sampling.') + +torch.autograd.set_detect_anomaly(True) try: args = parser.parse_args() @@ -85,12 +90,15 @@ USE_MEMORY = args.use_memory MESSAGE_DIM = args.message_dim MEMORY_DIM = args.memory_dim +NEG_SAMPLE = args.neg_sample +FEATURE_TYPE = args.feature_embedding_type +FEATURE_DIM = args.feature_dim -Path("./saved_models/").mkdir(parents=True, exist_ok=True) -Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True) -MODEL_SAVE_PATH = f'./saved_models/{args.prefix}-{args.data}.pth' +Path("./saved_models/link_prediction").mkdir(parents=True, exist_ok=True) +Path("./saved_checkpoints/link_prediction").mkdir(parents=True, exist_ok=True) +MODEL_SAVE_PATH = f'./saved_models/link_prediction/{args.prefix}-{args.data}.pth' get_checkpoint_path = lambda \ - epoch: f'./saved_checkpoints/{args.prefix}-{args.data}-{epoch}.pth' + epoch: f'./saved_checkpoints/link_prediction/{args.prefix}-{args.data}-{epoch}.pth' ### set up logger logging.basicConfig(level=logging.INFO) @@ -122,14 +130,24 @@ # Initialize negative samplers. Set seeds for validation and testing so negatives are the same # across different runs # NB: in the inductive setting, negatives are sampled only amongst other new nodes +#if NEG_SAMPLE == 'rnd': train_rand_sampler = RandEdgeSampler(train_data.sources, train_data.destinations) val_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=0) nn_val_rand_sampler = RandEdgeSampler(new_node_val_data.sources, new_node_val_data.destinations, - seed=1) + seed=1) test_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=2) nn_test_rand_sampler = RandEdgeSampler(new_node_test_data.sources, - new_node_test_data.destinations, - seed=3) + new_node_test_data.destinations, + seed=3) +# else: +# train_rand_sampler = RandEdgeSampler_adversarial(train_data.sources, train_data.destinations, train_data.timestamps, NEG_SAMPLE) +# val_rand_sampler = RandEdgeSampler_adversarial(full_data.sources, full_data.destinations, seed=0) +# nn_val_rand_sampler = RandEdgeSampler_adversarial(new_node_val_data.sources, new_node_val_data.destinations, +# seed=1) +# test_rand_sampler = RandEdgeSampler_adversarial(full_data.sources, full_data.destinations, seed=2) +# nn_test_rand_sampler = RandEdgeSampler_adversarial(new_node_test_data.sources, +# new_node_test_data.destinations, +# seed=3) # Set device device_string = 'cuda:{}'.format(GPU) if torch.cuda.is_available() else 'cpu' @@ -140,8 +158,8 @@ compute_time_statistics(full_data.sources, full_data.destinations, full_data.timestamps) for i in range(args.n_runs): - results_path = "results/{}_{}.pkl".format(args.prefix, i) if i > 0 else "results/{}.pkl".format(args.prefix) - Path("results/").mkdir(parents=True, exist_ok=True) + results_path = "results/link_prediction/{}_{}.pkl".format(args.prefix, i) if i > 0 else "results/link_prediction/{}.pkl".format(args.prefix) + Path("results/link_prediction/").mkdir(parents=True, exist_ok=True) # Initialize Model tgn = TGN(neighbor_finder=train_ngh_finder, node_features=node_features, @@ -217,9 +235,10 @@ neg_label = torch.zeros(size, dtype=torch.float, device=device) tgn = tgn.train() + pos_prob, neg_prob = tgn.compute_edge_probabilities(sources_batch, destinations_batch, negatives_batch, timestamps_batch, edge_idxs_batch, NUM_NEIGHBORS) - + loss += criterion(pos_prob.squeeze(), pos_label) + criterion(neg_prob.squeeze(), neg_label) loss /= args.backprop_every @@ -246,9 +265,9 @@ train_memory_backup = tgn.memory.backup_memory() val_ap, val_auc = eval_edge_prediction(model=tgn, - negative_edge_sampler=val_rand_sampler, - data=val_data, - n_neighbors=NUM_NEIGHBORS) + negative_edge_sampler=val_rand_sampler, + data=val_data, + n_neighbors=NUM_NEIGHBORS) if USE_MEMORY: val_memory_backup = tgn.memory.backup_memory() # Restore memory we had at the end of training to be used when validating on new nodes. @@ -258,9 +277,9 @@ # Validate on unseen nodes nn_val_ap, nn_val_auc = eval_edge_prediction(model=tgn, - negative_edge_sampler=val_rand_sampler, - data=new_node_val_data, - n_neighbors=NUM_NEIGHBORS) + negative_edge_sampler=val_rand_sampler, + data=new_node_val_data, + n_neighbors=NUM_NEIGHBORS) if USE_MEMORY: # Restore memory we had at the end of validation @@ -310,18 +329,18 @@ ### Test tgn.embedding_module.neighbor_finder = full_ngh_finder test_ap, test_auc = eval_edge_prediction(model=tgn, - negative_edge_sampler=test_rand_sampler, - data=test_data, - n_neighbors=NUM_NEIGHBORS) + negative_edge_sampler=test_rand_sampler, + data=test_data, + n_neighbors=NUM_NEIGHBORS) if USE_MEMORY: tgn.memory.restore_memory(val_memory_backup) # Test on unseen nodes nn_test_ap, nn_test_auc = eval_edge_prediction(model=tgn, - negative_edge_sampler=nn_test_rand_sampler, - data=new_node_test_data, - n_neighbors=NUM_NEIGHBORS) + negative_edge_sampler=nn_test_rand_sampler, + data=new_node_test_data, + n_neighbors=NUM_NEIGHBORS) logger.info( 'Test statistics: Old nodes -- auc: {}, ap: {}'.format(test_auc, test_ap)) diff --git a/tgn_link_prediction_PyG.py b/tgn_link_prediction_PyG.py new file mode 100644 index 0000000..b9f3e18 --- /dev/null +++ b/tgn_link_prediction_PyG.py @@ -0,0 +1,195 @@ +# This code achieves a performance of around 96.60%. However, it is not +# directly comparable to the results reported by the TGN paper since a +# slightly different evaluation setup is used here. +# In particular, predictions in the same batch are made in parallel, i.e. +# predictions for interactions later in the batch have no access to any +# information whatsoever about previous interactions in the same batch. +# On the contrary, when sampling node neighborhoods for interactions later in +# the batch, the TGN paper code has access to previous interactions in the +# batch. +# While both approaches are correct, together with the authors of the paper we +# decided to present this version here as it is more realsitic and a better +# test bed for future methods. + +import os.path as osp + +import torch +from sklearn.metrics import average_precision_score, roc_auc_score +from torch.nn import Linear + +from torch_geometric.datasets import JODIEDataset +from torch_geometric.loader import TemporalDataLoader +from torch_geometric.nn import TGNMemory, TransformerConv +from torch_geometric.nn.models.tgn import ( + IdentityMessage, + LastAggregator, + LastNeighborLoader, +) + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'JODIE') +dataset = JODIEDataset(path, name='wikipedia') +data = dataset[0] + +# For small datasets, we can put the whole dataset on GPU and thus avoid +# expensive memory transfer costs for mini-batches: +data = data.to(device) + +train_data, val_data, test_data = data.train_val_test_split( + val_ratio=0.15, test_ratio=0.15) + +train_loader = TemporalDataLoader( + train_data, + batch_size=200, + neg_sampling_ratio=1.0, +) +val_loader = TemporalDataLoader( + val_data, + batch_size=200, + neg_sampling_ratio=1.0, +) +test_loader = TemporalDataLoader( + test_data, + batch_size=200, + neg_sampling_ratio=1.0, +) +neighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device) + + +class GraphAttentionEmbedding(torch.nn.Module): + def __init__(self, in_channels, out_channels, msg_dim, time_enc): + super().__init__() + self.time_enc = time_enc + edge_dim = msg_dim + time_enc.out_channels + self.conv = TransformerConv(in_channels, out_channels // 2, heads=2, + dropout=0.1, edge_dim=edge_dim) + + def forward(self, x, last_update, edge_index, t, msg): + rel_t = last_update[edge_index[0]] - t + rel_t_enc = self.time_enc(rel_t.to(x.dtype)) + edge_attr = torch.cat([rel_t_enc, msg], dim=-1) + return self.conv(x, edge_index, edge_attr) + + +class LinkPredictor(torch.nn.Module): + def __init__(self, in_channels): + super().__init__() + self.lin_src = Linear(in_channels, in_channels) + self.lin_dst = Linear(in_channels, in_channels) + self.lin_final = Linear(in_channels, 1) + + def forward(self, z_src, z_dst): + h = self.lin_src(z_src) + self.lin_dst(z_dst) + h = h.relu() + return self.lin_final(h) + + +memory_dim = time_dim = embedding_dim = 100 + +memory = TGNMemory( + data.num_nodes, + data.msg.size(-1), + memory_dim, + time_dim, + message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim), + aggregator_module=LastAggregator(), +).to(device) + +gnn = GraphAttentionEmbedding( + in_channels=memory_dim, + out_channels=embedding_dim, + msg_dim=data.msg.size(-1), + time_enc=memory.time_enc, +).to(device) + +link_pred = LinkPredictor(in_channels=embedding_dim).to(device) + +optimizer = torch.optim.Adam( + set(memory.parameters()) | set(gnn.parameters()) + | set(link_pred.parameters()), lr=0.0001) +criterion = torch.nn.BCEWithLogitsLoss() + +# Helper vector to map global node indices to local ones. +assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device) + + +def train(): + memory.train() + gnn.train() + link_pred.train() + + memory.reset_state() # Start with a fresh memory. + neighbor_loader.reset_state() # Start with an empty graph. + + total_loss = 0 + for batch in train_loader: + optimizer.zero_grad() + batch = batch.to(device) + + n_id, edge_index, e_id = neighbor_loader(batch.n_id) + assoc[n_id] = torch.arange(n_id.size(0), device=device) + + # Get updated memory of all nodes involved in the computation. + z, last_update = memory(n_id) + z = gnn(z, last_update, edge_index, data.t[e_id].to(device), + data.msg[e_id].to(device)) + pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]]) + neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]]) + + loss = criterion(pos_out, torch.ones_like(pos_out)) + loss += criterion(neg_out, torch.zeros_like(neg_out)) + + # Update memory and neighbor loader with ground-truth state. + memory.update_state(batch.src, batch.dst, batch.t, batch.msg) + neighbor_loader.insert(batch.src, batch.dst) + + loss.backward() + optimizer.step() + memory.detach() + total_loss += float(loss) * batch.num_events + + return total_loss / train_data.num_events + + +@torch.no_grad() +def test(loader): + memory.eval() + gnn.eval() + link_pred.eval() + + torch.manual_seed(12345) # Ensure deterministic sampling across epochs. + + aps, aucs = [], [] + for batch in loader: + batch = batch.to(device) + + n_id, edge_index, e_id = neighbor_loader(batch.n_id) + assoc[n_id] = torch.arange(n_id.size(0), device=device) + + z, last_update = memory(n_id) + z = gnn(z, last_update, edge_index, data.t[e_id].to(device), + data.msg[e_id].to(device)) + pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]]) + neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]]) + + y_pred = torch.cat([pos_out, neg_out], dim=0).sigmoid().cpu() + y_true = torch.cat( + [torch.ones(pos_out.size(0)), + torch.zeros(neg_out.size(0))], dim=0) + + aps.append(average_precision_score(y_true, y_pred)) + aucs.append(roc_auc_score(y_true, y_pred)) + + memory.update_state(batch.src, batch.dst, batch.t, batch.msg) + neighbor_loader.insert(batch.src, batch.dst) + return float(torch.tensor(aps).mean()), float(torch.tensor(aucs).mean()) + + +for epoch in range(1, 51): + loss = train() + print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}') + val_ap, val_auc = test(val_loader) + test_ap, test_auc = test(test_loader) + print(f'Val AP: {val_ap:.4f}, Val AUC: {val_auc:.4f}') + print(f'Test AP: {test_ap:.4f}, Test AUC: {test_auc:.4f}') \ No newline at end of file diff --git a/train_supervised.py b/tgn_node_classification.py similarity index 94% rename from train_supervised.py rename to tgn_node_classification.py index 65c32ec..40408d5 100644 --- a/train_supervised.py +++ b/tgn_node_classification.py @@ -12,7 +12,7 @@ from model.tgn import TGN from utils.utils import EarlyStopMonitor, get_neighbor_finder, MLP -from utils.data_processing import compute_time_statistics, get_data_node_classification +from utils.tgn_data_processing import compute_time_statistics, get_data_node_classification from evaluation.evaluation import eval_node_classification random.seed(0) @@ -92,12 +92,12 @@ MESSAGE_DIM = args.message_dim MEMORY_DIM = args.memory_dim -Path("./saved_models/").mkdir(parents=True, exist_ok=True) -Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True) -MODEL_SAVE_PATH = f'./saved_models/{args.prefix}-{args.data}' + '\ +Path("./saved_models/node_classification").mkdir(parents=True, exist_ok=True) +Path("./saved_checkpoints/node_classification").mkdir(parents=True, exist_ok=True) +MODEL_SAVE_PATH = f'./saved_models/node_classification/{args.prefix}-{args.data}' + '\ node-classification.pth' get_checkpoint_path = lambda \ - epoch: f'./saved_checkpoints/{args.prefix}-{args.data}-{epoch}' + '\ + epoch: f'./saved_checkpoints/node_classification/{args.prefix}-{args.data}-{epoch}' + '\ node-classification.pth' ### set up logger @@ -131,8 +131,8 @@ compute_time_statistics(full_data.sources, full_data.destinations, full_data.timestamps) for i in range(args.n_runs): - results_path = "results/{}_node_classification_{}.pkl".format(args.prefix, - i) if i > 0 else "results/{}_node_classification.pkl".format( + results_path = "results/node_classification/{}_node_classification_{}.pkl".format(args.prefix, + i) if i > 0 else "results/node_classification/{}_node_classification.pkl".format( args.prefix) Path("results/").mkdir(parents=True, exist_ok=True) @@ -167,7 +167,7 @@ logger.info('Start training node classification task') decoder = MLP(node_features.shape[1], drop=DROP_OUT) - decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=args.lr) + decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=LEARNING_RATE) decoder = decoder.to(device) decoder_loss_criterion = torch.nn.BCELoss() diff --git a/utils/data_processing.py b/utils/data_processing.py index 3ece032..8771cd6 100644 --- a/utils/data_processing.py +++ b/utils/data_processing.py @@ -51,9 +51,13 @@ def get_data_node_classification(dataset_name, use_validation=False): def get_data(dataset_name, different_new_nodes_between_val_and_test=False, randomize_features=False): ### Load data and train val test split - graph_df = pd.read_csv('./data/ml_{}.csv'.format(dataset_name)) - edge_features = np.load('./data/ml_{}.npy'.format(dataset_name)) - node_features = np.load('./data/ml_{}_node.npy'.format(dataset_name)) + graph_df = pd.read_csv('./data/{}/ml_{}_df.csv'.format(dataset_name, dataset_name.split('-')[0])) + edge_features = np.load('./data/{}/ml_{}_edge_feat.npy'.format(dataset_name, dataset_name.split('-')[0])) + if dataset_name == "flight": + node_features = np.load('./data/{}/ml_{}_node_feat.npy'.format(dataset_name, dataset_name.split('-')[0])) + else: + max_idx = max(graph_df.u.max(), graph_df.i.max()) + node_features = np.zeros((max_idx + 1, 172)) if randomize_features: node_features = np.random.rand(node_features.shape[0], node_features.shape[1]) @@ -78,7 +82,7 @@ def get_data(dataset_name, different_new_nodes_between_val_and_test=False, rando set(destinations[timestamps > val_time])) # Sample nodes which we keep as new nodes (to test inductiveness), so than we have to remove all # their edges from training - new_test_node_set = set(random.sample(test_node_set, int(0.1 * n_total_unique_nodes))) + new_test_node_set = set(random.sample(list(test_node_set), int(0.1 * n_total_unique_nodes))) # Mask saying for each source and destination whether they are new test nodes new_test_source_mask = graph_df.u.map(lambda x: x in new_test_node_set).values @@ -91,7 +95,7 @@ def get_data(dataset_name, different_new_nodes_between_val_and_test=False, rando # For train we keep edges happening before the validation time which do not involve any new node # used for inductiveness train_mask = np.logical_and(timestamps <= val_time, observed_edges_mask) - + train_data = Data(sources[train_mask], destinations[train_mask], timestamps[train_mask], edge_idxs[train_mask], labels[train_mask]) diff --git a/utils/link_data_preprocessing/tgb_data.py b/utils/link_data_preprocessing/tgb_data.py new file mode 100644 index 0000000..ba1c41f --- /dev/null +++ b/utils/link_data_preprocessing/tgb_data.py @@ -0,0 +1,353 @@ +import numpy as np +import pandas as pd +from pathlib import Path +import argparse +import pickle +import csv +import tqdm + +def preprocess_wikipedia(data_name): + ''' + [u, i, ts, label, feats] + ''' + u_list, i_list, ts_list, label_list = [], [], [], [] + feat_l = [] + idx_list = [] + + with open(data_name) as f: + s = next(f) + for idx, line in enumerate(f): + e = line.strip().split(',') + u = int(e[0]) + i = int(e[1]) + + ts = float(e[2]) + label = float(e[3]) + + feat = np.array([float(x) for x in e[4:]]) + + u_list.append(u) + i_list.append(i) + ts_list.append(ts) + label_list.append(label) + idx_list.append(idx) + + feat_l.append(feat) + return pd.DataFrame({'u': u_list, + 'i': i_list, + 'ts': ts_list, + 'label': label_list, + 'idx': idx_list}), np.array(feat_l), None + +def preprocess_review_coin(data_name): + ''' + [ts, u, i, weight] + ''' + u_list, i_list, ts_list, label_list, idx_list, feat_list = [], [], [], [], [], [] + + node_dict = {} + nodes_id = 0 + + with open(data_name) as f: + s = next(f) + for idx, line in enumerate(f): + e = line.strip().split(',') + + ts = float(e[0]) + if "coin" in data_name: + u = e[1] + i = e[2] + else: + u = int(e[1]) + i = int(e[2]) + + if u not in node_dict: + node_dict[u] = nodes_id + nodes_id += 1 + if i not in node_dict: + node_dict[i] = nodes_id + nodes_id += 1 + + u_list.append(node_dict[u]) + i_list.append(node_dict[i]) + ts_list.append(ts) + label_list.append(0) + idx_list.append(idx) + + feat_list.append(np.zeros(1)) + + + return pd.DataFrame({"u": u_list, + "i": i_list, + "ts": ts_list, + 'label': label_list, + "idx": idx_list}), np.array(feat_list), node_dict + + +def preprocess_comment(data_name): + ''' + [ts, u, i, subrredit, num_words, score] + ''' + u_list, i_list, ts_list, label_list, idx_list, feat_list = [], [], [], [], [], [] + + # Del código TGB + max_words = 500 + + node_dict = {} + nodes_id = 0 + + + with open(data_name) as f: + s = next(f) + for idx, line in enumerate(f): + e = line.strip().split(',') + + ts = float(e[0]) + u = e[1] + if u not in node_dict: + node_dict[u] = nodes_id + nodes_id += 1 + + i = e[2] + if i not in node_dict: + node_dict[i] = nodes_id + nodes_id += 1 + + u_list.append(node_dict[u]) + i_list.append(node_dict[i]) + ts_list.append(ts) + label_list.append(0) + idx_list.append(idx) + + feat_list.append(np.array([(float(e[4])/max_words)])) + + + return pd.DataFrame({"u": u_list, + "i": i_list, + "ts": ts_list, + "label": label_list, + "idx": idx_list}), np.array(feat_list), node_dict + +def preprocess_flight(data_name): + ''' + [ts, u, i, callsing, typecode] + ''' + u_list, i_list, ts_list, label_list, idx_list, feat_list = [], [], [], [], [], [] + + node_dict = {} + nodes_id = 0 + + + with open(data_name) as f: + s = next(f) + for idx, line in enumerate(f): + e = line.strip().split(',') + + ts = float(e[0]) + u = e[1] + if u not in node_dict: + node_dict[u] = nodes_id + nodes_id += 1 + + i = e[2] + if i not in node_dict: + node_dict[i] = nodes_id + nodes_id += 1 + + u_list.append(node_dict[u]) + i_list.append(node_dict[i]) + ts_list.append(ts) + label_list.append(0) + idx_list.append(idx) + + # Fix size to 8 with ! + if len(e[3]) == 0: + e[3] = "!!!!!!!!" + while len(e[3]) < 8: + e[3] += "!" + + if len(e[4]) == 0: + e[4] = "!!!!!!!!" + while len(e[4]) < 8: + e[4] += "!" + if len(e[4]) > 8: + e[4] = "!!!!!!!!" + + feat_str = e[3] + e[4] + + feat_list.append(convert_str2int(feat_str)) + + + return pd.DataFrame({"u": u_list, + "i": i_list, + "ts": ts_list, + "label": label_list, + "idx": idx_list}), np.array(feat_list), node_dict + +def reindex(df, node_dict, bipartite=True): + new_df = df.copy() + if bipartite: + assert (df.u.max() - df.u.min() + 1 == len(df.u.unique())) + assert (df.i.max() - df.i.min() + 1 == len(df.i.unique())) + + upper_u = df.u.max() + 1 + new_i = df.i + upper_u + + new_df.i = new_i + new_df.u += 1 + new_df.i += 1 + new_df.idx += 1 + + new_dict = None + else: + new_df.u += 1 + new_df.i += 1 + new_df.idx += 1 + + new_dict = {key: value + 1 for key, value in node_dict.items()} + + return new_df, new_dict + +def convert_str2int(in_str: str) -> np.ndarray: + """ + convert strings to vectors of integers based on individual character + each letter is converted as follows, a=10, b=11 + numbers are still int + Parameters: + in_str: an input string to parse + Returns: + out: a numpy integer array + """ + out = [] + for element in in_str: + if element.isnumeric(): + out.append(element) + elif element == "!": + out.append(-1) + else: + out.append(ord(element.upper()) - 44 + 9) + out = np.array(out, dtype=np.float32) + return out + +def process_flight_node_feat(fname: str, node_ids): + """ + 1. need to have the same node id as csv_to_pd_data + 2. process the various node features into a vector + 3. return a numpy array of node features with index corresponding to node id + + airport_code,type,continent,iso_region,longitude,latitude + type: onehot encoding + continent: onehot encoding + iso_region: alphabet encoding same as edge feat + longitude: float divide by 180 + latitude: float divide by 90 + """ + feat_size = 20 + node_feat = np.zeros((len(node_ids), feat_size)) + type_dict = {} + type_idx = 0 + continent_dict = {} + cont_idx = 0 + + with open(fname, "r") as csv_file: + csv_reader = csv.reader(csv_file, delimiter=",") + idx = 0 + # airport_code,type,continent,iso_region,longitude,latitude + for row in tqdm(csv_reader): + if idx == 0: + idx += 1 + continue + else: + code = row[0] + if code not in node_ids: + continue + else: + node_id = node_ids[code] + airport_type = row[1] + if airport_type not in type_dict: + type_dict[airport_type] = type_idx + type_idx += 1 + continent = row[2] + if continent not in continent_dict: + continent_dict[continent] = cont_idx + cont_idx += 1 + + with open(fname, "r") as csv_file: + csv_reader = csv.reader(csv_file, delimiter=",") + idx = 0 + # airport_code,type,continent,iso_region,longitude,latitude + for row in tqdm(csv_reader): + if idx == 0: + idx += 1 + continue + else: + code = row[0] + if code not in node_ids: + continue + else: + node_id = node_ids[code] + airport_type = type_dict[row[1]] + type_vec = np.zeros(type_idx) + type_vec[airport_type] = 1 + continent = continent_dict[row[2]] + cont_vec = np.zeros(cont_idx) + cont_vec[continent] = 1 + while len(row[3]) < 7: + row[3] += "!" + iso_region = convert_str2int(row[3]) # numpy float array + lng = float(row[4]) + lat = float(row[5]) + coor_vec = np.array([lng, lat]) + final = np.concatenate( + (type_vec, cont_vec, iso_region, coor_vec), axis=0 + ) + node_feat[node_id] = final + return node_feat + +def run(data_name, bipartite=True): + Path("data/").mkdir(parents=True, exist_ok=True) + PATH = './data/{}/{}_edgelist.csv'.format(data_name, data_name.split('-')[0]) + OUT_DF = './data/{}/ml_{}_df.csv'.format(data_name, data_name.split('-')[0]) + OUT_EDGE_FEAT = './data/{}/ml_{}_edge_feat.npy'.format(data_name, data_name.split('-')[0]) + OUT_NODE_FEAT = './data/{}/ml_{}_node_feat.npy'.format(data_name, data_name.split('-')[0]) + OUT_NODE_DICT = './data/{}/ml_{}_node_dict.pkl'.format(data_name, data_name.split('-')[0]) + + if data_name == "wikipedia-tgb": + df, feat, node_dict = preprocess_wikipedia(PATH) + elif data_name in ["review-tgb", "coin-tgb"]: + df, feat, node_dict = preprocess_review_coin(PATH) + elif data_name == "comment-tgb": + df, feat, node_dict = preprocess_comment(PATH) + elif data_name == "flight-tgb": + df, feat, node_dict = preprocess_flight(PATH) + else: + print("Conjunto de datos no existente") + return + + new_df, new_dict = reindex(df, node_dict, bipartite) + + empty = np.zeros(feat.shape[1])[np.newaxis, :] + feat = np.vstack([empty, feat]) + + new_df.to_csv(OUT_DF) + np.save(OUT_EDGE_FEAT, feat) + + # Se deciden guardar únicamente las características de los nodos del conjunto "flight" ya que no son 0. + if data_name == "flight-tgb": + # Se calculan las características de los nodos de "flight" + node_feat = process_flight_node_feat("./data/flight-tgb/airport_node_feat.csv", node_dict) + np.save(OUT_NODE_FEAT, node_feat) + + if node_dict != None: + with open(OUT_NODE_DICT, 'wb') as archivo: + pickle.dump(new_dict, archivo) + + +parser = argparse.ArgumentParser('Interface for TGB data preprocessing') +parser.add_argument('--data_name', type=str, help='Dataset name (eg. wikipedia-tgb, review-tgb, coin-tgb, comment-tgb, flight-tgb)', + default='wikipedia') +parser.add_argument('--bipartite', action='store_true', help='Whether the graph is bipartite') + +args = parser.parse_args() + +run(args.data_name, bipartite=args.bipartite) \ No newline at end of file diff --git a/utils/preprocess_data.py b/utils/link_data_preprocessing/tgn_data.py similarity index 66% rename from utils/preprocess_data.py rename to utils/link_data_preprocessing/tgn_data.py index 9339d39..6369fca 100644 --- a/utils/preprocess_data.py +++ b/utils/link_data_preprocessing/tgn_data.py @@ -18,7 +18,7 @@ def preprocess(data_name): i = int(e[1]) ts = float(e[2]) - label = float(e[3]) # int(e[3]) + label = float(e[3]) feat = np.array([float(x) for x in e[4:]]) @@ -35,7 +35,6 @@ def preprocess(data_name): 'label': label_list, 'idx': idx_list}), np.array(feat_l) - def reindex(df, bipartite=True): new_df = df.copy() if bipartite: @@ -58,30 +57,24 @@ def reindex(df, bipartite=True): def run(data_name, bipartite=True): - Path("data/").mkdir(parents=True, exist_ok=True) - PATH = './data/{}.csv'.format(data_name) - OUT_DF = './data/ml_{}.csv'.format(data_name) - OUT_FEAT = './data/ml_{}.npy'.format(data_name) - OUT_NODE_FEAT = './data/ml_{}_node.npy'.format(data_name) - - df, feat = preprocess(PATH) - new_df = reindex(df, bipartite) - - empty = np.zeros(feat.shape[1])[np.newaxis, :] - feat = np.vstack([empty, feat]) - - max_idx = max(new_df.u.max(), new_df.i.max()) - rand_feat = np.zeros((max_idx + 1, 172)) - - new_df.to_csv(OUT_DF) - np.save(OUT_FEAT, feat) - np.save(OUT_NODE_FEAT, rand_feat) - + Path("data/").mkdir(parents=True, exist_ok=True) + PATH = './data/{}/{}.csv'.format(data_name, data_name.split('-')[0]) + OUT_DF = './data/{}/ml_{}_df.csv'.format(data_name, data_name.split('-')[0]) + OUT_FEAT = './data/{}/ml_{}_edge_feat.npy'.format(data_name, data_name.split('-')[0]) + df, feat = preprocess(PATH) + new_df = reindex(df, bipartite) + + empty = np.zeros(feat.shape[1])[np.newaxis, :] + feat = np.vstack([empty, feat]) + + new_df.to_csv(OUT_DF) + np.save(OUT_FEAT, feat) + parser = argparse.ArgumentParser('Interface for TGN data preprocessing') -parser.add_argument('--data', type=str, help='Dataset name (eg. wikipedia or reddit)', +parser.add_argument('--data_name', type=str, help='Dataset name (eg. wikipedia or reddit)', default='wikipedia') parser.add_argument('--bipartite', action='store_true', help='Whether the graph is bipartite') args = parser.parse_args() -run(args.data, bipartite=args.bipartite) \ No newline at end of file +run(args.data_name, bipartite=args.bipartite) \ No newline at end of file