|
9 | 9 | import torch
|
10 | 10 | import torch.nn as nn
|
11 | 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau
|
12 |
| -from torch_geometric.utils import subgraph, is_undirected |
| 12 | +from torch_sparse import transpose |
13 | 13 | from torch_geometric.loader import DataLoader
|
| 14 | +from torch_geometric.utils import subgraph, is_undirected |
14 | 15 | from ogb.graphproppred import Evaluator
|
15 | 16 | from sklearn.metrics import roc_auc_score
|
16 | 17 | from rdkit import Chem
|
17 | 18 |
|
18 | 19 | from pretrain_clf import train_clf_one_seed
|
19 | 20 | from utils import Writer, Criterion, MLP, visualize_a_graph, save_checkpoint, load_checkpoint, get_preds, get_lr, set_seed, process_data
|
20 |
| -from utils import get_local_config_name, get_model, get_data_loaders, write_stat_from_metric_dicts, init_metric_dict |
| 21 | +from utils import get_local_config_name, get_model, get_data_loaders, write_stat_from_metric_dicts, reorder_like, init_metric_dict |
21 | 22 |
|
22 | 23 |
|
23 | 24 | class GSAT(nn.Module):
|
@@ -75,9 +76,9 @@ def forward_pass(self, data, epoch, training):
|
75 | 76 |
|
76 | 77 | if self.learn_edge_att:
|
77 | 78 | if is_undirected(data.edge_index):
|
78 |
| - nodesize = data.x.shape[0] |
79 |
| - sp_csr = scipy.sparse.csr_matrix((torch.arange(att.shape[0]), (data.edge_index[0].cpu(), data.edge_index[1].cpu())), (nodesize, nodesize)) |
80 |
| - edge_att = (att + att[sp_csr[data.edge_index[1].tolist(), data.edge_index[0].tolist()].A1]) / 2 |
| 79 | + trans_idx, trans_val = transpose(data.edge_index, att, None, None, coalesced=False) |
| 80 | + trans_val_perm = reorder_like(trans_idx, data.edge_index, trans_val) |
| 81 | + edge_att = (att + trans_val_perm) / 2 |
81 | 82 | else:
|
82 | 83 | edge_att = att
|
83 | 84 | else:
|
|
0 commit comments