Skip to content

Commit

Permalink
replace sparse matrix with sparse array
Browse files Browse the repository at this point in the history
  • Loading branch information
szsdk committed Aug 26, 2024
1 parent 02919b1 commit ce599be
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
12 changes: 6 additions & 6 deletions emcfile/_pattern_sone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import h5py
import numpy as np
import numpy.typing as npt
from scipy.sparse import csr_matrix, hstack
from scipy.sparse import csr_array, hstack

from ._h5helper import PATH_TYPE, H5Path, check_remove_groups, make_path
from ._misc import pretty_size
Expand Down Expand Up @@ -243,15 +243,15 @@ def write(
compression=compression,
)

def _get_sparse_ones(self) -> csr_matrix:
def _get_sparse_ones(self) -> csr_array:
_one = np.ones(1, "i4")
_one = np.lib.stride_tricks.as_strided(
_one, shape=(self.place_ones.shape[0],), strides=(0,)
)
return csr_matrix((_one, self.place_ones, self.ones_idx), shape=self.shape)
return csr_array((_one, self.place_ones, self.ones_idx), shape=self.shape)

def _get_sparse_multi(self) -> csr_matrix:
return csr_matrix(
def _get_sparse_multi(self) -> csr_array:
return csr_array(
(self.count_multi, self.place_multi, self.multi_idx), shape=self.shape
)

Expand All @@ -272,7 +272,7 @@ def __array__(self) -> npt.NDArray[np.int32]:
def __matmul__(self, mtx: npt.NDArray[Any]) -> npt.NDArray[Any]:
return cast(
npt.NDArray[Any],
self._get_sparse_ones() * mtx + self._get_sparse_multi() * mtx,
self._get_sparse_ones() @ mtx + self._get_sparse_multi() @ mtx,
)

def __array_function__(
Expand Down
6 changes: 3 additions & 3 deletions emcfile/_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import numpy.typing as npt
from scipy.sparse import coo_matrix, spmatrix
from scipy.sparse import coo_array, coo_matrix, sparray, spmatrix

from ._h5helper import PATH_TYPE, H5Path
from ._misc import divide_range
Expand Down Expand Up @@ -160,9 +160,9 @@ def patterns(
if start is not None or end is not None:
raise Exception()
return dense_to_PatternsSOne(src)
elif isinstance(src, coo_matrix):
elif isinstance(src, (coo_array, coo_matrix)):
return coo_to_SOne_kernel(src)
elif isinstance(src, spmatrix):
elif isinstance(src, (sparray, spmatrix)):
return cast(
PatternsSOne,
np.concatenate(
Expand Down
10 changes: 5 additions & 5 deletions emcfile/tests/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import pytest
from psutil import Process
from scipy.sparse import coo_matrix, csr_matrix
from scipy.sparse import coo_array, csr_array

import emcfile as ef
from emcfile.tests.utils import temp_seed
Expand Down Expand Up @@ -86,8 +86,8 @@ def gen_pattern_inputs():
ref = ef.patterns(dense)
yield ref, ref
yield dense, ref
yield coo_matrix(dense), ref
yield csr_matrix(dense), ref
yield coo_array(dense), ref
yield csr_array(dense), ref


def test_pattern_not_equal(small_data):
Expand Down Expand Up @@ -252,9 +252,9 @@ def test_pattern_mul(big_data):
np.testing.assert_almost_equal(big_data @ mtx, np.asarray(big_data) @ mtx)
mtx = mtx > 0.4
np.testing.assert_almost_equal(big_data @ mtx, big_data.todense() @ mtx)
mtx = coo_matrix(mtx)
mtx = coo_array(mtx)
np.testing.assert_equal((big_data @ mtx).todense(), big_data.todense() @ mtx)
mtx = csr_matrix(mtx)
mtx = csr_array(mtx)
np.testing.assert_equal((big_data @ mtx).todense(), big_data.todense() @ mtx)


Expand Down

0 comments on commit ce599be

Please sign in to comment.