Skip to content

Commit 0cfc7c9

Browse files
committedDec 15, 2023
Major updates : NTU added along with GCN+Transformer
1 parent 5fd6ea7 commit 0cfc7c9

11 files changed

+781
-76
lines changed
 

‎config/model.json

+12-7
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
{
2-
"d_model" : 512,
3-
"nhead" : 16,
4-
"num_layers" : 8,
5-
"num_features" : 99,
6-
"dropout" : 0.5,
7-
"dim_feedforward" : 2048,
8-
"num_classes" : 2
2+
"gcn_num_features" : 3,
3+
"gcn_hidden_dim1" : 32,
4+
"gcn_hidden_dim2" : 64,
5+
"gcn_output_dim" : 128,
6+
7+
"transformer_d_model" : 128,
8+
"transformer_nhead" : 4,
9+
"transformer_num_layers" : 2,
10+
"transformer_num_features" : 128,
11+
"transformer_dropout" : 0.3,
12+
"transformer_dim_feedforward" : 256,
13+
"transformer_num_classes" : 2
914
}
+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import torch
2+
from torch_geometric.data import Batch
3+
from torch.utils.data.dataloader import default_collate
4+
from typing import Any, List, Mapping, Sequence, Tuple
5+
6+
7+
class Collater:
8+
"""
9+
Collates the batch of data
10+
11+
Parameters
12+
----------
13+
dataset : torch.utils.data.Dataset
14+
Dataset to collate
15+
"""
16+
17+
def __init__(self, dataset):
18+
self.dataset = dataset
19+
20+
def __call__(self, batch) -> Any:
21+
"""
22+
Collates the batch of data
23+
24+
Parameters
25+
----------
26+
batch : List[Any]
27+
Batch of data
28+
29+
Returns
30+
-------
31+
Any
32+
Collated batch of data
33+
"""
34+
elem = batch[0]
35+
36+
if isinstance(elem, torch.Tensor):
37+
return default_collate(batch)
38+
elif isinstance(elem, float):
39+
return torch.tensor(batch, dtype=torch.float)
40+
elif isinstance(elem, int):
41+
return torch.tensor(batch)
42+
elif isinstance(elem, str):
43+
return batch
44+
elif isinstance(elem, Mapping):
45+
return {key: self([data[key] for data in batch]) for key in elem}
46+
elif isinstance(elem, Sequence) and not isinstance(elem, str):
47+
return [self(s) for s in zip(*batch)]
48+
49+
raise TypeError(f"DataLoader found invalid type: '{type(elem)}'")
50+
51+
def collate_fn(self, batch: List[Any]) -> Any:
52+
"""
53+
Collates the batch of data
54+
55+
Parameters
56+
----------
57+
batch : List[Any]
58+
Batch of data
59+
60+
Returns
61+
-------
62+
Any
63+
Collated batch of data
64+
"""
65+
batched_graphs = [item["poses"] for item in batch]
66+
labels = [item["label"] for item in batch]
67+
68+
for i in range(len(batched_graphs)):
69+
batched_graphs[i] = Batch.from_data_list(batched_graphs[i])
70+
71+
labels = torch.tensor(labels, dtype=torch.long)
72+
73+
return batched_graphs, labels
74+
75+
76+
class DataLoader(torch.utils.data.DataLoader):
77+
"""
78+
Dataloader for the single view case
79+
80+
Parameters
81+
----------
82+
dataset : torch.utils.data.Dataset
83+
Dataset to load
84+
batch_size : int
85+
Batch size, by default 1
86+
shuffle : bool, optional
87+
Whether to shuffle the dataset, by default False
88+
"""
89+
90+
def __init__(self, dataset, batch_size: int = 1, shuffle: bool = False, **kwargs):
91+
self.collator = Collater(dataset)
92+
93+
super().__init__(
94+
dataset,
95+
batch_size,
96+
shuffle,
97+
collate_fn=self.collator.collate_fn,
98+
**kwargs,
99+
)

‎data_mgmt/dataloader.py ‎data_mgmt/dataloaders/transformer.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,14 @@ def collate_fn(self, batch: List[Any]) -> Any:
6161
Any
6262
Collated batch of data
6363
"""
64-
poses = [item[0] for item in batch]
65-
labels = [item[1] for item in batch]
66-
64+
poses = [item["keypoints"] for item in batch]
65+
labels = [item["label"] for item in batch]
66+
6767
max_length = max([item.shape[0] for item in poses])
6868
masks = [torch.ones(item.shape[0]) for item in poses]
69+
6970
for i, item in enumerate(masks):
7071
masks[i] = torch.nn.functional.pad(item, (0, max_length - item.shape[0]))
71-
7272
poses = [torch.nn.functional.pad(item, (0, 0, 0, max_length - item.shape[0])) for item in poses]
7373

7474
poses = torch.stack(poses)
@@ -77,7 +77,6 @@ def collate_fn(self, batch: List[Any]) -> Any:
7777

7878
return poses, masks, labels
7979

80-
8180
class DataLoader(torch.utils.data.DataLoader):
8281
"""
8382
Dataloader for the single view case

‎data_mgmt/datasets/ntu_dataset.py

+279
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
import torch
2+
import numpy as np
3+
import os
4+
import regex as re
5+
6+
from torch_geometric.data import Data
7+
from torch_geometric.data import Dataset
8+
9+
from typing import Dict
10+
11+
label_action = [
12+
{"id": 0, "A043": "falling"},
13+
{"id" : 1, "A008" : "sitting down"},
14+
{"id": 1, "A026": "hopping (one foot jumping)"},
15+
]
16+
17+
file_name_regex = r"S(\d{3})C001P(\d{3})R(\d{3})A(\d{3})"
18+
file_name_regex = re.compile(file_name_regex)
19+
20+
21+
def get_label(file_name: str) -> int:
22+
"""
23+
Returns the label of the file
24+
25+
Parameters
26+
----------
27+
file_name : str
28+
Name of the file
29+
30+
Returns
31+
-------
32+
int
33+
Label of the file
34+
"""
35+
label = file_name[-4:]
36+
for i in label_action:
37+
if label in i:
38+
return i["id"]
39+
return -1
40+
41+
42+
def is_valid_file(file_name: str, skip: int = 11) -> bool:
43+
"""
44+
Checks if the file is a valid file
45+
46+
Parameters
47+
----------
48+
file_name : str
49+
Name of the file
50+
skip : int, optional
51+
Number of frames to skip, by default 11
52+
53+
Returns
54+
-------
55+
bool
56+
True if the file is valid, False otherwise
57+
"""
58+
npy_file = file_name.endswith(".npy")
59+
file_name = file_name.split("/")[-1].split(".")[0]
60+
61+
if file_name_regex.match(file_name) is None or get_label(file_name) == -1:
62+
return False
63+
64+
return npy_file
65+
66+
67+
def get_edge_index():
68+
POSE_CONNECTIONS = [
69+
(3, 2),
70+
(20, 8),
71+
(8, 9),
72+
(9, 10),
73+
(10, 11),
74+
(11, 24),
75+
(11, 23),
76+
(20, 4),
77+
(4, 5),
78+
(5, 6),
79+
(6, 7),
80+
(7, 21),
81+
(7, 22),
82+
(0, 1),
83+
(1, 20),
84+
(0, 16),
85+
(0, 12),
86+
(16, 17),
87+
(17, 18),
88+
(18, 19),
89+
(12, 13),
90+
(13, 14),
91+
(14, 15),
92+
]
93+
edge_index = torch.tensor(POSE_CONNECTIONS, dtype=torch.long).t().contiguous()
94+
95+
return edge_index
96+
97+
98+
def get_multiview_files(dataset_folder: str) -> list:
99+
"""
100+
Returns a list of files that have multiple views
101+
102+
Parameters
103+
----------
104+
dataset_folder : str
105+
Path to the dataset folder
106+
107+
Returns
108+
-------
109+
list
110+
List of files that have multiple views
111+
"""
112+
multiview_files = []
113+
114+
for root, dirs, files in os.walk(dataset_folder):
115+
for file in files:
116+
if is_valid_file(file):
117+
file_name = file.split("/")[-1].split(".")[0]
118+
119+
file_name = file_name.split("C001")
120+
other_views = [
121+
file_name[0] + "C002" + file_name[1],
122+
file_name[0] + "C003" + file_name[1],
123+
]
124+
125+
not_exist = False
126+
for view in other_views:
127+
if not os.path.exists(os.path.join(root, view + ".skeleton.npy")):
128+
not_exist = True
129+
break
130+
if not_exist:
131+
continue
132+
133+
other_views.append(file_name[0] + "C001" + file_name[1])
134+
for i in range(len(other_views)):
135+
other_views[i] = os.path.join(
136+
root, other_views[i] + ".skeleton.npy"
137+
)
138+
multiview_files.append(other_views)
139+
140+
return multiview_files
141+
142+
143+
class NTUDataset(Dataset):
144+
"""
145+
Dataset class for the keypoint dataset
146+
"""
147+
148+
def __init__(
149+
self, dataset_folder: str, skip: int = 11, occlude: bool = False
150+
) -> None:
151+
super().__init__(None, None, None)
152+
self.dataset_folder = dataset_folder
153+
self.edge_index = get_edge_index()
154+
155+
self.poses = []
156+
self.labels = []
157+
self.keypoints = []
158+
159+
self.occluded_kps = np.array([23, 24, 10, 11, 9, 8, 4, 5, 6, 7, 21, 22])
160+
161+
self.multi_view_files = get_multiview_files(dataset_folder)
162+
for files in self.multi_view_files:
163+
rand_view = np.random.randint(3)
164+
165+
for idx, file in enumerate(files):
166+
file_data = np.load(file, allow_pickle=True).item()
167+
frames = file_data["skel_body0"]
168+
169+
if occlude and idx == rand_view:
170+
frames = self._occlude_keypoints(frames)
171+
pose_graphs = self._create_pose_graph(frames)
172+
173+
if "C001" in file:
174+
kps = self._get_flattened_keypoints(torch.tensor(frames))
175+
self.keypoints.append(kps)
176+
self.poses.append(pose_graphs)
177+
178+
file_name = files[0].split("/")[-1].split(".")[0]
179+
self.labels.append(get_label(file_name))
180+
181+
def _create_pose_graph(self, keypoints: torch.Tensor) -> Data:
182+
"""
183+
Creates a Pose Graph from the given keypoints and edge index
184+
185+
Parameters
186+
----------
187+
keypoints : torch.Tensor
188+
Keypoints of the pose
189+
edge_index : torch.Tensor
190+
Edge index of the pose
191+
192+
Returns
193+
-------
194+
Data
195+
Pose Graph
196+
"""
197+
pose_graphs = []
198+
for t in range(keypoints.shape[0]):
199+
pose_graph = Data(
200+
x=torch.tensor(keypoints[t, :, :], dtype=torch.float),
201+
edge_index=self.edge_index,
202+
)
203+
pose_graphs.append(pose_graph)
204+
205+
return pose_graphs
206+
207+
def _get_flattened_keypoints(self, keypoints: torch.Tensor) -> torch.Tensor:
208+
"""
209+
Returns the flattened keypoints
210+
211+
Parameters
212+
----------
213+
keypoints : torch.Tensor
214+
Keypoints
215+
216+
Returns
217+
-------
218+
torch.Tensor
219+
Flattened keypoints
220+
"""
221+
return keypoints.reshape(keypoints.shape[0], -1)
222+
223+
def _occlude_keypoints(
224+
self, frames: torch.Tensor, mask_prob: float = 0.2
225+
) -> torch.Tensor:
226+
"""
227+
Occludes the keypoints of the pose
228+
229+
Parameters
230+
----------
231+
frames : torch.Tensor
232+
Keypoints of the pose
233+
mask_prob : float, optional
234+
Probability of masking the frames, by default 0.5
235+
236+
Returns
237+
-------
238+
torch.Tensor
239+
Occluded frames
240+
"""
241+
index = np.random.randint(3)
242+
if index == 0:
243+
mask_indices = np.arange(0, frames.shape[0] // 2)
244+
elif index == 1:
245+
mask_indices = np.arange(frames.shape[0] // 2, frames.shape[0])
246+
else:
247+
mask_indices = np.arange(frames.shape[0])
248+
249+
masked_kps = frames[mask_indices]
250+
masked_kps[:, self.occluded_kps, :] = -1
251+
frames[mask_indices] = masked_kps
252+
253+
return frames
254+
255+
def len(self) -> int:
256+
"""
257+
Returns the number of samples in the dataset
258+
259+
Returns
260+
-------
261+
int : len
262+
Number of samples in the dataset
263+
"""
264+
return len(self.labels)
265+
266+
def get(self, index: int) -> Dict[str, torch.Tensor]:
267+
"""
268+
Returns the sample at the given index
269+
270+
Returns
271+
-------
272+
Dict[str, torch.Tensor] : sample
273+
Sample at the given index
274+
"""
275+
keypoints = self.keypoints[index]
276+
poses = self.poses[index]
277+
label = self.labels[index]
278+
279+
return {"keypoints": keypoints, "poses": poses, "label": label}

‎data_mgmt/dataset.py ‎data_mgmt/datasets/ur_dataset.py

+87-30
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import torch
22
from torch.utils.data import Dataset
3+
from torch_geometric.data import Data
34
import numpy as np
45
import os
56

6-
# from dataloader import DataLoader
7-
87
from typing import Tuple
98

109
def get_label(file_name: str) -> int:
1110
if "adl" in file_name:
12-
return 0
13-
return 1
11+
return 1
12+
return 0
1413

1514

1615
def is_valid_file(file_name: str, skip: int = 11) -> bool:
@@ -30,23 +29,68 @@ def is_valid_file(file_name: str, skip: int = 11) -> bool:
3029
True if the file is valid, False otherwise
3130
"""
3231
npy_file = file_name.endswith(".npy")
33-
cam0 = "cam0" in file_name
3432
skip_frame_num = file_name.split("/")[-1].split("-")[-2] == str(skip)
3533

36-
return npy_file and cam0 and skip_frame_num
37-
34+
return npy_file and skip_frame_num
3835

39-
class KeypointsDataset(Dataset):
36+
def get_edge_index():
37+
"""
38+
Returns the edge index of the pose graph
39+
40+
Returns
41+
-------
42+
torch.Tensor
43+
Edge index of the pose graph
44+
"""
45+
POSE_CONNECTIONS = [
46+
(0, 1),
47+
(1, 2),
48+
(2, 3),
49+
(3, 7), # Head to left shoulder
50+
(0, 4),
51+
(4, 5),
52+
(5, 6),
53+
(6, 8), # Head to right shoulder
54+
(9, 10),
55+
(11, 12), # Left and right shoulder
56+
(11, 13),
57+
(13, 15),
58+
(15, 17),
59+
(15, 19),
60+
(15, 21), # Left arm
61+
(12, 14),
62+
(14, 16),
63+
(16, 18),
64+
(16, 20),
65+
(16, 22), # Right arm
66+
(11, 23),
67+
(12, 24),
68+
(23, 24), # Torso
69+
(23, 25),
70+
(25, 27),
71+
(27, 29),
72+
(29, 31), # Left leg
73+
(24, 26),
74+
(26, 28),
75+
(28, 30),
76+
(30, 32), # Right leg
77+
]
78+
edge_index = torch.tensor(POSE_CONNECTIONS, dtype=torch.long).t().contiguous()
79+
80+
return edge_index
81+
82+
class URDataset(Dataset):
4083
"""
4184
Dataset class for the keypoint dataset
4285
"""
4386

4487
def __init__(self, dataset_folder: str, skip: int = 11) -> None:
4588
self.dataset_folder = dataset_folder
89+
self.edge_index = get_edge_index()
4690

91+
self.keypoints = []
4792
self.poses = []
4893
self.labels = []
49-
self.file_names = []
5094

5195
for root, dirs, files in os.walk(dataset_folder):
5296
for file in files:
@@ -55,11 +99,38 @@ def __init__(self, dataset_folder: str, skip: int = 11) -> None:
5599

56100
kps = np.load(file_path)
57101
kps = kps[:, :, :3]
102+
pose_graphs = self._create_pose_graph(torch.tensor(kps))
58103
kps = self._get_flattened_keypoints(torch.tensor(kps))
59104

60-
self.poses.append(kps)
105+
self.poses.append(pose_graphs)
106+
self.keypoints.append(kps)
61107
self.labels.append(get_label(file_path))
62-
self.file_names.append(file_path)
108+
109+
def _create_pose_graph(self, keypoints: torch.Tensor) -> Data:
110+
"""
111+
Creates a Pose Graph from the given keypoints and edge index
112+
113+
Parameters
114+
----------
115+
keypoints : torch.Tensor
116+
Keypoints of the pose
117+
edge_index : torch.Tensor
118+
Edge index of the pose
119+
120+
Returns
121+
-------
122+
Data
123+
Pose Graph
124+
"""
125+
pose_graphs = []
126+
for t in range(keypoints.shape[0]):
127+
pose_graph = Data(
128+
x=torch.tensor(keypoints[t, :, :], dtype=torch.float),
129+
edge_index=self.edge_index,
130+
)
131+
pose_graphs.append(pose_graph)
132+
133+
return pose_graphs
63134

64135
def _get_flattened_keypoints(self, keypoints: torch.Tensor) -> torch.Tensor:
65136
"""
@@ -86,7 +157,7 @@ def __len__(self) -> int:
86157
int : len
87158
Number of samples in the dataset
88159
"""
89-
return len(self.poses)
160+
return len(self.keypoints)
90161

91162
def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
92163
"""
@@ -97,22 +168,8 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
97168
dict : {kps, label, file_name}
98169
A dictionary containing the keypoint array, label and file name
99170
"""
100-
poses = self.poses[index]
171+
keypoints = self.keypoints[index]
101172
label = self.labels[index]
102-
return poses, label
103-
104-
# if __name__ == "__main__":
105-
# dataset = KeypointsDataset("../data", skip=3)
106-
107-
# train_dataset, test_dataset = torch.utils.data.random_split(dataset, [int(0.8 * len(dataset)), len(dataset) - int(0.8 * len(dataset))])
108-
109-
# train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
110-
# test_loader = DataLoader(test_dataset, batch_size=2, shuffle=True)
111-
112-
# for batch in train_loader:
113-
# print(batch[0].shape)
114-
# print(batch[0])
115-
# print(batch[1].shape)
116-
# print(batch[1])
117-
# print(batch[2].shape)
118-
# break
173+
poses = self.poses[index]
174+
175+
return {"keypoints": keypoints, "label": label, "poses": poses}

‎main.py

+57-11
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
from collections import Counter
55

66
from trainer import Trainer
7-
from model import get_model
7+
from model import get_transformer_model, get_gcn_transformer_model
88
from utils.logger import Logger
99
from utils.model_config import ModelConfig
10-
from data_mgmt.dataset import KeypointsDataset
10+
from data_mgmt.datasets.ur_dataset import URDataset
11+
from data_mgmt.datasets.ntu_dataset import NTUDataset
1112

12-
def parse_args():
13+
from typing import Tuple
14+
15+
def parse_args() -> argparse.Namespace:
1316
parser = argparse.ArgumentParser(description="Train the model")
1417

1518
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
@@ -19,6 +22,24 @@ def parse_args():
1922
default="./data",
2023
help="Path to the dataset folder",
2124
)
25+
parser.add_argument(
26+
"--model",
27+
type=str,
28+
default="transformer",
29+
help="Model to use for training, transformer or gcn_transformer",
30+
)
31+
parser.add_argument(
32+
"--dataset_type",
33+
type=str,
34+
default="ur",
35+
help="Type of dataset to use, ntu or ur",
36+
)
37+
parser.add_argument(
38+
"--skip",
39+
type=int,
40+
default=11,
41+
help="Number of frames to skip",
42+
)
2243
parser.add_argument("--epochs", type=int, default=50, help="Number of epochs")
2344
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
2445
parser.add_argument(
@@ -42,14 +63,34 @@ def parse_args():
4263
default="./config/model.json",
4364
help="Path to the model config file",
4465
)
66+
parser.add_argument(
67+
"--occlude",
68+
action="store_true",
69+
help="Whether to occlude the input or not",
70+
)
4571
args = parser.parse_args()
4672

73+
if args.dataset_type not in ["ntu", "ur"]:
74+
raise ValueError("Dataset type should be either ntu or ur")
75+
76+
if args.model not in ["transformer", "gcn_transformer"]:
77+
raise ValueError("Model should be either transformer or gcn_transformer")
78+
79+
if args.dataset_type == "ur":
80+
if args.skip % 2 == 0:
81+
raise ValueError("Skip frames should be odd")
82+
if args.skip > 11:
83+
raise ValueError("Skip frames should be less than 11")
84+
4785
return args
4886

4987

50-
def load_dataset(dataset_folder, logger):
88+
def load_dataset(args : argparse.Namespace, logger : Logger) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset, torch.utils.data.Dataset]:
5189
np.random.seed(42)
52-
dataset = KeypointsDataset(dataset_folder, skip=3)
90+
if args.dataset_type == "ntu":
91+
dataset = NTUDataset(args.dataset, occlude=args.occlude)
92+
elif args.dataset_type == "ur":
93+
dataset = URDataset(args.dataset, skip=args.skip)
5394

5495
if len(dataset) > 0:
5596
logger.info("Dataset loaded successfully.")
@@ -88,19 +129,24 @@ def main():
88129
logger.info("\n")
89130
logger.info("Loading the dataset...")
90131
train_dataset, val_dataset, test_dataset = load_dataset(
91-
args.dataset, logger
132+
args, logger
92133
)
93134

94135
logger.info(f"Training dataset size: {len(train_dataset)}")
95136
logger.info(f"Validation dataset size: {len(val_dataset)}")
96137
logger.info(f"Testing dataset size: {len(test_dataset)}")
97138

98139
model_config = ModelConfig(args.model_config).get_config()
99-
model, (train_dataloader, val_dataloader, test_dataloader) = get_model(
100-
model_config, args, (train_dataset, val_dataset, test_dataset)
101-
)
102-
103-
trainer = Trainer(model, lr=args.lr, logger=logger)
140+
if args.model == "transformer":
141+
model, (train_dataloader, val_dataloader, test_dataloader) = get_transformer_model(
142+
model_config, args, (train_dataset, val_dataset, test_dataset)
143+
)
144+
elif args.model == "gcn_transformer":
145+
model, (train_dataloader, val_dataloader, test_dataloader) = get_gcn_transformer_model(
146+
model_config, args, (train_dataset, val_dataset, test_dataset)
147+
)
148+
149+
trainer = Trainer(model, lr=args.lr, logger=logger, model_type=args.model)
104150
logger.info(f"Batch size: {args.batch_size}")
105151
logger.info(f"Number of epochs: {args.epochs}")
106152
logger.info(f"Learning rate: {args.lr}")

‎model.py

+64-12
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22
from torch.utils.data import Dataset
33

44
from models.transformer import Transformer
5-
from data_mgmt.dataloader import DataLoader
5+
from models.action_recognizer import ActionRecognizer
6+
from data_mgmt.dataloaders.transformer import DataLoader as TransformerDataLoader
7+
from data_mgmt.dataloaders.gcn_transformer import DataLoader as GCNTransformerDataLoader
68

79
from typing import Dict, Tuple
810

9-
def get_model(
11+
def get_transformer_model(
1012
config: Dict,
1113
args: argparse.Namespace,
1214
dataset: Tuple[Dataset, Dataset, Dataset],
1315
) -> Tuple[
14-
Transformer, Tuple[DataLoader, DataLoader, DataLoader]
16+
Transformer, Tuple[TransformerDataLoader, TransformerDataLoader, TransformerDataLoader]
1517
]:
1618
"""
1719
Returns the model and the dataloader
@@ -31,19 +33,69 @@ def get_model(
3133
Model and the dataloaders
3234
"""
3335
train_dataset, val_dataset, test_dataset = dataset
34-
train_loader = DataLoader(
36+
train_loader = TransformerDataLoader(
3537
train_dataset, batch_size=args.batch_size, shuffle=True
3638
)
37-
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True)
38-
test_loader = DataLoader(
39+
val_loader = TransformerDataLoader(val_dataset, batch_size=args.batch_size, shuffle=True)
40+
test_loader = TransformerDataLoader(
3941
test_dataset, batch_size=args.batch_size, shuffle=True
4042
)
4143

4244
return Transformer(
43-
d_model=config["d_model"],
44-
nhead=config["nhead"],
45-
num_layers=config["num_layers"],
46-
num_features=config["num_features"],
47-
dropout=config["dropout"],
48-
dim_ff=config["dim_feedforward"],
45+
d_model=config["transformer_d_model"],
46+
nhead=config["transformer_nhead"],
47+
num_layers=config["transformer_num_layers"],
48+
num_features=config["transformer_num_features"],
49+
dropout=config["transformer_dropout"],
50+
dim_ff=config["transformer_dim_feedforward"],
51+
num_classes=config["transformer_num_classes"],
52+
dataset=args.dataset_type,
4953
), (train_loader, val_loader, test_loader)
54+
55+
def get_gcn_transformer_model(
56+
config: Dict,
57+
args: argparse.Namespace,
58+
dataset: Tuple[Dataset, Dataset, Dataset],
59+
) -> Tuple[
60+
ActionRecognizer, Tuple[GCNTransformerDataLoader, GCNTransformerDataLoader, GCNTransformerDataLoader]
61+
]:
62+
"""
63+
Returns the model and the dataloader
64+
65+
Parameters
66+
----------
67+
config : Dict
68+
Configuration for the model
69+
args : argparse.Namespace
70+
Arguments passed to the program
71+
dataset : Tuple[Dataset, Dataset, Dataset]
72+
Dataset to use for training, validation and testing
73+
74+
Returns
75+
-------
76+
Tuple[ActionRecognizer, Tuple[DataLoader, DataLoader, DataLoader]]
77+
Model and the dataloaders
78+
"""
79+
train_dataset, val_dataset, test_dataset = dataset
80+
train_loader = GCNTransformerDataLoader(
81+
train_dataset, batch_size=args.batch_size, shuffle=True
82+
)
83+
val_loader = GCNTransformerDataLoader(val_dataset, batch_size=args.batch_size, shuffle=True)
84+
test_loader = GCNTransformerDataLoader(
85+
test_dataset, batch_size=args.batch_size, shuffle=True
86+
)
87+
88+
return ActionRecognizer(
89+
gcn_num_features=config["gcn_num_features"],
90+
gcn_hidden_dim1=config["gcn_hidden_dim1"],
91+
gcn_hidden_dim2=config["gcn_hidden_dim2"],
92+
gcn_output_dim=config["gcn_output_dim"],
93+
transformer_d_model=config["transformer_d_model"],
94+
transformer_nhead=config["transformer_nhead"],
95+
transformer_num_layers=config["transformer_num_layers"],
96+
transformer_num_features=config["transformer_num_features"],
97+
transformer_dropout=config["transformer_dropout"],
98+
transformer_dim_feedforward=config["transformer_dim_feedforward"],
99+
transformer_num_classes=config["transformer_num_classes"],
100+
dataset=args.dataset_type,
101+
), (train_loader, val_loader, test_loader)

‎models/action_recognizer.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch_geometric.data import Batch
4+
5+
from models.transformer import Transformer
6+
from models.gcn import PoseGCN
7+
8+
9+
class ActionRecognizer(nn.Module):
10+
def __init__(
11+
self,
12+
gcn_num_features: int,
13+
gcn_hidden_dim1: int,
14+
gcn_hidden_dim2: int,
15+
gcn_output_dim: int,
16+
transformer_d_model: int,
17+
transformer_nhead: int,
18+
transformer_num_layers: int,
19+
transformer_num_features: int,
20+
transformer_dropout: float = 0.1,
21+
transformer_dim_feedforward: int = 2048,
22+
transformer_num_classes: int = 2,
23+
dataset: str = "ntu",
24+
) -> None:
25+
"""
26+
Parameters
27+
----------
28+
gcn_num_features : int
29+
Number of features in the input sequence
30+
gcn_hidden_dim1 : int
31+
Dimension of the first hidden layer of the GCN
32+
gcn_hidden_dim2 : int
33+
Dimension of the second hidden layer of the GCN
34+
gcn_output_dim : int
35+
Dimension of the output layer of the GCN
36+
transformer_d_model : int
37+
Dimension of the input embedding
38+
transformer_nhead : int
39+
Number of attention heads
40+
transformer_num_layers : int
41+
Number of transformer encoder layers
42+
transformer_num_features : int
43+
Number of features in the input sequence
44+
transformer_dropout : float, optional
45+
Dropout rate, by default 0.1
46+
transformer_dim_feedforward : int, optional
47+
Dimension of the feedforward network, by default 2048
48+
"""
49+
super(ActionRecognizer, self).__init__()
50+
51+
self.gcn = PoseGCN(
52+
gcn_num_features, gcn_hidden_dim1, gcn_hidden_dim2, gcn_output_dim
53+
)
54+
self.transformer = Transformer(
55+
transformer_d_model,
56+
transformer_nhead,
57+
transformer_num_layers,
58+
transformer_num_features,
59+
transformer_dropout,
60+
transformer_dim_feedforward,
61+
num_classes=transformer_num_classes,
62+
)
63+
self.num_classes = transformer_num_classes
64+
self.dataset = dataset
65+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66+
67+
def forward(self, batch: torch.Tensor) -> torch.Tensor:
68+
"""
69+
Parameters
70+
----------
71+
kps : torch.Tensor
72+
Input sequence of keypoints
73+
74+
Returns
75+
-------
76+
torch.Tensor
77+
Classification of the input sequence of keypoints
78+
"""
79+
outputs = []
80+
81+
for item in batch:
82+
view_embedding = self.gcn(item)
83+
84+
output = self.transformer(view_embedding.unsqueeze(0).to(self.device))
85+
outputs.append(output)
86+
87+
return torch.stack(outputs).squeeze(1)

‎models/gcn.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from torch_geometric.nn import GCNConv
4+
from torch_geometric.data import Data, Batch
5+
from torch_geometric.nn import global_mean_pool
6+
7+
8+
class PoseGCN(torch.nn.Module):
9+
def __init__(
10+
self, num_features: int, hidden_dim1: int, hidden_dim2: int, output_dim: int
11+
) -> None:
12+
"""
13+
Parameters
14+
----------
15+
num_features : int
16+
Number of features in the input sequence
17+
hidden_dim1 : int
18+
Dimension of the first hidden layer of the GCN
19+
hidden_dim2 : int
20+
Dimension of the second hidden layer of the GCN
21+
output_dim : int
22+
Dimension of the output layer of the GCN
23+
"""
24+
super(PoseGCN, self).__init__()
25+
self.conv1 = GCNConv(num_features, hidden_dim1)
26+
self.conv2 = GCNConv(hidden_dim1, hidden_dim2)
27+
self.conv3 = GCNConv(hidden_dim2, output_dim)
28+
29+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30+
31+
def forward(self, data: Batch) -> torch.Tensor:
32+
"""
33+
Parameters
34+
----------
35+
data : Data
36+
Pose Graph
37+
38+
Returns
39+
-------
40+
torch.Tensor
41+
Output of the GCN of shape (batch_size, output_dim)
42+
"""
43+
x, edge_index, batch = (
44+
data.x.to(self.device),
45+
data.edge_index.to(self.device),
46+
data.batch.to(self.device),
47+
)
48+
49+
x = self.conv1(x, edge_index)
50+
x = torch.relu(x)
51+
x = self.conv2(x, edge_index)
52+
x = torch.relu(x)
53+
x = self.conv3(x, edge_index)
54+
55+
x = global_mean_pool(x, batch)
56+
return x

‎models/transformer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
dropout: float = 0.1,
5252
dim_ff: int = 2048,
5353
num_classes: int = 2,
54+
dataset: str = "ntu",
5455
) -> None:
5556
"""
5657
Parameters
@@ -74,6 +75,7 @@ def __init__(
7475
self.num_layers = num_layers
7576
self.num_features = num_features
7677
self.num_classes = num_classes
78+
self.dataset = dataset
7779

7880
self.pos_encoding = get_positional_encoding(
7981
1000, d_model
@@ -89,7 +91,7 @@ def __init__(
8991
)
9092
self.decoder = nn.Linear(self.d_model, self.num_classes)
9193

92-
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
94+
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
9395
"""
9496
Parameters
9597
----------

‎trainer.py

+33-10
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
model: nn.Module,
2323
lr: float = 5e-5,
2424
logger: Logger = None,
25+
model_type: str = "transformer",
2526
) -> None:
2627
"""
2728
Parameters
@@ -37,6 +38,8 @@ def __init__(
3738
self.logger = logger
3839
self.model = model
3940
self.lr = lr
41+
self.model_type = model_type
42+
4043
self.writer = SummaryWriter()
4144

4245
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -148,8 +151,12 @@ def train_one_epoch(
148151
epoch_correct = 0
149152
epoch_count = 0
150153
for idx, batch in enumerate(iter(train_loader)):
151-
predictions = self.model(batch[0].float().to(self.device), batch[1].to(self.device))
152-
labels = batch[2].to(self.device)
154+
if self.model_type == "transformer":
155+
predictions = self.model(batch[0].float().to(self.device), batch[1].to(self.device))
156+
labels = batch[2].to(self.device)
157+
elif self.model_type == "gcn_transformer":
158+
predictions = self.model(batch[0])
159+
labels = batch[1].to(self.device)
153160

154161
loss = self.criterion(predictions, labels)
155162
self.writer.add_scalar("Training loss per batch", loss, idx)
@@ -191,8 +198,12 @@ def evaluate(
191198
val_epoch_count = 0
192199

193200
for idx, batch in enumerate(iter(val_loader)):
194-
predictions = self.model(batch[0].float().to(self.device), batch[1].to(self.device))
195-
labels = batch[2].to(self.device)
201+
if self.model_type == "transformer":
202+
predictions = self.model(batch[0].float().to(self.device), batch[1].to(self.device))
203+
labels = batch[2].to(self.device)
204+
elif self.model_type == "gcn_transformer":
205+
predictions = self.model(batch[0])
206+
labels = batch[1].to(self.device)
196207

197208
val_loss = self.criterion(predictions, labels)
198209
self.writer.add_scalar("Validation loss per batch", val_loss, idx)
@@ -224,6 +235,7 @@ def test(
224235
Tuple containing the test epoch loss, test epoch correct
225236
and test epoch count
226237
"""
238+
output_path = os.path.join(output_path, self.model.dataset)
227239
if not os.path.exists(output_path):
228240
os.makedirs(output_path)
229241

@@ -237,14 +249,19 @@ def test(
237249
+ ".pt"
238250
)
239251
torch.save(self.best_model.state_dict(), os.path.join(output_path, file_name))
252+
240253
self.best_model.to(self.device)
241254
self.best_model.eval()
242255
with torch.no_grad():
243256
predictions = []
244257
labels = []
245258
for idx, batch in enumerate(iter(test_loader)):
246-
predictions.extend(self.best_model(batch[0].float().to(self.device), batch[1].to(self.device)).argmax(axis=1).tolist())
247-
labels.extend(batch[2].tolist())
259+
if self.model_type == "transformer":
260+
predictions.extend(self.best_model(batch[0].float().to(self.device), batch[1].to(self.device)).argmax(axis=1).tolist())
261+
labels.extend(batch[2].tolist())
262+
elif self.model_type == "gcn_transformer":
263+
predictions.extend(self.best_model(batch[0]).argmax(axis=1).tolist())
264+
labels.extend(batch[1].tolist())
248265

249266
self.logger.info(f"Predictions: {predictions}")
250267
self.logger.info(f"Labels: {labels}")
@@ -253,11 +270,17 @@ def test(
253270
precision, recall, f1_score, _ = precision_recall_fscore_support(
254271
labels, predictions, average="weighted"
255272
)
273+
cm = confusion_matrix(labels, predictions)
274+
tn, fp, fn, tp = cm.ravel()
275+
sensitivity = tp / (tp + fn)
276+
specificity = tn / (tn + fp)
277+
geometric_mean = (sensitivity * specificity) ** 0.5
256278

257279
self.logger.info(f"Accuracy: {accuracy:.4f}")
258-
self.logger.info(f"Precision: {precision}")
259-
self.logger.info(f"Recall: {recall}")
260-
self.logger.info(f"F1 Score: {f1_score}")
280+
self.logger.info(f"Precision: {precision:.4f}")
281+
self.logger.info(f"Recall: {recall:.4f}")
282+
self.logger.info(f"F1 Score: {f1_score:.4f}")
283+
self.logger.info(f"G-Mean: {geometric_mean:.4f}")
261284

262285
plt.figure(figsize=(8, 6))
263286
colors = cycle(["aqua", "darkorange"])
@@ -288,7 +311,6 @@ def test(
288311
plt.savefig(os.path.join(output_path, file_name))
289312
plt.show()
290313

291-
cm = confusion_matrix(labels, predictions)
292314
ax = sns.heatmap(
293315
cm,
294316
annot=True,
@@ -319,6 +341,7 @@ def _plot_losses(self, output_path: str) -> None:
319341
-------
320342
None
321343
"""
344+
output_path = os.path.join(output_path, self.model.dataset)
322345
if not os.path.exists(output_path):
323346
os.makedirs(output_path)
324347

0 commit comments

Comments
 (0)
Please sign in to comment.