An official implementation of GMAIR.
This project uses Python 3.8 and Pytorch 1.8.1.
git clone https://github.com/EmoFuncs/GMAIR-pytorch.git
pip install -r requirements.txt
Build bbox:
cd GMAIR-pytorch/gmair/utils/bbox
python setup.py build
cp build/lib/bbox.so .
link The dataset is generated from a modified version of MultiDigitMNIST.
train images train annotations test images test annotations
Note that annotations are only used for evaluation.
For MultiMNIST, download MultiMNIST dataset. Unzip it, and put it into 'data/multi_mnist/'. Substitute 'config.py' with 'mnist_config.py' in 'gmair/config'
cd gmair/config
cp mnist_config.py config.py
cd ../..
For Fruit2d, download Fruit2d dataset. Unzip them, and put them into 'data/fruit2d/'. Substitute 'config.py' with 'fruit_config.py' in 'gmair/config'
cd gmair/config
cp mnist_config.py config.py
cd ../..
The architecture should be:
data
|---fruit2d
| |---test_images
| | |---x.png
| |
| |---test_labels
| | |---x.txt
| |
| |---train_images
| | |---y.png
| |
| |---train_labels
| |---y.txt
|
|---scatter_mnist
|---scattered_mnist_128x128_obj14x14.hdf5
Then,
python train.py
Checkpoints: MultiMNIST Fruit2D
Set the path of checkpoint file in the configuration file 'config.py' (the variable 'test_model_path').
python test.py