Skip to content

maitrisavaliya/eeg-few-shot-wavelet

Repository files navigation

Few-Shot EEG Visual Decoding via Wavelet-Based Prototypical Networks

A novel deep learning approach combining Morlet wavelet transforms with prototypical networks for few-shot learning on EEG data. Achieves 100% accuracy on synthetic data and demonstrates strong cross-domain generalization.

Author: Maitri Savaliya

Contact: maitrisavaliya05@gmail.com

OVERVIEW

Problem statement

Brain-Computer Interfaces (BCIs) and clinical EEG analysis face a fundamental challenge: limited labeled data. Traditional deep learning requires thousands of trials, but clinical settings often provide only a few examples per patient.

My solution

I propose Wavelet-Based Prototypical Networks that:

  1. Extract physiologically-relevant frequency features (theta: 4-8 Hz, gamma: 30-80 Hz)
  2. Learn from few examples (5-shot learning)
  3. Generalize across domains (simple -> challenging data)

Why this matters

  • Practical: Matches clinical constraints (few trials available)
  • Efficient: Only ~193K parameters, CPU-friendly
  • Interpretable: Uses known neural oscillation frequencies
  • Robust: Strong cross-domain generalization (e.g., +39.6% above chance)

KEY RESULTS

Performance on synthetic data (simple)

Metric Value Episodes
Training Accuracy 99.2% 500
Validation Accuracy 100% 50
Test Accuracy 100% 1000
Training Time 16.1 min CPU

Cross-domain generalization (simple -> challenging)

Trained on: 58 channels, 10 classes, clean data Tested on: 64 channels, 20 subjects, artifacts, overlapping patterns

Difficulty n-way k-shot n-query Accuracy Chance Above Chance
Easy 5 5 15 59.6% 20.0% +39.6%
Medium 10 3 10 51.3% 12.5% +48.8%
Hard 10 2 8 42.8% 10.7% +32.1%
Very Hard 10 1 5 37.6% 10.0% +27.6%

Key insight: model trained on simple data generalizes to complex data without fine-tuning.


MATHEMATICAL FRAMEWORK

  1. Problem formulation: few-shot learning

N-way K-shot classification: Given N classes with K labeled examples each, classify new queries.

Episode structure:

Support set: S = {(x_1, y_1), ..., (x_n, y_n)} where n = N * K Query set: Q = {(x_1, y_1), ..., (x_m, y_m)} where m = N * Q

Goal: learn a model f that performs well on Q using only S.

  1. Morlet wavelet transform

Definition: the complex Morlet wavelet is defined as:

psi(t) = (sigma * sqrt(pi))^-1/2 * exp(-t^2/(2 sigma^2)) * exp(i * 2 pi f t)

where f is the center frequency (Hz), sigma = n_cycles / (2 pi f), and n_cycles (typ. 5-7) controls time-frequency tradeoff.

Continuous wavelet transform:

W(f, t) = integral x(tau) * psi*((tau - t)/a) dtau, where a = 1/f

In practice we compute wavelet power at discrete frequencies:

Power(f, t) = |W(f, t)|^2

Frequency bands used:

  • Theta: 4-8 Hz, 4 log-spaced frequencies
  • Gamma: 30-80 Hz, 8 log-spaced frequencies
  • Total: 12 frequency channels per EEG channel per timepoint

Input/Output dimensions (example):

Input: x in R^{C x T} (C=58 channels, T=250 samples) Wavelet output: W in R^{C x F x T} (F = 12)

Why Morlet: good time-frequency localization, biologically meaningful, widely used in EEG analysis.

  1. Convolutional encoder with attention

Architecture summary:

x in R^{B x C x F x T} -> reshape to R^{B x (C*F) x T} -> Conv1D blocks -> temporal attention -> embedding e in R^{B x D}

Convolutional blocks: each block = Conv1D -> BatchNorm -> ReLU -> MaxPool

Example parameters:

  • Conv block 1: 32 filters, kernel size 7, pool 2
  • Conv block 2: 64 filters, kernel size 5, pool 2

Temporal attention (per time step h_t):

a(t) = W2 * tanh(W1 * h_t) alpha_t = softmax_t(a(t)) context c = sum_t alpha_t * h_t

Projection:

e = ReLU(Linear(c)) e = Linear(e) e = e / ||e||_2 (L2 normalization)

Final embedding dimension D = 128 (example)

  1. Prototypical classification

For each class k:

p_k = (1/K) * sum_{i: y_i = k} e_i

Distance: squared Euclidean d(e, p_k) = ||e - p_k||_2^2 Logits = -d(e_q, p_k) (negative distance) Probabilities = softmax(-d) Loss = cross-entropy over query labels

  1. Training algorithm (episode-based)

for episode in range(N_episodes): classes = sample_n_classes(N_way) support_set = sample_examples(classes, K_shot) query_set = sample_examples(classes, N_query) support_embeddings = encoder(support_set) prototypes = compute_prototypes(support_embeddings) query_embeddings = encoder(query_set) logits = -distance(query_embeddings, prototypes) loss = CrossEntropy(logits, query_labels) loss.backward(); optimizer.step()

Episode-based training simulates test conditions and trains the model to learn from few examples.


ARCHITECTURE DETAILS

Pipeline (text):

Input EEG (C x T) -> Theta/Gamma wavelet bank (F frequencies) -> concat -> Conv1D blocks -> Attention -> FC -> Embedding (D)

Prototype matching and classification follow.

Model size (example): ~193,376 parameters.


INSTALLATION

Requirements

  • Python 3.8+
  • PyTorch 2.0+
  • numpy, scipy, matplotlib, scikit-learn, pyyaml, tqdm

Install example (Windows PowerShell):

python -m venv venv
venv\Scripts\Activate.ps1
pip install -r requirements.txt

Suggested requirements.txt pins:

torch>=2.0.0 numpy>=1.24.0 scipy>=1.10.0 matplotlib>=3.7.0 scikit-learn>=1.3.0 pyyaml>=6.0 tqdm>=4.65.0 seaborn>=0.12.0


QUICK START

Run a quick test (if present):

python quick_test.py

Train:

python train.py

Evaluate:

python evaluate.py --model_path checkpoints/best_model.pt

Test cross-domain:

python create_challenge_data.py
python test_challenge.py

EXPERIMENTS & ANALYSIS

Training dynamics (example):

Episode 0: Loss=1.6027, Acc=0.28 Episode 25: Loss=1.0421, Acc=0.55 Episode 50: Loss=0.7865, Acc=0.80 Episode100: Loss=0.6177, Acc=0.92 Episode200: Loss=0.5133, Acc=0.9584 Episode500: Loss=0.43, Acc=0.9816

Generalization analysis summary included in the paper and experiment scripts.

REPRODUCING RESULTS

  1. Install deps
  2. Prepare dataset
  3. python train.py
  4. python evaluate.py --model_path checkpoints/best_model.pt

FUTURE WORK

  • Learnable wavelet parameters
  • Include alpha/beta bands
  • Subject-adaptive few-shot fine-tuning

CITATION

If you use this code, cite (placeholder):

@misc{maitrisavaliya2025waveletproto, title={Few-Shot EEG Visual Decoding via Wavelet-Based Prototypical Networks}, author={Maitri Savaliya}, year={2025}, howpublished={\url{https://github.com/maitrisavaliya/eeg-few-shot-wavelet}} }


LICENSE

This project is provided under the MIT License. See LICENSE.


CONTACT

Questions: maitrisavaliya05@gmail.com

About

This repository implements a compact, interpretable few-shot learning pipeline for EEG visual decoding. It combines Morlet wavelet-based time-frequency features (theta and gamma bands) with a convolutional encoder + temporal attention and prototypical networks to perform N-way K-shot classification. I

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages