Skip to content

Commit 9079d9d

Browse files
committed
reorder_like
1 parent de6edb2 commit 9079d9d

File tree

3 files changed

+21
-10
lines changed

3 files changed

+21
-10
lines changed

example/gsat.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import scipy
55
import torch
66
import torch.nn as nn
7+
from torch_sparse import transpose
78
from torch_geometric.utils import is_undirected
8-
from utils import MLP
9+
from utils import MLP, reorder_like
910

1011

1112
class GSAT(nn.Module):
@@ -40,9 +41,9 @@ def forward_pass(self, data, epoch, training):
4041

4142
if self.learn_edge_att:
4243
if is_undirected(data.edge_index):
43-
nodesize = data.x.shape[0]
44-
sci_csr = scipy.sparse.csr_matrix((torch.arange(att.shape[0]), (data.edge_index[0].cpu(), data.edge_index[1].cpu())), (nodesize, nodesize))
45-
edge_att = (att + att[sci_csr[data.edge_index[1].tolist(), data.edge_index[0].tolist()].A1]) / 2
44+
trans_idx, trans_val = transpose(data.edge_index, att, None, None, coalesced=False)
45+
trans_val_perm = reorder_like(trans_idx, data.edge_index, trans_val)
46+
edge_att = (att + trans_val_perm) / 2
4647
else:
4748
edge_att = att
4849
else:

src/run_gsat.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@
99
import torch
1010
import torch.nn as nn
1111
from torch.optim.lr_scheduler import ReduceLROnPlateau
12-
from torch_geometric.utils import subgraph, is_undirected
12+
from torch_sparse import transpose
1313
from torch_geometric.loader import DataLoader
14+
from torch_geometric.utils import subgraph, is_undirected
1415
from ogb.graphproppred import Evaluator
1516
from sklearn.metrics import roc_auc_score
1617
from rdkit import Chem
1718

1819
from pretrain_clf import train_clf_one_seed
1920
from utils import Writer, Criterion, MLP, visualize_a_graph, save_checkpoint, load_checkpoint, get_preds, get_lr, set_seed, process_data
20-
from utils import get_local_config_name, get_model, get_data_loaders, write_stat_from_metric_dicts, init_metric_dict
21+
from utils import get_local_config_name, get_model, get_data_loaders, write_stat_from_metric_dicts, reorder_like, init_metric_dict
2122

2223

2324
class GSAT(nn.Module):
@@ -75,9 +76,9 @@ def forward_pass(self, data, epoch, training):
7576

7677
if self.learn_edge_att:
7778
if is_undirected(data.edge_index):
78-
nodesize = data.x.shape[0]
79-
sp_csr = scipy.sparse.csr_matrix((torch.arange(att.shape[0]), (data.edge_index[0].cpu(), data.edge_index[1].cpu())), (nodesize, nodesize))
80-
edge_att = (att + att[sp_csr[data.edge_index[1].tolist(), data.edge_index[0].tolist()].A1]) / 2
79+
trans_idx, trans_val = transpose(data.edge_index, att, None, None, coalesced=False)
80+
trans_val_perm = reorder_like(trans_idx, data.edge_index, trans_val)
81+
edge_att = (att + trans_val_perm) / 2
8182
else:
8283
edge_att = att
8384
else:

src/utils/utils.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from rdkit import Chem
66
import matplotlib.pyplot as plt
77
from torch_geometric.data import Data
8-
from torch_geometric.utils import to_networkx
8+
from torch_geometric.utils import to_networkx, sort_edge_index
99
from torch.utils.tensorboard import SummaryWriter
1010
from torch.utils.tensorboard.summary import hparams
1111

@@ -16,6 +16,15 @@
1616
'metric/best_x_precision_train': 0, 'metric/best_x_precision_valid': 0, 'metric/best_x_precision_test': 0}
1717

1818

19+
def reorder_like(from_edge_index, to_edge_index, values):
20+
from_edge_index, values = sort_edge_index(from_edge_index, values)
21+
ranking_score = to_edge_index[0] * (to_edge_index.max()+1) + to_edge_index[1]
22+
ranking = ranking_score.argsort().argsort()
23+
if not (from_edge_index[:, ranking] == to_edge_index).all():
24+
raise ValueError("Edges in from_edge_index and to_edge_index are different, impossible to match both.")
25+
return values[ranking]
26+
27+
1928
def process_data(data, use_edge_attr):
2029
if not use_edge_attr:
2130
data.edge_attr = None

0 commit comments

Comments
 (0)