Skip to content

Commit

Permalink
Implement siamese network for one shot leraning
Browse files Browse the repository at this point in the history
  • Loading branch information
Seungmin Oh committed Jan 18, 2021
1 parent 773882b commit 87fab5c
Show file tree
Hide file tree
Showing 16 changed files with 36 additions and 87 deletions.
File renamed without changes.
28 changes: 19 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,45 @@ You can run one shot learning step by step. Also, I posted the details of the co

### 🚀How to run

All executions begin at the location of `./siamese`. You can execute three action. *just run*, *download-data*, *train*, *test*.
You can execute three action. *just run*, *download-data*, *train*, *test*.

1. #### Run
1. #### Clone

This commend automatically executes the entire process according to `config_maker`. If you just want to try this network, I recommend this.
Clone this repository and go into the directory.

```bash
git clone https://github.com/Rhcsky/siamese-one-shot-pytorch.git

cd siamese-one-shot-pytorch
```

2. #### Run

This commend automatically executes the entire process according to `config_maker`(download data + train + test).

If you just want to try this network, I recommend this.

```bash
python main.py run
```

2. #### Download-data
3. #### Download-data

The Omniglot data is downloaded and divided into 30 types of train data, 10 types of validation data, and 10 types of test data. All data is contained in `./data/processed/`.

```bash
python main.py download-data
```

3. #### Train
4. #### Train

Only model learning is conducted. If you want to run 'train', you have to run 'download-data' first.

```bash
python main.py train
```

4. #### Test
5. #### Test

Only test the model. Stored models and datasets must exist.

Expand All @@ -48,7 +60,7 @@ All parameters are present in `config_maker`. If you want to adjust the paramete



### Result
### Check Result

Train logs, saved model and configuration data were in `./result/[model_number]`. Logs are made by `tensorboard`. So if you want to see more detail about train metrics, write commend on `./siamese_network/result/[model_number]` like this.

Expand All @@ -60,8 +72,6 @@ tensorboard --logdir=logs

### 📌Reference

siamese network

* [Siamese Neural Networks for One-shot Image Recognition](https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf)

* [kevinzakka/one-shot-siamese](https://github.com/kevinzakka/one-shot-siamese)
Expand Down
File renamed without changes.
File renamed without changes.
24 changes: 14 additions & 10 deletions siamese_network/data_prepare.py → data_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm


def copy_image_to_processed_dir(alpha_list, img_dir, desc):
def move_image_to_processed_dir(alpha_list, img_dir, desc):
for alpha in tqdm(alpha_list, desc=desc):
write_dir1 = img_dir + '/' + os.path.basename(alpha) + '_'
for char in (os.listdir(alpha)):
Expand All @@ -19,17 +19,21 @@ def copy_image_to_processed_dir(alpha_list, img_dir, desc):


def prepare_data():
background_dir = "../data/unzip/background"
evaluation_dir = "../data/unzip/evaluation"
processed_dir = "../data/processed"
background_dir = "data/unzip/background"
evaluation_dir = "data/unzip/evaluation"
processed_dir = "data/processed"
random.seed(5)

if not os.path.exists(processed_dir):
os.makedirs(processed_dir)
os.makedirs(processed_dir +'/train')
os.makedirs(processed_dir +'/val')
os.makedirs(processed_dir +'/test')


if any([True for _ in os.scandir(processed_dir)]):
return

if os.path.exists(processed_dir) is None:
os.makedirs(processed_dir)

# Move 10 of evaluation image for getting more train set.
if len(glob(evaluation_dir + '/*')) >= 20:
for d in random.sample(glob(evaluation_dir + '/*'), 10):
Expand All @@ -48,6 +52,6 @@ def prepare_data():
val_dir = os.path.join(processed_dir, 'val')
test_dir = os.path.join(processed_dir, 'test')

copy_image_to_processed_dir(train_alpha, train_dir, 'train')
copy_image_to_processed_dir(val_alpha, val_dir, 'val')
copy_image_to_processed_dir(test_alpha, test_dir, 'test')
move_image_to_processed_dir(train_alpha, train_dir, 'train')
move_image_to_processed_dir(val_alpha, val_dir, 'val')
move_image_to_processed_dir(test_alpha, test_dir, 'test')
File renamed without changes.
File renamed without changes.
File renamed without changes.
Binary file added raw
Binary file not shown.
Binary file added raw6du9mufz.tmp
Binary file not shown.
File renamed without changes.
File renamed without changes.
67 changes: 0 additions & 67 deletions siamese_network/README.md

This file was deleted.

File renamed without changes.
File renamed without changes.
4 changes: 3 additions & 1 deletion siamese_network/utils.py → utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,15 @@ def load_config(config):

# download omniglot dataset
def download_omniglot_data():
BASEDIR = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir)) + '/data'
BASEDIR = os.path.dirname(os.path.realpath(__file__)) + "/data"

# make directory
if not os.path.exists(BASEDIR):
os.mkdir(BASEDIR)
if not os.path.exists(os.path.join(BASEDIR, 'unzip')):
os.mkdir(os.path.join(BASEDIR, 'unzip'))
if not os.path.exists(os.path.join(BASEDIR, 'raw')):
os.mkdir(os.path.join(BASEDIR, 'raw'))

# download zip file
if not os.path.exists(BASEDIR + '/raw/images_background.zip'):
Expand Down

0 comments on commit 87fab5c

Please sign in to comment.