diff --git a/synapse/cli/streaming.py b/synapse/cli/streaming.py index 6f08386..14b0a05 100644 --- a/synapse/cli/streaming.py +++ b/synapse/cli/streaming.py @@ -102,6 +102,8 @@ def read(args): if not config: console.print(f"[bold red]Failed to load config from {args.config}") return + + # Check for expected nodes stream_out = next( (n for n in config.nodes if n.type == NodeType.kStreamOut), None ) @@ -111,29 +113,42 @@ def read(args): broadband = next( (n for n in config.nodes if n.type == NodeType.kBroadbandSource), None ) - if not broadband: - console.print("[bold red]No BroadbandSource node found in config") - return - signal = broadband.signal - if not signal: - console.print("[bold red]No signal configured for BroadbandSource node") + spike_source = next( + (n for n in config.nodes if n.type == NodeType.kSpikeSource), None + ) + if not broadband and not spike_source: + console.print("[bold red]No BroadbandSource or SpikeSource node found in config") return - if not signal.electrode: - console.print( - "[bold red]No electrode signal configured for BroadbandSource node" - ) - return + if broadband: + signal = broadband.signal + if not signal: + console.print("[bold red]No signal configured for BroadbandSource node") + return + + if not signal.electrode: + console.print( + "[bold red]No electrode signal configured for BroadbandSource node" + ) + return - num_ch = len(signal.electrode.channels) - if args.num_ch: - num_ch = args.num_ch - offset = 0 - channels = [] - for ch in range(offset, offset + num_ch): - channels.append(channel.Channel(ch, 2 * ch, 2 * ch + 1)) + num_ch = len(signal.electrode.channels) + if args.num_ch: + num_ch = args.num_ch + offset = 0 + channels = [] + for ch in range(offset, offset + num_ch): + channels.append(channel.Channel(ch, 2 * ch, 2 * ch + 1)) + + broadband.signal.electrode.channels = channels + elif spike_source: + if not spike_source.electrodes: + console.print("[bold red]No electrodes configured for SpikeSource node") + return - broadband.signal.electrode.channels = channels + num_ch = len(spike_source.electrodes.channels) + if args.num_ch: + num_ch = args.num_ch with console.status( "Configuring device", spinner="bouncingBall", spinner_style="green" @@ -166,14 +181,6 @@ def read(args): if not device.start(): raise ValueError("Failed to start device") - # Get the sample rate from the device - # We need to look at the node configuration with type kBroadbandSource for the sample rate - broadband = next( - (n for n in config.nodes if n.type == NodeType.kBroadbandSource), None - ) - assert broadband is not None, "No BroadbandSource node found in config" - sample_rate_hz = broadband.sample_rate_hz - else: # TODO(gilbert): Get rid of this giant if-else block node = next( @@ -227,15 +234,22 @@ def read(args): if args.bin: threads.append( threading.Thread( - target=_binary_writer, args=(stop, q, num_ch, output_base) + target=_writer_binary, args=(stop, q, num_ch, output_base) ) ) else: threads.append( - threading.Thread(target=_data_writer, args=(stop, q, output_base)) + threading.Thread(target=_writer_jsonl, args=(stop, q, output_base)) ) if args.plot: + if not broadband: + console.print("[bold red]A BroadbandSource node is required to plot data") + return + + # Get the sample rate from the device + # We need to look at the node configuration with type kBroadbandSource for the sample rate + sample_rate_hz = broadband.sample_rate_hz threads.append( threading.Thread( target=_plot_data, args=(stop, plot_q, sample_rate_hz, num_ch) @@ -275,7 +289,6 @@ def read_packets( q: queue.Queue, plot_q: queue.Queue, duration: Optional[int] = None, - num_ch: int = 32, ): packet_count = 0 seq_number = None @@ -326,7 +339,7 @@ def read_packets( ) -def _binary_writer(stop, q, num_ch, output_base): +def _writer_binary(stop, q, num_ch, output_base): filename = f"{output_base}.dat" full_path = os.path.join(output_base, filename) if filename: @@ -339,6 +352,9 @@ def _binary_writer(stop, q, num_ch, output_base): except queue.Empty: continue + if data.data_type == ndtp_types.DataType.kBroadband: + continue + try: for ch_id, samples in data.samples: channel_data.append([ch_id, samples]) @@ -360,7 +376,7 @@ def _binary_writer(stop, q, num_ch, output_base): continue -def _data_writer(stop, q, output_base): +def _writer_jsonl(stop, q, output_base): filename = f"{output_base}.jsonl" full_path = os.path.join(output_base, filename) if filename: diff --git a/synapse/simulator/nodes/spike_source.py b/synapse/simulator/nodes/spike_source.py index 5afb7ee..59defa0 100644 --- a/synapse/simulator/nodes/spike_source.py +++ b/synapse/simulator/nodes/spike_source.py @@ -1,6 +1,8 @@ import asyncio +from enum import Enum import random import time +import numpy as np from synapse.api.node_pb2 import NodeType from synapse.api.nodes.spike_source_pb2 import SpikeSourceConfig @@ -11,11 +13,52 @@ def r_sample(bit_width: int): return random.randint(0, 2**bit_width - 1) +class SpikeGenerationMode(Enum): + kNoise = 0 + kSine = 1 + kLinear = 2 + +def generate_sine_wave_spikes( + num_channels: int, + phase: int, + base_spike_count: int, + spike_amplitude: int, + wave_period: int +) -> np.ndarray: + """Generate spike counts forming a sine wave pattern across channels. + """ + y = np.sin(2 * np.pi * phase / wave_period) + active_channel = int((y + 1) * (num_channels - 1) / 2) + + spike_counts = np.full(num_channels, base_spike_count) + for ch in range(max(0, active_channel - 1), min(num_channels, active_channel + 2)): + distance = abs(ch - active_channel) + if distance == 0: + spike_counts[ch] = base_spike_count + spike_amplitude + else: + spike_counts[ch] = base_spike_count + spike_amplitude // 2 + + return spike_counts + +def generate_gradient_spikes( + num_channels: int, + phase: int, + max_spikes: int, + period: int +) -> np.ndarray: + """Generate spike counts forming a diagonal gradient pattern across channels. + """ + spike_counts = np.zeros(num_channels) + for ch in range(num_channels): + ch_phase = (phase + ch * period // num_channels) % period + spike_counts[ch] = max_spikes * ch_phase / period + return spike_counts.astype(int) class SpikeSource(BaseNode): def __init__(self, id): super().__init__(id, NodeType.kSpikeSource) self.__config: SpikeSourceConfig = None + self.__phase = 0 def config(self): c = super().config() @@ -30,17 +73,15 @@ def configure( return Status() async def run(self): + mode = SpikeGenerationMode.kSine + if not self.__config: self.logger.error("node not configured") return c = self.__config - if not c.HasField("signal") or not c.signal: - self.logger.error("node signal not configured") - return - - if not c.signal.HasField("electrodes") or not c.signal.electrodes: + if not c.HasField("electrodes") or not c.electrodes: self.logger.error("node not configured with electrodes") return @@ -52,23 +93,51 @@ async def run(self): channels = e.channels spike_window_ms = c.spike_window_ms if c.spike_window_ms else 20.0 - window_s = spike_window_ms / 1000.0 - min_spikes = max(0, int(1 * window_s)) - max_spikes = min(15, int(200 * window_s)) + max_spikes = int(spike_window_ms) + base_spike_count = max_spikes // 4 + spike_amplitude = max_spikes // 2 + + num_channels = len(channels) + wave_period = num_channels * 2 - t0 = time.time_ns() - while self.running: - now = time.time_ns() + try: + t0 = time.time_ns() + while self.running: + now = time.time_ns() - spike_counts = [random.randint(min_spikes, max_spikes) for _ in channels] - data = SpiketrainData( - t0=t0, - bin_size_ms=spike_window_ms, - spike_counts=spike_counts - ) - - await self.emit_data(data) + if mode == SpikeGenerationMode.kSine: + spike_counts = generate_sine_wave_spikes( + num_channels=num_channels, + phase=self.__phase, + base_spike_count=base_spike_count, + spike_amplitude=spike_amplitude, + wave_period=wave_period + ) + elif mode == SpikeGenerationMode.kLinear: + spike_counts = generate_gradient_spikes( + num_channels=num_channels, + phase=self.__phase, + max_spikes=max_spikes, + period=wave_period + ) + else: + spike_counts = np.array([ + r_sample(4) if random.random() < 0.3 else 0 + for _ in range(num_channels) + ]) + + data = SpiketrainData( + t0=t0, + bin_size_ms=spike_window_ms, + spike_counts=spike_counts.tolist() + ) + + await self.emit_data(data) - t0 = now + self.__phase = (self.__phase + 1) % wave_period + t0 = now - await asyncio.sleep(spike_window_ms / 1000) + await asyncio.sleep(spike_window_ms / 1000) + except Exception as e: + self.logger.error(f"Error in SpikeSource: {e}") + raise e diff --git a/synapse/utils/ndtp_types.py b/synapse/utils/ndtp_types.py index 081a436..c53b4a6 100644 --- a/synapse/utils/ndtp_types.py +++ b/synapse/utils/ndtp_types.py @@ -163,3 +163,5 @@ def to_list(self): SynapseData = Union[SpiketrainData, ElectricalBroadbandData] + +__all__ = ['SynapseData', 'SpiketrainData', 'ElectricalBroadbandData', 'DataType']