A novel deep learning approach combining Morlet wavelet transforms with prototypical networks for few-shot learning on EEG data. Achieves 100% accuracy on synthetic data and demonstrates strong cross-domain generalization.
Author: Maitri Savaliya
Contact: maitrisavaliya05@gmail.com
OVERVIEW
Problem statement
Brain-Computer Interfaces (BCIs) and clinical EEG analysis face a fundamental challenge: limited labeled data. Traditional deep learning requires thousands of trials, but clinical settings often provide only a few examples per patient.
My solution
I propose Wavelet-Based Prototypical Networks that:
- Extract physiologically-relevant frequency features (theta: 4-8 Hz, gamma: 30-80 Hz)
- Learn from few examples (5-shot learning)
- Generalize across domains (simple -> challenging data)
Why this matters
- Practical: Matches clinical constraints (few trials available)
- Efficient: Only ~193K parameters, CPU-friendly
- Interpretable: Uses known neural oscillation frequencies
- Robust: Strong cross-domain generalization (e.g., +39.6% above chance)
KEY RESULTS
Performance on synthetic data (simple)
| Metric | Value | Episodes |
|---|---|---|
| Training Accuracy | 99.2% | 500 |
| Validation Accuracy | 100% | 50 |
| Test Accuracy | 100% | 1000 |
| Training Time | 16.1 min | CPU |
Cross-domain generalization (simple -> challenging)
Trained on: 58 channels, 10 classes, clean data Tested on: 64 channels, 20 subjects, artifacts, overlapping patterns
| Difficulty | n-way | k-shot | n-query | Accuracy | Chance | Above Chance |
|---|---|---|---|---|---|---|
| Easy | 5 | 5 | 15 | 59.6% | 20.0% | +39.6% |
| Medium | 10 | 3 | 10 | 51.3% | 12.5% | +48.8% |
| Hard | 10 | 2 | 8 | 42.8% | 10.7% | +32.1% |
| Very Hard | 10 | 1 | 5 | 37.6% | 10.0% | +27.6% |
Key insight: model trained on simple data generalizes to complex data without fine-tuning.
MATHEMATICAL FRAMEWORK
- Problem formulation: few-shot learning
N-way K-shot classification: Given N classes with K labeled examples each, classify new queries.
Episode structure:
Support set: S = {(x_1, y_1), ..., (x_n, y_n)} where n = N * K Query set: Q = {(x_1, y_1), ..., (x_m, y_m)} where m = N * Q
Goal: learn a model f that performs well on Q using only S.
- Morlet wavelet transform
Definition: the complex Morlet wavelet is defined as:
psi(t) = (sigma * sqrt(pi))^-1/2 * exp(-t^2/(2 sigma^2)) * exp(i * 2 pi f t)
where f is the center frequency (Hz), sigma = n_cycles / (2 pi f), and n_cycles (typ. 5-7) controls time-frequency tradeoff.
Continuous wavelet transform:
W(f, t) = integral x(tau) * psi*((tau - t)/a) dtau, where a = 1/f
In practice we compute wavelet power at discrete frequencies:
Power(f, t) = |W(f, t)|^2
Frequency bands used:
- Theta: 4-8 Hz, 4 log-spaced frequencies
- Gamma: 30-80 Hz, 8 log-spaced frequencies
- Total: 12 frequency channels per EEG channel per timepoint
Input/Output dimensions (example):
Input: x in R^{C x T} (C=58 channels, T=250 samples) Wavelet output: W in R^{C x F x T} (F = 12)
Why Morlet: good time-frequency localization, biologically meaningful, widely used in EEG analysis.
- Convolutional encoder with attention
Architecture summary:
x in R^{B x C x F x T} -> reshape to R^{B x (C*F) x T} -> Conv1D blocks -> temporal attention -> embedding e in R^{B x D}
Convolutional blocks: each block = Conv1D -> BatchNorm -> ReLU -> MaxPool
Example parameters:
- Conv block 1: 32 filters, kernel size 7, pool 2
- Conv block 2: 64 filters, kernel size 5, pool 2
Temporal attention (per time step h_t):
a(t) = W2 * tanh(W1 * h_t) alpha_t = softmax_t(a(t)) context c = sum_t alpha_t * h_t
Projection:
e = ReLU(Linear(c)) e = Linear(e) e = e / ||e||_2 (L2 normalization)
Final embedding dimension D = 128 (example)
- Prototypical classification
For each class k:
p_k = (1/K) * sum_{i: y_i = k} e_i
Distance: squared Euclidean d(e, p_k) = ||e - p_k||_2^2 Logits = -d(e_q, p_k) (negative distance) Probabilities = softmax(-d) Loss = cross-entropy over query labels
- Training algorithm (episode-based)
for episode in range(N_episodes): classes = sample_n_classes(N_way) support_set = sample_examples(classes, K_shot) query_set = sample_examples(classes, N_query) support_embeddings = encoder(support_set) prototypes = compute_prototypes(support_embeddings) query_embeddings = encoder(query_set) logits = -distance(query_embeddings, prototypes) loss = CrossEntropy(logits, query_labels) loss.backward(); optimizer.step()
Episode-based training simulates test conditions and trains the model to learn from few examples.
ARCHITECTURE DETAILS
Pipeline (text):
Input EEG (C x T) -> Theta/Gamma wavelet bank (F frequencies) -> concat -> Conv1D blocks -> Attention -> FC -> Embedding (D)
Prototype matching and classification follow.
Model size (example): ~193,376 parameters.
INSTALLATION
Requirements
- Python 3.8+
- PyTorch 2.0+
- numpy, scipy, matplotlib, scikit-learn, pyyaml, tqdm
Install example (Windows PowerShell):
python -m venv venv
venv\Scripts\Activate.ps1
pip install -r requirements.txtSuggested requirements.txt pins:
torch>=2.0.0 numpy>=1.24.0 scipy>=1.10.0 matplotlib>=3.7.0 scikit-learn>=1.3.0 pyyaml>=6.0 tqdm>=4.65.0 seaborn>=0.12.0
QUICK START
Run a quick test (if present):
python quick_test.pyTrain:
python train.pyEvaluate:
python evaluate.py --model_path checkpoints/best_model.ptTest cross-domain:
python create_challenge_data.py
python test_challenge.pyEXPERIMENTS & ANALYSIS
Training dynamics (example):
Episode 0: Loss=1.6027, Acc=0.28 Episode 25: Loss=1.0421, Acc=0.55 Episode 50: Loss=0.7865, Acc=0.80 Episode100: Loss=0.6177, Acc=0.92 Episode200: Loss=0.5133, Acc=0.9584 Episode500: Loss=0.43, Acc=0.9816
Generalization analysis summary included in the paper and experiment scripts.
REPRODUCING RESULTS
- Install deps
- Prepare dataset
- python train.py
- python evaluate.py --model_path checkpoints/best_model.pt
FUTURE WORK
- Learnable wavelet parameters
- Include alpha/beta bands
- Subject-adaptive few-shot fine-tuning
CITATION
If you use this code, cite (placeholder):
@misc{maitrisavaliya2025waveletproto, title={Few-Shot EEG Visual Decoding via Wavelet-Based Prototypical Networks}, author={Maitri Savaliya}, year={2025}, howpublished={\url{https://github.com/maitrisavaliya/eeg-few-shot-wavelet}} }
LICENSE
This project is provided under the MIT License. See LICENSE.
CONTACT
Questions: maitrisavaliya05@gmail.com