Skip to content

Commit db44583

Browse files
authored
Expose sparse matrix normalization (#24)
* extract utility to normalize a sparse matrix * add test for sparse matrix normalization utility * remove unused flake8 ignores & fix comments * update GHA to test on torch 1.13 too
1 parent a5de688 commit db44583

File tree

5 files changed

+56
-15
lines changed

5 files changed

+56
-15
lines changed

.github/workflows/tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ jobs:
5555
matrix:
5656
os: [ "ubuntu-latest" ]
5757
python-version: [ "3.8", "3.9", "3.10" ]
58-
torch-version: [ "torch-1.11", "torch-1.12" ]
58+
torch-version: [ "torch-1.11", "torch-1.12", "torch-1.13" ]
5959
runs-on: ${{ matrix.os }}
6060
steps:
6161
- uses: actions/checkout@v2

setup.cfg

+4-6
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,10 @@ strictness = short
131131
#########################
132132
[flake8]
133133
ignore =
134-
S301 # pickle
135-
S403 # pickle
136-
S404
137-
S603
138-
W503 # Line break before binary operator (flake8 is wrong)
139-
E203 # whitespace before ':'
134+
# Line break before binary operator (flake8 is wrong)
135+
W503
136+
# whitespace before ':'
137+
#E203
140138
exclude =
141139
.tox,
142140
.git,

src/torch_ppr/utils.py

+33-7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"prepare_x0",
1818
"power_iteration",
1919
"batched_personalized_page_rank",
20+
"sparse_normalize",
2021
]
2122

2223
logger = logging.getLogger(__name__)
@@ -172,6 +173,33 @@ def sparse_diagonal(values: torch.Tensor) -> torch.Tensor:
172173
)
173174

174175

176+
def sparse_normalize(matrix: torch.Tensor, dim: int = 0) -> torch.Tensor:
177+
"""
178+
Normalize a sparse matrix to row/column sum of 1.
179+
180+
:param matrix:
181+
the sparse matrix
182+
:param dim:
183+
the dimension along which to normalize, either 0 for rows or 1 for columns
184+
185+
:return:
186+
the normalized sparse matrix
187+
"""
188+
# calculate row/column sum
189+
row_or_column_sum = (
190+
torch.sparse.sum(matrix, dim=dim).to_dense().clamp_min(min=torch.finfo(matrix.dtype).eps)
191+
)
192+
# invert and create diagonal matrix
193+
scaling_matrix = sparse_diagonal(values=torch.reciprocal(row_or_column_sum))
194+
# multiply matrix
195+
if dim == 0:
196+
args = (matrix, scaling_matrix)
197+
else:
198+
args = (scaling_matrix, matrix)
199+
# note: we do not pass by keyword due to instable API
200+
return torch.sparse.mm(*args)
201+
202+
175203
def prepare_page_rank_adjacency(
176204
adj: Optional[torch.Tensor] = None,
177205
edge_index: Optional[torch.LongTensor] = None,
@@ -219,12 +247,8 @@ def prepare_page_rank_adjacency(
219247
if add_identity:
220248
adj = adj + sparse_diagonal(torch.ones(adj.shape[0], dtype=adj.dtype, device=adj.device))
221249

222-
# adjacency normalization: normalize to col-sum = 1
223-
degree_inv = torch.reciprocal(
224-
torch.sparse.sum(adj, dim=0).to_dense().clamp_min(min=torch.finfo(adj.dtype).eps)
225-
)
226-
degree_inv = sparse_diagonal(values=degree_inv)
227-
return torch.sparse.mm(adj, degree_inv)
250+
# adjacency normalization: normalize to row-sum = 1
251+
return sparse_normalize(matrix=adj, dim=0)
228252

229253

230254
def validate_x(x: torch.Tensor, n: Optional[int] = None) -> None:
@@ -259,7 +283,9 @@ def validate_x(x: torch.Tensor, n: Optional[int] = None) -> None:
259283

260284

261285
def prepare_x0(
262-
x0: Optional[torch.Tensor] = None, indices: Collection[int] = None, n: Optional[int] = None
286+
x0: Optional[torch.Tensor] = None,
287+
indices: Optional[Collection[int]] = None,
288+
n: Optional[int] = None,
263289
) -> torch.Tensor:
264290
"""
265291
Prepare a start value.

tests/test_utils.py

+16
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,19 @@ def test_sparse_diagonal(n: int):
181181
assert matrix.shape == (n, n)
182182
assert matrix.is_sparse
183183
assert torch.allclose(matrix.to_dense(), torch.diag(values))
184+
185+
186+
@pytest.mark.parametrize("seed", [21, 42, 63])
187+
def test_sparse_normalize(seed: int):
188+
"""Test for sparse matrix normalization."""
189+
generator = torch.manual_seed(seed=seed)
190+
n_rows, n_cols = torch.randint(10, 20, size=(2,), generator=generator)
191+
matrix = torch.rand(size=(n_rows, n_cols), generator=generator)
192+
# make sparse
193+
matrix[matrix < 0.5] = 0
194+
matrix = matrix.to_sparse()
195+
# normalize
196+
for dim in (0, 1):
197+
matrix_norm = utils.sparse_normalize(matrix=matrix, dim=dim)
198+
sparse_sum = torch.sparse.sum(matrix_norm, dim=dim)
199+
assert torch.allclose(sparse_sum.values(), torch.ones_like(sparse_sum.values()))

tox.ini

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ envlist =
2323
docstr-coverage
2424
docs-test
2525
# the actual tests
26-
py-torch-{1.11,1.12}
26+
py-torch-{1.11,1.12,1.13}
2727
# always keep coverage-report last
2828
# coverage-report
2929

@@ -39,6 +39,7 @@ setenv =
3939
deps =
4040
torch-1.11: torch~=1.11.0
4141
torch-1.12: torch~=1.12.0
42+
torch-1.13: torch~=1.13.0
4243
extras =
4344
# See the [options.extras_require] entry in setup.cfg for "tests"
4445
tests

0 commit comments

Comments
 (0)