Skip to content

EmoFuncs/GMAIR-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

31 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GMAIR-pytorch

An official implementation of GMAIR.

Prepare

This project uses Python 3.8 and Pytorch 1.8.1.

git clone https://github.com/EmoFuncs/GMAIR-pytorch.git
pip install -r requirements.txt

Build bbox:

cd GMAIR-pytorch/gmair/utils/bbox
python setup.py build
cp build/lib/bbox.so .

Datasets

MultiMNIST dataset

link The dataset is generated from a modified version of MultiDigitMNIST.

Fruit2D dataset

train images train annotations test images test annotations

Note that annotations are only used for evaluation.

Train

For MultiMNIST, download MultiMNIST dataset. Unzip it, and put it into 'data/multi_mnist/'. Substitute 'config.py' with 'mnist_config.py' in 'gmair/config'

cd gmair/config
cp mnist_config.py config.py
cd ../..

For Fruit2d, download Fruit2d dataset. Unzip them, and put them into 'data/fruit2d/'. Substitute 'config.py' with 'fruit_config.py' in 'gmair/config'

cd gmair/config
cp mnist_config.py config.py
cd ../..

The architecture should be:

data
|---fruit2d
|   |---test_images
|   |   |---x.png
|   |
|   |---test_labels
|   |   |---x.txt
|   |
|   |---train_images
|   |   |---y.png
|   |
|   |---train_labels
|       |---y.txt
|   
|---scatter_mnist
    |---scattered_mnist_128x128_obj14x14.hdf5

Then,

python train.py

Test

Checkpoints: MultiMNIST Fruit2D

Set the path of checkpoint file in the configuration file 'config.py' (the variable 'test_model_path').

python test.py

About

An official implementation of GMAIR

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published