Skip to content

Commit

Permalink
Fixing multiple warnings related to the conversion of arrays with ndi…
Browse files Browse the repository at this point in the history
…m > 0 to scalars.
  • Loading branch information
drivanov committed Jan 29, 2025
1 parent 51907e0 commit 5c30d1d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/dgl/nn/pytorch/explain/subgraphx.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def shapley(self, subgraph_nodes):
device = self.feat.device
for _ in range(self.shapley_steps):
permuted_space = np.random.permutation(coalition_space)
split_idx = int(np.where(permuted_space == split_point)[0])
split_idx = int(np.where(permuted_space == split_point)[0][0])

selected_nodes = permuted_space[:split_idx]

Expand Down Expand Up @@ -490,7 +490,7 @@ def shapley(self, subgraph_nodes):
selected_node_map = dict()
for ntype, nodes in coalition_space.items():
permuted_space = np.random.permutation(nodes)
split_idx = int(np.where(permuted_space == split_point)[0])
split_idx = int(np.where(permuted_space == split_point)[0][0])
selected_node_map[ntype] = permuted_space[:split_idx]

# Mask for coalition set S_i
Expand Down

0 comments on commit 5c30d1d

Please sign in to comment.