Skip to content

soroush-mim/memloss

 
 

Repository files navigation

MemLoss: Enhancing Adversarial Training with Recycling Adversarial Examples

This repository contains the implementation of MemLoss, a novel approach to adversarial training that leverages Memory Adversarial Examples (adversarial examples from previous training epochs) to enhance model robustness and accuracy. The method is described in the paper "MemLoss: Enhancing Adversarial Training with Recycling Adversarial Examples" by Soroush Mahdi, Maryam Amirmazlaghani, Saeed Saravani, and Zahra Dehghanian.

Note: This repository is forked from the HAT (Helper-based Adversarial Training) implementation by Rahul Rade and Seyed-Mohsen Moosavi-Dezfooli. We extend their work by incorporating memory mechanisms into adversarial training.

Key Features

  • Memory Adversarial Examples: Recycles adversarial examples from previous training epochs to maintain robustness against a broader range of attacks
  • Improved Accuracy-Robustness Trade-off: Achieves better clean accuracy while maintaining strong adversarial robustness
  • Flexible Integration: Can be seamlessly integrated with existing adversarial training methods like TRADES and MART
  • Weighted Memory Loss: Dynamically adjusts the influence of memory examples based on their classification difficulty
  • Multiple Variants: Includes implementations for MemLoss+TRADES, MemLoss+MART, and MemLoss+HAT

Setup

Requirements

Our code has been implemented and tested with Python 3.8.5 and PyTorch 1.8.0. To install the required packages:

$ pip install -r requirements.txt

Repository Structure

.
└── core             # Source code for the experiments
    ├── attacks            # Adversarial attacks
    ├── data               # Data setup and loading
    ├── models             # Model architectures
    └── utils              # Helpers, training and testing functions
        ├── memory_trades.py     # MemLoss+TRADES implementation
        ├── memory_mart.py       # MemLoss+MART implementation
        ├── memory_trades_V2.py  # Alternative MemLoss variant
        └── train.py             # Training functions with memory support
    └── metrics.py         # Evaluation metrics
└── train.py         # Training script with MemLoss support
└── train-wa.py      # Training with model weight averaging
└── eval-aa.py       # AutoAttack evaluation
└── eval-adv.py      # PGD+ and CW evaluation
└── eval-rb.py       # RobustBench evaluation
└── plot_results.py  # Utility for plotting training results

Usage

Training

Run train.py for standard, adversarial, TRADES, MART, HAT, and MemLoss training.

MemLoss Training with TRADES

To train a ResNet-18 model using MemLoss+TRADES on CIFAR-10:

$ python train.py --data-dir <data_dir> \
    --log-dir <log_dir> \
    --desc memloss-trades-cifar10 \
    --data cifar10 \
    --model resnet18 \
    --num-adv-epochs 70 \
    --beta 5.0 \
    --beta-prime 2.0 \
    --lr-max 0.21 \
    --attack linf-pgd \
    --epsilon 8 \
    --attack-loss kl \
    --memory-train

MemLoss Training with MART

To train using MemLoss+MART on CIFAR-10:

$ python train.py --data-dir <data_dir> \
    --log-dir <log_dir> \
    --desc memloss-mart-cifar10 \
    --data cifar10 \
    --model resnet18 \
    --num-adv-epochs 70 \
    --beta 6.0 \
    --memory-train \
    --mart

MemLoss Training with HAT

To combine MemLoss with HAT, first train a standard model, then:

$ python train.py --data-dir <data_dir> \
    --log-dir <log_dir> \
    --desc memloss-hat-cifar10 \
    --data cifar10 \
    --model resnet18 \
    --num-adv-epochs 70 \
    --helper-model <std-model-name> \
    --beta 2.0 \
    --beta-prime 1.0 \
    --gamma 0.5 \
    --memory-train

Key Training Parameters for MemLoss

  • --memory-train: Enable memory adversarial training
  • --beta: Weight for the robust loss term (default: 5.0 for TRADES)
  • --beta-prime: Weight for the memory loss term (default: 2.0)
  • --attack-loss: Loss type for generating adversarial examples (kl, ce, or memory-kl)
  • --weighted: Enable weighted memory loss based on prediction confidence
  • --ema-xprime: Enable exponential moving average for memory examples

Robustness Evaluation

The trained models can be evaluated by running eval-aa.py which uses AutoAttack for evaluating the robust accuracy. For example:

$ python eval-aa.py --data-dir <data_dir> \
    --log-dir <log_dir> \
    --desc memloss-trades-cifar10

For evaluation with PGD+ and CW attacks, use:

$ python eval-adv.py --wb --data-dir <data_dir> \
    --log-dir <log_dir> \
    --desc memloss-trades-cifar10

Method Overview

MemLoss addresses a key limitation in existing adversarial training methods: the model's tendency to "forget" adversarial examples from previous epochs when adapting to new ones. By incorporating Memory Adversarial Examples (adversarial inputs generated in earlier training epochs), MemLoss:

  1. Prevents Catastrophic Forgetting: Maintains robustness against adversarial examples from previous epochs while learning from new ones
  2. Provides Data Diversity: Treats memory examples as an additional source of attack diversity, improving overall robustness
  3. Balances the Trade-off: Achieves better clean accuracy compared to TRADES and MART while maintaining competitive robust accuracy
  4. Weighted Strategy: Dynamically emphasizes harder memory examples using a confidence-based weighting scheme

The memory loss term is defined as:

L_MemLoss = β' · KL(f_θ(x) || f_θ(x'_prev)) · (1 - p_y(x'_prev))

where x'_prev is the adversarial example from the previous epoch, and p_y(x'_prev) is the model's confidence on the correct class.

Results Summary

MemLoss demonstrates significant improvements on multiple datasets:

  • CIFAR-10: 49.79% robust accuracy (AutoAttack) with 83.08% clean accuracy
  • CIFAR-100: 24.24% robust accuracy with 56.81% clean accuracy
  • SVHN: 51.58% robust accuracy with 90.10% clean accuracy

For detailed experimental results, please refer to the paper.

Citing this work

If you use this code or find our work helpful, please cite:

@misc{mahdi2025memlossenhancingadversarialtraining,
      title={MemLoss: Enhancing Adversarial Training with Recycling Adversarial Examples}, 
      author={Soroush Mahdi and Maryam Amirmazlaghani and Saeed Saravani and Zahra Dehghanian},
      year={2025},
      eprint={2510.09105},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2510.09105}, 
}

This work builds upon HAT. If you use the HAT components, please also cite:

@inproceedings{rade2022reducing,
    title={Reducing Excessive Margin to Achieve a Better Accuracy vs. Robustness Trade-off},
    author={Rahul Rade and Seyed-Mohsen Moosavi-Dezfooli},
    booktitle={International Conference on Learning Representations},
    year={2022},
    url={https://openreview.net/forum?id=Azh9QBQ4tR7}
}

Acknowledgments

This repository is based on the HAT (Helper-based Adversarial Training) implementation. We thank the original authors for making their code publicly available.

About

MEMLoss: a memory-based adversarial training framework

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 100.0%