Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 2 additions & 19 deletions sigpyproc/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,17 +289,9 @@ def dedisperse(
"""
delays = self.header.get_dmdelays(dm, ref_freq=ref_freq)
if only_valid_samples:
max_delay = delays.max()
valid_samps = self.data.shape[1] - max_delay
if valid_samps < 0:
msg = (
f"Insufficient time samples to dedisperse to {dm} (requires at "
f"least {max_delay} samples, given {self.data.shape[1]})."
)
raise ValueError(msg)
new_ar = kernels.roll_block_valid(self.data, delays)
new_ar = kernels.roll_block_valid(self.data, -delays)
else:
new_ar = kernels.roll_block(self.data, delays)
new_ar = kernels.roll_block(self.data, -delays)
return FilterbankBlock(
new_ar,
self.header.new_header({"nsamples": new_ar.shape[1]}),
Expand Down Expand Up @@ -337,15 +329,6 @@ def dmt_transform(
dm_arr = dm + np.linspace(-dm, dm, dmsteps)
dm_delays = self.header.get_dmdelays(dm_arr, ref_freq=ref_freq)
if only_valid_samples:
max_delay = dm_delays.max()
valid_samps = self.data.shape[1] - dm_delays.max()
if valid_samps < 0:
msg = (
f"Insufficient time samples to dedisperse to {dm_arr.max()} "
f"(requires at least {max_delay} samples, given "
f"{self.data.shape[1]})."
)
raise ValueError(msg)
new_ar = kernels.dmt_block_valid(self.data, dm_delays)
else:
new_ar = kernels.dmt_block(self.data, dm_delays)
Expand Down
126 changes: 79 additions & 47 deletions sigpyproc/core/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,17 +963,18 @@
) -> np.ndarray:
"""Roll array elements along a given axis.

Implemented in ``rocket-fft``. This function is a njit-compiled wrapper
around `numpy.roll`.
This is a Numba-compiled wrapper around `numpy.roll`, implemented via `rocket-fft`
to support the `axis` argument.

Parameters
----------
arr : ndarray
Input array.
shift : int | tuple[int, ...]
Number of bins to shift.
Number of positions to shift. Positive shifts right/down, negative left/up.
If a tuple, must match the length of `axis`.
axis : int | tuple[int, ...] | None, optional
Axis or axes along which to roll, by default None.
Axis or axes to roll along. If None, flattens array and rolls all elements.

Returns
-------
Expand All @@ -985,66 +986,72 @@

@njit(cache=True, fastmath=True)
def roll_block(arr: np.ndarray, shifts: np.ndarray) -> np.ndarray:
"""Roll the 2D array along the second axis by the specified shifts.
"""Roll each row of a 2D array along columns by per-row shifts.

Applies a circular shift to each row independently, wrapping elements around
the column axis. Positive shifts move elements right, negative shifts move left.

Parameters
----------
arr : np.ndarray
Input 2D array.
Input 2D array of shape (nrows, ncols).
shifts : np.ndarray
Array of shifts for each row.
1D array of integer shifts, length equal to nrows. Can be positive or negative.

Returns
-------
np.ndarray
Rolled 2D array.
Rolled 2D array with the same shape as `arr`.

Raises
------
ValueError
If `arr` is not 2D or `shifts` length does not match number of rows.

"""
if arr.ndim != 2:
msg = "Input array must be 2D."
raise ValueError(msg)
if len(shifts) != arr.shape[0]:
msg = "Number of shifts must be equal to the number of rows."
raise ValueError(msg)
res = np.empty_like(arr)
nrows, ncols = arr.shape
res = np.empty_like(arr)
for irow in range(nrows):
shift = shifts[irow] % ncols
res[irow, shift:] = arr[irow, : ncols - shift]
res[irow, :shift] = arr[irow, ncols - shift :]
return res


@njit(cache=True, fastmath=True)
def dmt_block(arr: np.ndarray, dm_delays: np.ndarray) -> np.ndarray:
if arr.ndim != 2 or dm_delays.ndim != 2:
msg = "Input array and delays must be 2D."
raise ValueError(msg)
if arr.shape[0] != dm_delays.shape[1]:
msg = "Number of chans must be same in both arrays."
raise ValueError(msg)
_, nsamps = arr.shape
ndms, _ = dm_delays.shape
res = np.empty((ndms, nsamps), dtype=arr.dtype)
for idm in range(ndms):
res[idm] = np.sum(roll_block(arr, dm_delays[idm]), axis=0)
if shift == 0:
res[irow] = arr[irow]
else:
res[irow, shift:] = arr[irow, : ncols - shift]
res[irow, :shift] = arr[irow, ncols - shift :]
return res


@njit(cache=True, fastmath=True)
def roll_block_valid(arr: np.ndarray, shifts: np.ndarray) -> np.ndarray:
"""Roll the 2D array along the second axis amd keep valid region.
"""Roll each row of a 2D array by per-row shifts, keeping only valid columns.

Similar to `roll_block` but only keeps the valid region where no wrapping occurs.
Positive shifts move elements right, negative shifts move elements left.

Parameters
----------
arr : np.ndarray
Input 2D array.
Input 2D array of shape (nrows, ncols).
shifts : np.ndarray
Array of shifts for each row.
1D array of integer shifts, length equal to nrows. Can be positive or negative.

Returns
-------
np.ndarray
Rolled 2D array (excluding invalid samples, not circular rolled).
Rolled 2D array with shape (nrows, ncols - shift_range).

Raises
------
ValueError
If `arr` is not 2D or `shifts` length doesn't match number of rows.
If the shift range exceeds the number of columns.

"""
if arr.ndim != 2:
msg = "Input array must be 2D."
Expand All @@ -1053,19 +1060,39 @@
msg = "Number of shifts must be equal to the number of rows."
raise ValueError(msg)
nrows, ncols = arr.shape
valid_samps = ncols - np.abs(shifts).max()
if valid_samps < 0:
msg = "Insufficient time samples to dedisperse."
max_pos_shift = max(0, np.max(shifts))
min_neg_shift = min(0, np.min(shifts))

# Calculate the valid region size
start_col = max_pos_shift
end_col = ncols + min_neg_shift
valid_cols = end_col - start_col
if valid_cols <= 0:
msg = (
f"Not enough samples. Required at least {max_pos_shift - min_neg_shift} "
f"samples, given {ncols}."
)
raise ValueError(msg)
res = np.empty((nrows, valid_samps), dtype=arr.dtype)
if np.any(shifts > 0):
for irow in range(nrows):
res[irow] = arr[irow, shifts[irow] : valid_samps + shifts[irow]]
else:
for irow in range(nrows):
end = ncols + shifts[irow]
start = end - valid_samps
res[irow] = arr[irow, start:end]
res = np.empty((nrows, valid_cols), dtype=arr.dtype)
for irow in range(nrows):
shift = shifts[irow]
res[irow, :] = arr[irow, start_col - shift : end_col - shift]
return res


@njit(cache=True, fastmath=True)
def dmt_block(arr: np.ndarray, dm_delays: np.ndarray) -> np.ndarray:
if arr.ndim != 2 or dm_delays.ndim != 2:
msg = "Input array and delays must be 2D."
raise ValueError(msg)
if arr.shape[0] != dm_delays.shape[1]:
msg = "Number of chans must be same in both arrays."
raise ValueError(msg)
_, nsamps = arr.shape
ndms, _ = dm_delays.shape
res = np.empty((ndms, nsamps), dtype=arr.dtype)
for idm in range(ndms):
res[idm] = np.sum(roll_block(arr, dm_delays[idm]), axis=0)
return res


Expand All @@ -1079,11 +1106,16 @@
raise ValueError(msg)
_, nsamps = arr.shape
ndms, _ = dm_delays.shape
valid_samps = nsamps - dm_delays.max()
if valid_samps < 0:
msg = "Insufficient time samples to dedisperse."
max_pos_shift = max(0, np.max(dm_delays))
min_neg_shift = min(0, np.min(dm_delays))
valid_samples = nsamps + min_neg_shift - max_pos_shift
if valid_samples <= 0:
msg = (
f"Not enough samples. Required at least {max_pos_shift - min_neg_shift} "
f"samples, given {nsamps}."
)
raise ValueError(msg)
res = np.empty((ndms, valid_samps), dtype=arr.dtype)
res = np.empty((ndms, valid_samples), dtype=arr.dtype)
for idm in range(ndms):
res[idm] = np.sum(roll_block_valid(arr, dm_delays[idm]), axis=0)
return res
Expand All @@ -1094,7 +1126,7 @@
arr: np.ndarray,
shifts: np.ndarray,
nsamps_out: int = 1,
tfactor: int = 1,

Check failure on line 1129 in sigpyproc/core/kernels.py

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, 3.13)

sigpyproc/core/kernels.py:1129:5: ARG001 Unused function argument: `tfactor`

Check failure on line 1129 in sigpyproc/core/kernels.py

View workflow job for this annotation

GitHub Actions / build (macos-latest, 3.13)

sigpyproc/core/kernels.py:1129:5: ARG001 Unused function argument: `tfactor`

Check failure on line 1129 in sigpyproc/core/kernels.py

View workflow job for this annotation

GitHub Actions / build (macos-latest, 3.12)

sigpyproc/core/kernels.py:1129:5: ARG001 Unused function argument: `tfactor`

Check failure on line 1129 in sigpyproc/core/kernels.py

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, 3.12)

sigpyproc/core/kernels.py:1129:5: ARG001 Unused function argument: `tfactor`

Check failure on line 1129 in sigpyproc/core/kernels.py

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, 3.11)

sigpyproc/core/kernels.py:1129:5: ARG001 Unused function argument: `tfactor`

Check failure on line 1129 in sigpyproc/core/kernels.py

View workflow job for this annotation

GitHub Actions / build (macos-latest, 3.11)

sigpyproc/core/kernels.py:1129:5: ARG001 Unused function argument: `tfactor`
) -> np.ndarray:
"""Roll the 2D array along the second axis.

Expand Down
45 changes: 26 additions & 19 deletions sigpyproc/simulation/furby.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def _compute_stats(self, ts_os: np.ndarray) -> PulseStats:
class SpectralStructure:
"""Class to simulate various spectral structures for the radio bursts.

This class generates spectral patterns with customisable characteristics,
Generates spectral patterns with customisable characteristics,
including frequency-dependent effects and various structural types.

Parameters
Expand All @@ -430,21 +430,24 @@ class SpectralStructure:
Notes
-----
Supported spectral structures are:
- flat: Equal gain in each channel, no evolution with frequency
- flat: Equal gain, no evolution with frequency (unless spec_index != 0).
- power_law: Gains follow a power-law with the given spectral index.
- smooth_envelope: Gains evolve smoothly as a Gaussian envelope.
- gaussian: Gains follow a Gaussian profile.
- polynomial_peaks: Gains are determined by a polynomial with random peaks.
- scintillation: Gains follow a sinusoidal (scintillating) profile.
- gaussian_blobs: Gains follow a patchy profile with Gaussian blobs.
- random: Randomly select one of the above structures.
- smooth_envelope: Smooth Gaussian-like envelope.
- gaussian: Single Gaussian profile.
- polynomial_peaks: Polynomial with random peaks, degree 2-5.
- scintillation: Sinusoidal (scintillating) profile.
- gaussian_blobs: Patchy profile with Gaussian blobs.
- random: Randomly selects one of the above.

The power-law weight (freqs / freqs[0]) ** spec_index is applied to all spectra.
Set spec_index=0 for no power-law weighting.
"""

def __init__(
self,
freqs: np.ndarray,
kind: SpecSimulMethods = "scintillation",
spec_index: float = -2,
spec_index: float = -2.0,
seed: int | None = None,
) -> None:
self.freqs = np.asarray(freqs, dtype=np.float32)
Expand Down Expand Up @@ -487,8 +490,8 @@ def generate(self) -> np.ndarray:
Then the spectrum is shifted to have a mean of 1 to conserve the total signal
after averaging along the frequency axis.
"""
spec = self._spec_generators[self.kind]()
return self._normalize_and_shift(spec * self.power_law_wt)
spec = self._spec_generators[self.kind]() * self.power_law_wt
return self._normalize_and_shift(spec)

def plot(
self,
Expand Down Expand Up @@ -524,26 +527,30 @@ def _spec_flat(self) -> np.ndarray:

def _spec_power_law(self) -> np.ndarray:
"""Generate a power-law spectrum."""
return self.power_law_wt
return np.ones_like(self.freqs)

def _spec_smooth_envelope(self) -> np.ndarray:
"""Generate a smooth envelope spectrum."""
center = self.rng.uniform(0, self.nchans)
roots = [center - self.nchans // 2, center + self.nchans // 2]
spec = -polynomial.polyvalfromroots(np.arange(self.nchans), roots)
center = self.rng.uniform(self.freqs.min(), self.freqs.max())
width = self.rng.uniform(np.ptp(self.freqs) / 4, np.ptp(self.freqs) / 2)
roots = [center - width, center + width]
spec = -polynomial.polyvalfromroots(self.freqs, roots)
return spec.astype(np.float32)

def _spec_gaussian(self) -> np.ndarray:
"""Generate a Gaussian spectrum."""
center = self.rng.normal(self.freqs.mean(), np.ptp(self.freqs) / 4)
center = self.rng.uniform(self.freqs.mean(), self.freqs.max())
width = self.rng.uniform(np.ptp(self.freqs) / 10, np.ptp(self.freqs) / 2)
return utils.gaussian(self.freqs, center, width)

def _spec_poly(self) -> np.ndarray:
"""Generate a polynomial spectrum with random peaks."""
npeaks = self.rng.geometric(p=1 / 3) + 1
roots = self.rng.uniform(self.freqs.min(), self.freqs.max(), size=2 * npeaks)
spec = -polynomial.polyvalfromroots(self.freqs, np.sort(roots))
degree = self.rng.integers(2, 6) # Low degree to prevent overflow
coeffs = self.rng.normal(size=degree + 1)
bandwidth = self.freqs.max() - self.freqs.min()
freqs_norm = (self.freqs - self.freqs.min()) / (bandwidth) * 2 - 1 # [-1,1]
poly = np.poly1d(coeffs)
spec = poly(freqs_norm)
return spec.astype(np.float32)

def _spec_scint(self) -> np.ndarray:
Expand Down
22 changes: 16 additions & 6 deletions tests/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_normalise(self, filfile_8bit_1: str) -> None:
np.testing.assert_allclose(block_norm.data.mean(), 0, atol=0.01)
np.testing.assert_allclose(block_norm.data.std(), 1, atol=0.01)
with pytest.raises(ValueError):
block.normalise(loc_method="invalid") # type: ignore[arg-type]
block.normalise(loc_method="invalid") # type: ignore[arg-type]

def test_pad_samples(self, filfile_8bit_1: str) -> None:
nsamps_final = 2048
Expand All @@ -81,7 +81,7 @@ def test_pad_samples(self, filfile_8bit_1: str) -> None:
np.testing.assert_equal(block_pad.data.shape[1], 2048)
np.testing.assert_equal(block_pad.header.nsamples, 2048)
with pytest.raises(ValueError):
block.pad_samples(nsamps_final, offset, pad_mode="invalid") # type: ignore[arg-type]
block.pad_samples(nsamps_final, offset, pad_mode="invalid") # type: ignore[arg-type]

def test_get_tim(self, filfile_8bit_1: str) -> None:
fil = FilReader(filfile_8bit_1)
Expand All @@ -100,26 +100,36 @@ def test_get_bandpass(self, filfile_8bit_1: str) -> None:
def test_dedisperse(self, filfile_8bit_1: str) -> None:
dm = 50
fil = FilReader(filfile_8bit_1)
block = fil.read_block(100, 1024)
data = np.zeros((fil.header.nchans, fil.header.nsamples), dtype=np.float32)
block = FilterbankBlock(data, fil.header)
block.data[:, block.nsamples // 2] = 1.0
block_dedisp = block.dedisperse(dm)
np.testing.assert_equal(block.data.shape, block_dedisp.data.shape)
np.testing.assert_equal(block_dedisp.data.shape, block.data.shape)
np.testing.assert_equal(block_dedisp.dm, dm)
np.testing.assert_array_equal(
block.data.mean(axis=1),
block_dedisp.data.mean(axis=1),
block.data.mean(axis=1),
)
# check the direction of the dedispersion
assert block_dedisp.data[:, block.nsamples // 2 + 1 :].sum() == 0
block_dedisp = block.dedisperse(-dm)
assert block_dedisp.data[:, : block.nsamples // 2].sum() == 0

def test_dedisperse_valid_samples(self, filfile_8bit_1: str) -> None:
dm = 50
fil = FilReader(filfile_8bit_1)
block = fil.read_block(100, 1024)
data = np.zeros((fil.header.nchans, fil.header.nsamples), dtype=np.float32)
block = FilterbankBlock(data, fil.header)
block.data[:, block.nsamples // 2] = 1.0
block_dedisp = block.dedisperse(dm, only_valid_samples=True)
np.testing.assert_equal(block_dedisp.dm, dm)
np.testing.assert_equal(block_dedisp.data.shape[0], block.data.shape[0])
np.testing.assert_equal(
block_dedisp.nsamples,
block.nsamples - block.header.get_dmdelays(dm).max(),
)
# check the direction of the dedispersion
assert block_dedisp.data[:, block.nsamples // 2 + 1 :].sum() == 0

def test_dedisperse_valid_samples_fail(self, filfile_8bit_1: str) -> None:
dm = 10000
Expand Down
Loading
Loading