Skip to content

Commit c8e6c33

Browse files
authored
Merge pull request #441 from zwtu/YOWO
Add YOWO
2 parents a2a838f + 38d6ac2 commit c8e6c33

35 files changed

+3278
-7
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ __pycache__/
55
*.swp
66
*.swo
77
*.swn
8+
.DS_Store
89

910
# Byte-compiled / optimized / DLL files
1011
__pycache__/

configs/localization/yowo.yaml

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
MODEL: #MODEL field
2+
framework: "YOWOLocalizer" #Mandatory, indicate the type of network, associate to the 'paddlevideo/modeling/framework/' .
3+
backbone: #Mandatory, indicate the type of backbone, associate to the 'paddlevideo/modeling/backbones/' .
4+
name: "YOWO" #Mandatory, The name of backbone.
5+
num_class: 24
6+
pretrained_2d: "data/ucf24/darknet.pdparam"
7+
pretrained_3d: "data/ucf24/resnext101_kinetics.pdparams"
8+
loss:
9+
name: "RegionLoss"
10+
num_classes: 24
11+
num_anchors: 5
12+
anchors: [0.70458, 1.18803, 1.26654, 2.55121, 1.59382, 4.08321, 2.30548, 4.94180, 3.52332, 5.91979]
13+
object_scale: 5
14+
noobject_scale: 1
15+
class_scale: 1
16+
coord_scale: 1
17+
18+
DATASET: #DATASET field
19+
batch_size: 8 #Mandatory, bacth size
20+
num_workers: 4 #Mandatory, XXX the number of subprocess on each GPU.
21+
test_batch_size: 8
22+
valid_batch_size: 8
23+
train:
24+
format: "UCF24Dataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
25+
file_path: "data/ucf24/trainlist.txt" #Mandatory, train data index file path
26+
valid:
27+
format: "UCF24Dataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
28+
file_path: "data/ucf24/testlist.txt" #Mandatory, test data index file path
29+
test:
30+
format: "UCF24Dataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
31+
file_path: "data/ucf24/testlist.txt" #Mandatory, test data index file path
32+
33+
PIPELINE: #PIPELINE field TODO.....
34+
train: #Mandotary, indicate the pipeline to deal with the training data, associate to the 'paddlevideo/loader/pipelines/'
35+
sample:
36+
name: "SamplerUCF24"
37+
num_frames: 16
38+
valid_mode: False
39+
transform: #Mandotary, image transform operator.
40+
- YowoAug:
41+
valid_mode: False
42+
valid: #Mandotary, indicate the pipeline to deal with the training data, associate to the 'paddlevideo/loader/pipelines/'
43+
sample:
44+
name: "SamplerUCF24"
45+
num_frames: 16
46+
valid_mode: True
47+
transform: #Mandotary, image transform operator.
48+
- YowoAug:
49+
valid_mode: True
50+
test:
51+
sample:
52+
name: "SamplerUCF24"
53+
num_frames: 16
54+
valid_mode: True
55+
transform: #Mandotary, image transform operator.
56+
- YowoAug:
57+
valid_mode: True
58+
59+
OPTIMIZER: #OPTIMIZER field
60+
name: Adam
61+
learning_rate:
62+
learning_rate: 0.0001
63+
name: 'MultiStepDecay'
64+
milestones: [1, 2, 3, 4]
65+
gamma: 0.5
66+
weight_decay:
67+
name: "L2"
68+
value: 0.0005
69+
70+
GRADIENT_ACCUMULATION:
71+
global_batch_size: 128 # Specify the sum of batches to be calculated by all GPUs
72+
73+
METRIC:
74+
name: 'YOWOMetric'
75+
gt_folder: 'data/ucf24/groundtruths_ucf'
76+
result_path: 'output/detections_test'
77+
threshold: 0.5
78+
log_interval: 100
79+
80+
INFERENCE:
81+
name: 'YOWO_Inference_helper'
82+
num_seg: 16
83+
target_size: 224
84+
85+
model_name: "YOWO"
86+
log_interval: 20 #Optional, the interal of logger, default:10
87+
save_interval: 1
88+
epochs: 5 #Mandatory, total epoch
89+
log_level: "INFO" #Optional, the logger level. default: "INFO"
90+
val_interval: 1

data/ucf24/build_split.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import argparse
2+
import os
3+
4+
5+
def build_split_list(raw_path, mode):
6+
"""
7+
Generate target format splits based on original splits
8+
"""
9+
raw_path = os.path.join(raw_path, '{}list01.txt'.format(mode))
10+
print('{} analysis begin'.format(raw_path))
11+
with open(raw_path, 'r') as fin:
12+
lines = fin.readlines()
13+
fin.close()
14+
15+
with open('{}list.txt'.format(mode), 'w') as fout:
16+
for i, line in enumerate(lines):
17+
line = line.strip() # 'class_name/video_name'
18+
label_dir = os.path.join('labels', line) # 'data/ucf24/labels/class_name/video_name'
19+
if not os.path.isdir(label_dir):
20+
continue
21+
txt_list = os.listdir(label_dir)
22+
txt_list.sort()
23+
for txt_item in txt_list:
24+
filename = os.path.join('data', 'ucf24', label_dir, txt_item)
25+
fout.write(filename + '\n')
26+
if i % 200 == 0:
27+
print('{} videos parsed'.format(i))
28+
fout.close()
29+
print('{} analysis done'.format(raw_path))
30+
31+
32+
def parse_args():
33+
parser = argparse.ArgumentParser(description='Build file list')
34+
parser.add_argument('--raw_path', type=str, default='./splitfiles')
35+
args = parser.parse_args()
36+
return args
37+
38+
39+
if __name__ == '__main__':
40+
args = parse_args()
41+
build_split_list(args.raw_path, 'train')
42+
build_split_list(args.raw_path, 'test')
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#! /usr/bin/bash env
2+
3+
wget --no-check-certificate "https://videotag.bj.bcebos.com/Data/ucf24.zip"
4+
unzip -q ucf24.zip
5+
rm -rf ./ucf24.zip

data/ucf24/visualization.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import argparse
2+
import imageio
3+
import os
4+
5+
6+
def imgs2gif(frames_dir, duration):
7+
"""
8+
img_dir: directory for inference results
9+
duration: duration = 1 / fps
10+
"""
11+
frames = []
12+
for idx in sorted(os.listdir(frames_dir)):
13+
img = os.path.join(frames_dir, idx)
14+
if img.endswith('jpg'):
15+
frames.append(imageio.imread(img))
16+
save_name = '.'.join([frames_dir, 'gif'])
17+
imageio.mimsave(save_name, frames, 'GIF', duration=duration)
18+
print(save_name, 'saved!')
19+
20+
def parse_args():
21+
parser = argparse.ArgumentParser(description='Build file list')
22+
parser.add_argument('--frames_dir', type=str, default='./inference/YOWO_infer/HorseRiding')
23+
parser.add_argument('--duration', type=float, default=0.04)
24+
args = parser.parse_args()
25+
return args
26+
27+
28+
if __name__ == '__main__':
29+
args = parse_args()
30+
imgs2gif(args.frames_dir, args.duration)

docs/en/dataset/ucf24.md

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
English | [简体中文](../../zh-CN/dataset/ucf24.md)
2+
3+
# UCF24 Data Preparation
4+
This document mainly introduces the preparation process of UCF24 dataset. It mainly includes the download of the RGB frame files, the annotation files and the pathlist of the generated file.
5+
6+
---
7+
## 1. Data Download
8+
Detailed information on UCF24 data can be found on the website [UCF24](http://www.thumos.info/download.html). For ease of use, PaddleVideo provides a download script for the RGB frame, annotation file of the UCF24 data.
9+
10+
First, please ensure access to the [data/ucf24/ directory](../../../data/ucf24) and enter the following command for downloading the RGB frame, annotation file of the UCF24 dataset.
11+
12+
```shell
13+
bash download_frames_annotations.sh
14+
```
15+
16+
- To run this command you need to install the unrar decompression tool, which can be installed using the pip method.
17+
18+
- The RGB frame files will be stored in the [data/ucf24/rgb-images/ directory](../../../data/ucf24/rgb-images)
19+
20+
- The annotation files will be stored in the [data/ucf24/lables/ directory](../../../data/ucf24/labels)
21+
22+
---
23+
## 2. File Pathlist Generation
24+
To specify the format for dividing the file, enter the following command
25+
26+
```python
27+
python build_split.py --raw_path ./splitfiles
28+
```
29+
30+
**Description of parameters**
31+
32+
`--raw_path`: indicates the storage path of the original division file
33+
34+
35+
# Folder Structure
36+
After the whole data pipeline for UCF24 preparation, the folder structure will look like:
37+
38+
```
39+
├── data
40+
│ ├── ucf24
41+
│ | ├── groundtruths_ucf
42+
│ | ├── labels
43+
│ | | ├── Basketball
44+
│ | | | ├── v_Basketball_g01_c01
45+
│ | | | | ├── 00009.txt
46+
│ | | | | ├── 00010.txt
47+
│ | | | | ├── ...
48+
│ | | | | ├── 00050.txt
49+
│ | | | | ├── 00051.txt
50+
│ | | ├── ...
51+
│ | | ├── WalkingWithDog
52+
│ | | | ├── v_WalkingWithDog_g01_c01
53+
│ | | | ├── ...
54+
│ | | | ├── v_WalkingWithDog_g25_c04
55+
│ | ├── rgb-images
56+
│ | | ├── Basketball
57+
│ | | | ├── v_Basketball_g01_c01
58+
│ | | | | ├── 00001.jpg
59+
│ | | | | ├── 00002.jpg
60+
│ | | | | ├── ...
61+
│ | | | | ├── 00140.jpg
62+
│ | | | | ├── 00141.jpg
63+
│ | | ├── ...
64+
│ | | ├── WalkingWithDog
65+
│ | | | ├── v_WalkingWithDog_g01_c01
66+
│ | | | ├── ...
67+
│ | | | ├── v_WalkingWithDog_g25_c04
68+
│ | ├── splitfiles
69+
│ | | ├── trainlist01.txt
70+
│ | | |── testlist01.txt
71+
│ | ├── trainlist.txt
72+
│ | |── testlist.txt
73+
```
+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
[简体中文](../../../zh-CN/model_zoo/localization/yowo.md) | English
2+
3+
# YOWO
4+
5+
## Content
6+
7+
- [Introduction](#Introduction)
8+
- [Data](#DATA)
9+
- [Train](#Train)
10+
- [Test](#Test)
11+
- [Inference](#Inference)
12+
- [Reference](#Reference)
13+
14+
15+
## Introduction
16+
17+
YOWO is a single-stage network with two branches. One branch extracts spatial features of key frames (i.e., the current frame) via 2D-CNN, while the other branch acquires spatio-temporal features of clips consisting of previous frames via 3D-CNN. To accurately aggregate these features, YOWO uses a channel fusion and attention mechanism that maximizes the inter-channel dependencies. Finally, the fused features are subjected to frame-level detection.
18+
19+
20+
<div align="center">
21+
<img src="../../../images/yowo.jpg">
22+
</div>
23+
24+
25+
## Data
26+
27+
UCF101-24 data download and preparation please refer to [UCF101-24 data preparation](../../dataset/ucf24.md)
28+
29+
30+
## Train
31+
32+
### UCF101-24 data set training
33+
34+
#### Download and add pre-trained models
35+
36+
1. Download the pre-training model [resnext-101-kinetics](https://videotag.bj.bcebos.com/PaddleVideo-release2.3/resnext101_kinetics.pdparams)[darknet](https://videotag.bj.bcebos.com/PaddleVideo-release2.3/darknet.pdparam) as Backbone initialization parameters, or download through the wget command
37+
38+
```bash
39+
wget -nc https://videotag.bj.bcebos.com/PaddleVideo-release2.3/darknet.pdparam
40+
wget -nc https://videotag.bj.bcebos.com/PaddleVideo-release2.3/resnext101_kinetics.pdparams
41+
```
42+
43+
2. Open `PaddleVideo/configs/localization/yowo.yaml`, and fill in the downloaded weight storage path below `pretrained_2d:` and `pretrained_3d:` respectively
44+
45+
```yaml
46+
MODEL:
47+
framework: "YOWOLocalizer"
48+
backbone:
49+
name: "YOWO"
50+
num_class: 24
51+
pretrained_2d: fill in the path of 2D pre-training model here
52+
pretrained_3d: fill in the path of 3D pre-training model here
53+
```
54+
55+
#### Start training
56+
57+
- The UCF101-24 data set uses 1 card for training, and the start command of the training method is as follows:
58+
59+
```bash
60+
python3 main.py -c configs/localization/yowo.yaml --validate --seed=1
61+
```
62+
63+
- Turn on amp mixed-precision training to speed up the training process. The training start command is as follows:
64+
65+
```bash
66+
python3 main.py --amp -c configs/localization/yowo.yaml --validate --seed=1
67+
```
68+
69+
- In addition, you can customize and modify the parameter configuration to achieve the purpose of training/testing on different data sets. It is recommended that the naming method of the configuration file is `model_dataset name_file format_data format_sampling method.yaml` , Please refer to [config](../../tutorials/config.md) for parameter usage.
70+
71+
72+
## Test
73+
74+
- The YOWO model is verified synchronously during training. You can find the keyword `best` in the training log to obtain the model test accuracy. The log example is as follows:
75+
76+
```
77+
Already save the best model (fsocre)0.8779
78+
```
79+
80+
- Since the verification index of the YOWO model test mode is **Frame-mAP (@ IoU 0.5)**, which is different from the **fscore** used in the verification mode during the training process, so the verification index recorded in the training log, called `fscore `, does not represent the final test score, so after the training is completed, you can use the test mode to test the best model to obtain the final index, the command is as follows:
81+
82+
```bash
83+
python3 main.py -c configs/localization/yowo.yaml --test --seed=1 -w 'output/YOWO/YOWO_epoch_00005.pdparams'
84+
```
85+
86+
87+
When the test configuration uses the following parameters, the test indicators on the validation data set of UCF101-24 are as follows:
88+
89+
90+
| Model | 3D-CNN backbone | 2D-CNN backbone | Dataset |Input | Frame-mAP <br>(@ IoU 0.5) | checkpoints |
91+
| :-----------: | :-----------: | :-----------: | :-----------: | :-----------: | :-----------: | :-----------: |
92+
| YOWO | 3D-ResNext-101 | Darknet-19 | UCF101-24 | 16-frames, d=1 | 80.94 | [YOWO.pdparams](https://videotag.bj.bcebos.com/PaddleVideo-release2.3/YOWO_epoch_00005.pdparams) |
93+
94+
95+
96+
## Inference
97+
98+
### Export inference model
99+
100+
```bash
101+
python3 tools/export_model.py -c configs/localization/yowo.yaml -p 'output/YOWO/YOWO_epoch_00005.pdparams'
102+
```
103+
104+
The above command will generate the model structure file `YOWO.pdmodel` and the model weight file `YOWO.pdiparams` required for prediction.
105+
106+
- For the meaning of each parameter, please refer to [Model Reasoning Method](../../usage.md#2-infer)
107+
108+
### Use prediction engine inference
109+
110+
- Download the test video [HorseRiding.avi](https://videotag.bj.bcebos.com/Data/HorseRiding.avi) for a quick experience, or via the wget command. The downloaded video should be placed in the `data/ucf24` directory:
111+
112+
```bash
113+
wget -nc https://videotag.bj.bcebos.com/Data/HorseRiding.avi
114+
```
115+
116+
- Run the following command for inference:
117+
118+
```bash
119+
python3 tools/predict.py -c configs/localization/yowo.yaml -i 'data/ucf24/HorseRiding.avi' --model_file ./inference/YOWO.pdmodel --params_file ./inference/YOWO.pdiparams
120+
```
121+
122+
- When inference is over, the prediction results in image form will be saved in the `inference/YOWO_infer` directory. The image sequence can be converted to a gif by running the following command to complete the final visualisation.
123+
124+
```
125+
python3 data/ucf24/visualization.py --frames_dir ./inference/YOWO_infer/HorseRiding --duration 0.04
126+
```
127+
128+
The resulting visualization is as follows:
129+
130+
<div align="center">
131+
<img src="../../../images/horse_riding.gif" alt="Horse Riding">
132+
</div>
133+
134+
It can be seen that using the YOWO model trained on UCF101-24 to predict `data/ucf24/HorseRiding.avi`, the category of each frame output is HorseRiding with a confidence level of about 0.80.
135+
136+
## Reference
137+
138+
- [You Only Watch Once: A Unified CNN Architecture for Real-Time Spatiotemporal Action Localization](https://arxiv.org/pdf/1911.06644.pdf), Köpüklü O, Wei X, Rigoll G.

docs/images/horse_riding.gif

996 KB
Loading

docs/images/yowo.jpg

314 KB
Loading

0 commit comments

Comments
 (0)