1
+ import torch
2
+ import numpy as np
3
+ import os
4
+ import regex as re
5
+
6
+ from torch_geometric .data import Data
7
+ from torch_geometric .data import Dataset
8
+
9
+ from typing import Dict
10
+
11
+ label_action = [
12
+ {"id" : 0 , "A043" : "falling" },
13
+ {"id" : 1 , "A008" : "sitting down" },
14
+ {"id" : 1 , "A026" : "hopping (one foot jumping)" },
15
+ ]
16
+
17
+ file_name_regex = r"S(\d{3})C001P(\d{3})R(\d{3})A(\d{3})"
18
+ file_name_regex = re .compile (file_name_regex )
19
+
20
+
21
+ def get_label (file_name : str ) -> int :
22
+ """
23
+ Returns the label of the file
24
+
25
+ Parameters
26
+ ----------
27
+ file_name : str
28
+ Name of the file
29
+
30
+ Returns
31
+ -------
32
+ int
33
+ Label of the file
34
+ """
35
+ label = file_name [- 4 :]
36
+ for i in label_action :
37
+ if label in i :
38
+ return i ["id" ]
39
+ return - 1
40
+
41
+
42
+ def is_valid_file (file_name : str , skip : int = 11 ) -> bool :
43
+ """
44
+ Checks if the file is a valid file
45
+
46
+ Parameters
47
+ ----------
48
+ file_name : str
49
+ Name of the file
50
+ skip : int, optional
51
+ Number of frames to skip, by default 11
52
+
53
+ Returns
54
+ -------
55
+ bool
56
+ True if the file is valid, False otherwise
57
+ """
58
+ npy_file = file_name .endswith (".npy" )
59
+ file_name = file_name .split ("/" )[- 1 ].split ("." )[0 ]
60
+
61
+ if file_name_regex .match (file_name ) is None or get_label (file_name ) == - 1 :
62
+ return False
63
+
64
+ return npy_file
65
+
66
+
67
+ def get_edge_index ():
68
+ POSE_CONNECTIONS = [
69
+ (3 , 2 ),
70
+ (20 , 8 ),
71
+ (8 , 9 ),
72
+ (9 , 10 ),
73
+ (10 , 11 ),
74
+ (11 , 24 ),
75
+ (11 , 23 ),
76
+ (20 , 4 ),
77
+ (4 , 5 ),
78
+ (5 , 6 ),
79
+ (6 , 7 ),
80
+ (7 , 21 ),
81
+ (7 , 22 ),
82
+ (0 , 1 ),
83
+ (1 , 20 ),
84
+ (0 , 16 ),
85
+ (0 , 12 ),
86
+ (16 , 17 ),
87
+ (17 , 18 ),
88
+ (18 , 19 ),
89
+ (12 , 13 ),
90
+ (13 , 14 ),
91
+ (14 , 15 ),
92
+ ]
93
+ edge_index = torch .tensor (POSE_CONNECTIONS , dtype = torch .long ).t ().contiguous ()
94
+
95
+ return edge_index
96
+
97
+
98
+ def get_multiview_files (dataset_folder : str ) -> list :
99
+ """
100
+ Returns a list of files that have multiple views
101
+
102
+ Parameters
103
+ ----------
104
+ dataset_folder : str
105
+ Path to the dataset folder
106
+
107
+ Returns
108
+ -------
109
+ list
110
+ List of files that have multiple views
111
+ """
112
+ multiview_files = []
113
+
114
+ for root , dirs , files in os .walk (dataset_folder ):
115
+ for file in files :
116
+ if is_valid_file (file ):
117
+ file_name = file .split ("/" )[- 1 ].split ("." )[0 ]
118
+
119
+ file_name = file_name .split ("C001" )
120
+ other_views = [
121
+ file_name [0 ] + "C002" + file_name [1 ],
122
+ file_name [0 ] + "C003" + file_name [1 ],
123
+ ]
124
+
125
+ not_exist = False
126
+ for view in other_views :
127
+ if not os .path .exists (os .path .join (root , view + ".skeleton.npy" )):
128
+ not_exist = True
129
+ break
130
+ if not_exist :
131
+ continue
132
+
133
+ other_views .append (file_name [0 ] + "C001" + file_name [1 ])
134
+ for i in range (len (other_views )):
135
+ other_views [i ] = os .path .join (
136
+ root , other_views [i ] + ".skeleton.npy"
137
+ )
138
+ multiview_files .append (other_views )
139
+
140
+ return multiview_files
141
+
142
+
143
+ class NTUDataset (Dataset ):
144
+ """
145
+ Dataset class for the keypoint dataset
146
+ """
147
+
148
+ def __init__ (
149
+ self , dataset_folder : str , skip : int = 11 , occlude : bool = False
150
+ ) -> None :
151
+ super ().__init__ (None , None , None )
152
+ self .dataset_folder = dataset_folder
153
+ self .edge_index = get_edge_index ()
154
+
155
+ self .poses = []
156
+ self .labels = []
157
+ self .keypoints = []
158
+
159
+ self .occluded_kps = np .array ([23 , 24 , 10 , 11 , 9 , 8 , 4 , 5 , 6 , 7 , 21 , 22 ])
160
+
161
+ self .multi_view_files = get_multiview_files (dataset_folder )
162
+ for files in self .multi_view_files :
163
+ rand_view = np .random .randint (3 )
164
+
165
+ for idx , file in enumerate (files ):
166
+ file_data = np .load (file , allow_pickle = True ).item ()
167
+ frames = file_data ["skel_body0" ]
168
+
169
+ if occlude and idx == rand_view :
170
+ frames = self ._occlude_keypoints (frames )
171
+ pose_graphs = self ._create_pose_graph (frames )
172
+
173
+ if "C001" in file :
174
+ kps = self ._get_flattened_keypoints (torch .tensor (frames ))
175
+ self .keypoints .append (kps )
176
+ self .poses .append (pose_graphs )
177
+
178
+ file_name = files [0 ].split ("/" )[- 1 ].split ("." )[0 ]
179
+ self .labels .append (get_label (file_name ))
180
+
181
+ def _create_pose_graph (self , keypoints : torch .Tensor ) -> Data :
182
+ """
183
+ Creates a Pose Graph from the given keypoints and edge index
184
+
185
+ Parameters
186
+ ----------
187
+ keypoints : torch.Tensor
188
+ Keypoints of the pose
189
+ edge_index : torch.Tensor
190
+ Edge index of the pose
191
+
192
+ Returns
193
+ -------
194
+ Data
195
+ Pose Graph
196
+ """
197
+ pose_graphs = []
198
+ for t in range (keypoints .shape [0 ]):
199
+ pose_graph = Data (
200
+ x = torch .tensor (keypoints [t , :, :], dtype = torch .float ),
201
+ edge_index = self .edge_index ,
202
+ )
203
+ pose_graphs .append (pose_graph )
204
+
205
+ return pose_graphs
206
+
207
+ def _get_flattened_keypoints (self , keypoints : torch .Tensor ) -> torch .Tensor :
208
+ """
209
+ Returns the flattened keypoints
210
+
211
+ Parameters
212
+ ----------
213
+ keypoints : torch.Tensor
214
+ Keypoints
215
+
216
+ Returns
217
+ -------
218
+ torch.Tensor
219
+ Flattened keypoints
220
+ """
221
+ return keypoints .reshape (keypoints .shape [0 ], - 1 )
222
+
223
+ def _occlude_keypoints (
224
+ self , frames : torch .Tensor , mask_prob : float = 0.2
225
+ ) -> torch .Tensor :
226
+ """
227
+ Occludes the keypoints of the pose
228
+
229
+ Parameters
230
+ ----------
231
+ frames : torch.Tensor
232
+ Keypoints of the pose
233
+ mask_prob : float, optional
234
+ Probability of masking the frames, by default 0.5
235
+
236
+ Returns
237
+ -------
238
+ torch.Tensor
239
+ Occluded frames
240
+ """
241
+ index = np .random .randint (3 )
242
+ if index == 0 :
243
+ mask_indices = np .arange (0 , frames .shape [0 ] // 2 )
244
+ elif index == 1 :
245
+ mask_indices = np .arange (frames .shape [0 ] // 2 , frames .shape [0 ])
246
+ else :
247
+ mask_indices = np .arange (frames .shape [0 ])
248
+
249
+ masked_kps = frames [mask_indices ]
250
+ masked_kps [:, self .occluded_kps , :] = - 1
251
+ frames [mask_indices ] = masked_kps
252
+
253
+ return frames
254
+
255
+ def len (self ) -> int :
256
+ """
257
+ Returns the number of samples in the dataset
258
+
259
+ Returns
260
+ -------
261
+ int : len
262
+ Number of samples in the dataset
263
+ """
264
+ return len (self .labels )
265
+
266
+ def get (self , index : int ) -> Dict [str , torch .Tensor ]:
267
+ """
268
+ Returns the sample at the given index
269
+
270
+ Returns
271
+ -------
272
+ Dict[str, torch.Tensor] : sample
273
+ Sample at the given index
274
+ """
275
+ keypoints = self .keypoints [index ]
276
+ poses = self .poses [index ]
277
+ label = self .labels [index ]
278
+
279
+ return {"keypoints" : keypoints , "poses" : poses , "label" : label }
0 commit comments