Skip to content

Commit 3f613b0

Browse files
authored
Merge pull request #561 from pynapple-org/runtime-import
Runtime import
2 parents 0234dbb + 7faab6b commit 3f613b0

9 files changed

Lines changed: 75 additions & 10 deletions

File tree

pynapple/core/_core_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from typing import Literal
1212

1313
import numpy as np
14-
from scipy import signal
1514

1615
from ._jitted_functions import ( # pjitconvolve,
1716
jitbin_array,
@@ -124,6 +123,8 @@ def _dropna(time_array, data_array, starts, ends, update_time_support, ndim):
124123

125124

126125
def _convolve(time_array, data_array, starts, ends, array, trim="both"):
126+
from scipy import signal
127+
127128
if get_backend() == "jax":
128129
from pynajax.jax_core_convolve import convolve
129130

pynapple/core/time_series.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import numpy as np
2323
import pandas as pd
2424
from numpy.lib.mixins import NDArrayOperatorsMixin
25-
from scipy import signal
2625
from tabulate import tabulate
2726

2827
from ._core_functions import (
@@ -574,6 +573,8 @@ def smooth(self, std, windowsize=None, time_units="s", size_factor=100, norm=Tru
574573
Time series convolved with a gaussian kernel
575574
576575
"""
576+
from scipy import signal
577+
577578
if not isinstance(std, (int, float)):
578579
raise IOError("std should be type int or float")
579580
if not isinstance(size_factor, int):
@@ -648,6 +649,8 @@ def decimate(self, down, order=8, filter_type="iir", ep=None):
648649
>>> plt.show()
649650
650651
"""
652+
from scipy import signal
653+
651654
if not isinstance(down, int):
652655
raise IOError(
653656
"Invalid value for 'down': Parameter 'down' should be of type int."

pynapple/io/neurosuite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def load_mean_waveforms(self, epoch=None, waveform_window=None, spike_count=1000
370370

371371
for index, timestep in enumerate(batches):
372372
print(
373-
f"Extracting waveforms from dat file: window {index+1} / {len(windows)}",
373+
f"Extracting waveforms from dat file: window {index + 1} / {len(windows)}",
374374
end="\r",
375375
)
376376

pynapple/process/decoding.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77
from functools import wraps
88

99
import numpy as np
10-
import xarray as xr
11-
from scipy.spatial.distance import cdist
1210

1311
from .. import core as nap
1412

1513

1614
def _format_decoding_inputs(func):
1715
@wraps(func)
1816
def wrapper(*args, **kwargs):
17+
import xarray as xr
18+
1919
# Validate each positional argument
2020
sig = inspect.signature(func)
2121
bound = sig.bind(*args, **kwargs)
@@ -621,6 +621,8 @@ def decode_template(
621621
99.9 1.0 1.0
622622
dtype: float64, shape: (1000, 2)
623623
"""
624+
from scipy.spatial.distance import cdist
625+
624626
tc = tuning_curves.values.reshape(tuning_curves.sizes["unit"], -1)
625627
ct = data.values
626628

@@ -644,6 +646,8 @@ def decode_1d(tuning_curves, group, ep, bin_size, time_units="s", feature=None):
644646
`decode_1d` will be removed in Pynapple 1.0.0, it is replaced by
645647
`decode_bayes` because the latter works for N dimensions.
646648
"""
649+
import xarray as xr
650+
647651
warnings.warn(
648652
"decode_1d is deprecated and will be removed in a future version; use decode_bayes instead.",
649653
FutureWarning,
@@ -684,6 +688,8 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N
684688
`decode_2d` will be removed in Pynapple 1.0.0, it is replaced by
685689
`decode_bayes` because the latter works for N dimensions.
686690
"""
691+
import xarray as xr
692+
687693
warnings.warn(
688694
"decode_2d is deprecated and will be removed in a future version; use decode_bayes instead.",
689695
FutureWarning,

pynapple/process/filtering.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import numpy as np
99
import pandas as pd
10-
from scipy.signal import butter, filtfilt, sosfiltfilt, sosfreqz
1110

1211
from .. import core as nap
1312

@@ -61,6 +60,8 @@ def wrapper(*args, **kwargs):
6160

6261
def _get_butter_coefficients(cutoff, filter_type, sampling_frequency, order=4):
6362
"""Calls scipy butter"""
63+
from scipy.signal import butter
64+
6465
return butter(order, cutoff, btype=filter_type, fs=sampling_frequency, output="sos")
6566

6667

@@ -84,6 +85,8 @@ def _compute_butterworth_filter(
8485
)
8586

8687
else:
88+
from scipy.signal import sosfiltfilt
89+
8790
out = np.zeros_like(data.d)
8891
for ep in data.time_support:
8992
slc = data.get_slice(start=ep.start[0], end=ep.end[0])
@@ -494,6 +497,8 @@ def get_filter_frequency_response(
494497
cutoff = np.array(cutoff)
495498

496499
if mode == "butter":
500+
from scipy.signal import sosfreqz
501+
497502
sos = _get_butter_coefficients(cutoff, filter_type, fs, order)
498503
w, h = sosfreqz(sos, worN=1024, fs=fs)
499504
return pd.Series(index=w, data=np.abs(h))
@@ -550,6 +555,8 @@ def detect_oscillatory_events(
550555
The interval set of detected events with metadata containing
551556
the power, amplitude, and peak_time
552557
"""
558+
from scipy.signal import filtfilt
559+
553560
data = data.restrict(epoch)
554561

555562
if fs is None:

pynapple/process/spectrum.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import numpy as np
1010
import pandas as pd
1111
from numba import njit
12-
from scipy import signal
1312

1413
from .. import core as nap
1514

@@ -239,6 +238,8 @@ def compute_mean_power_spectral_density(
239238
ValueError
240239
If overlap is not within [0, 1).
241240
"""
241+
from scipy import signal
242+
242243
if not (0.0 <= overlap < 1.0):
243244
raise ValueError("Overlap should be in intervals [0.0, 1.0).")
244245

pynapple/process/tuning_curves.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import numpy as np
1111
import pandas as pd
12-
import xarray as xr
1312

1413
from .. import core as nap
1514

@@ -183,6 +182,7 @@ def compute_tuning_curves(
183182
occupancy: [100. 100. 100. 100. 100. 100. 100. 100. 100. 100.]
184183
bin_edges: [array([0. , 0.09, 0.18, 0.27, 0.36, 0.45, 0.54, 0.63, 0.72,...
185184
"""
185+
import xarray as xr
186186

187187
# check data
188188
if not isinstance(data, (nap.TsdFrame, nap.TsGroup, nap.Ts, nap.Tsd)):
@@ -367,6 +367,8 @@ def compute_response_per_epoch(data, epochs_dict, return_pandas=False):
367367
* unit (unit) int64 24B 0 1 2
368368
* epochs (epochs) <U5 40B 'stim0' 'stim1'
369369
"""
370+
import xarray as xr
371+
370372
# check data
371373
if not isinstance(data, (nap.TsdFrame, nap.TsGroup, nap.Ts, nap.Tsd)):
372374
raise TypeError("data should be a TsdFrame, TsGroup, Ts, or Tsd.")
@@ -508,6 +510,8 @@ def compute_mutual_information(tuning_curves, rates=None):
508510
1 33.480966 3.301870
509511
2 33.369159 3.310432
510512
"""
513+
import xarray as xr
514+
511515
if not isinstance(tuning_curves, xr.DataArray):
512516
raise TypeError(
513517
"tuning_curves should be an xr.DataArray as computed by compute_tuning_curves."
@@ -763,6 +767,8 @@ def compute_2d_mutual_info(dict_tc, features, ep=None, minmax=None, bitssec=Fals
763767
`compute_2d_mutual_info` will be removed in Pynapple 1.0.0, it is replaced by
764768
`compute_mutual_information` because the latter works for N dimensions.
765769
"""
770+
import xarray as xr
771+
766772
warnings.warn(
767773
"compute_2d_mutual_info is deprecated and will be removed in a future version;"
768774
"use compute_mutual_information instead.",
@@ -820,6 +826,8 @@ def compute_1d_mutual_info(tc, feature, ep=None, minmax=None, bitssec=False):
820826
`compute_1d_mutual_info` will be removed in Pynapple 1.0.0, it is replaced by
821827
`compute_mutual_information` because the latter works for N dimensions.
822828
"""
829+
import xarray as xr
830+
823831
warnings.warn(
824832
"compute_1d_mutual_info is deprecated and will be removed in a future version;"
825833
"use compute_mutual_information instead.",

tests/test_lazy_loading.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,3 +301,42 @@ def test_tsgroup_no_warnings(tmp_path): # default fixture
301301
# file_path = Path(f'data_{k}.h5')
302302
# if file_path.exists():
303303
# file_path.unlink()
304+
305+
306+
def test_lazy_import_heavy_modules():
307+
"""Test that importing pynapple does not eagerly load scipy.signal, scipy.spatial, or xarray."""
308+
import subprocess
309+
import sys
310+
311+
# Run in subprocess to get clean sys.modules state
312+
code = """
313+
import sys
314+
before = set(sys.modules.keys())
315+
import pynapple
316+
after = set(sys.modules.keys())
317+
new_modules = after - before
318+
319+
# Check for scipy submodules (scipy base may be loaded by numba, but submodules should not)
320+
scipy_signal = [m for m in new_modules if m.startswith('scipy.signal')]
321+
scipy_spatial = [m for m in new_modules if m.startswith('scipy.spatial')]
322+
xarray_mods = [m for m in new_modules if m == 'xarray' or m.startswith('xarray.')]
323+
324+
failures = []
325+
if scipy_signal:
326+
failures.append(f"scipy.signal modules loaded: {scipy_signal}")
327+
if scipy_spatial:
328+
failures.append(f"scipy.spatial modules loaded: {scipy_spatial}")
329+
if xarray_mods:
330+
failures.append(f"xarray modules loaded: {xarray_mods}")
331+
332+
if failures:
333+
print("FAIL: " + "; ".join(failures))
334+
sys.exit(1)
335+
print("PASS")
336+
"""
337+
result = subprocess.run(
338+
[sys.executable, "-c", code],
339+
capture_output=True,
340+
text=True,
341+
)
342+
assert result.returncode == 0, f"Test failed: {result.stdout}{result.stderr}"

tests/test_tuning_curves.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,7 +1588,7 @@ def test_compute_1d_mutual_info(args, kwargs, expected):
15881588
assert list(si.columns) == ["SI"]
15891589
if isinstance(tc, pd.DataFrame):
15901590
assert list(si.index.values) == list(tc.columns)
1591-
np.testing.assert_approx_equal(si.values, expected)
1591+
np.testing.assert_allclose(si.values, expected)
15921592

15931593

15941594
@pytest.mark.filterwarnings("ignore")
@@ -1660,7 +1660,7 @@ def test_compute_2d_mutual_info(args, kwargs, expected):
16601660
assert list(si.columns) == ["SI"]
16611661
if isinstance(dict_tc, dict):
16621662
assert list(si.index.values) == list(dict_tc.keys())
1663-
np.testing.assert_approx_equal(si.values, expected)
1663+
np.testing.assert_allclose(si.values, expected)
16641664

16651665

16661666
# ------------------------------------------------------------------------------------

0 commit comments

Comments
 (0)