-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtraining.py
More file actions
375 lines (320 loc) · 12.6 KB
/
Copy pathtraining.py
File metadata and controls
375 lines (320 loc) · 12.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score, precision_score, recall_score, f1_score
import logging
from tqdm import tqdm
from torch.utils.data import DataLoader
from typing import Dict, Tuple, Optional, Union
logger = logging.getLogger(__name__)
def triplet_loss(anchor, positive, negative, margin=1.0):
"""
Compute triplet loss.
Args:
anchor: Anchor embeddings
positive: Positive sample embeddings
negative: Negative sample embeddings
margin: Margin for triplet loss
Returns:
Loss value
"""
# Compute distances
pos_dist = F.pairwise_distance(anchor, positive)
neg_dist = F.pairwise_distance(anchor, negative)
# Compute loss
loss = torch.clamp(pos_dist - neg_dist + margin, min=0.0)
# Add small epsilon to avoid zero loss when distances are equal
loss = loss + 1e-8
return loss.mean()
def compute_metrics(
model: torch.nn.Module,
data: Union[DataLoader, Dict[str, torch.Tensor]],
device: torch.device
) -> Tuple[Dict[str, float], np.ndarray]:
"""Compute metrics for model evaluation.
Args:
model: Model to evaluate
data: Either a DataLoader or a dictionary containing 'image', 'audio', and 'crow_id'
device: Device to run evaluation on
Returns:
Tuple of (metrics dict, similarities array)
"""
model.eval()
embeddings = []
labels = []
with torch.no_grad():
if isinstance(data, dict):
images = data['image'].to(device)
# Early return if batch is empty
if images.shape[0] == 0:
return {
'accuracy': 0.0,
'precision': 0.0,
'recall': 0.0,
'f1': 0.0
}, np.array([])
audio = None
if data['audio'] is not None:
audio = {
'mel_spec': data['audio']['mel_spec'].to(device),
'chroma': data['audio']['chroma'].to(device)
}
labels = data['crow_id']
embeddings = model(images, audio).cpu().numpy()
else:
# Handle DataLoader format
for batch in data:
if isinstance(batch, dict):
images = batch['image'].to(device)
audio = None
if batch['audio'] is not None:
audio = {
'mel_spec': batch['audio']['mel_spec'].to(device),
'chroma': batch['audio']['chroma'].to(device)
}
batch_labels = batch['crow_id']
else:
images, audio, batch_labels = batch
images = images.to(device)
if audio is not None:
audio = {
'mel_spec': audio[0].to(device),
'chroma': audio[1].to(device)
}
batch_embeddings = model(images, audio).cpu().numpy()
embeddings.append(batch_embeddings)
labels.extend(batch_labels)
if embeddings:
embeddings = np.vstack(embeddings)
if len(embeddings) == 0:
return {
'accuracy': 0.0,
'precision': 0.0,
'recall': 0.0,
'f1': 0.0
}, np.array([])
# Compute cosine similarities
# Normalize embeddings to unit length to ensure proper cosine similarity
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
similarities = np.dot(embeddings, embeddings.T)
# Ensure similarities are in [-1, 1] range
similarities = np.clip(similarities, -1.0, 1.0)
# Convert labels to integers for metric computation
unique_labels = sorted(set(labels))
label_to_int = {label: i for i, label in enumerate(unique_labels)}
labels_int = np.array([label_to_int[label] for label in labels])
# Compute metrics using similarity threshold of 0.5
threshold = 0.5
predictions = (similarities > threshold).astype(int)
# Create binary labels for each pair
true_labels = (labels_int[:, None] == labels_int[None, :]).astype(int)
# Compute metrics
accuracy = np.mean(predictions == true_labels)
precision = np.sum((predictions == 1) & (true_labels == 1)) / (np.sum(predictions == 1) + 1e-8)
recall = np.sum((predictions == 1) & (true_labels == 1)) / (np.sum(true_labels == 1) + 1e-8)
f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
metrics = {
'accuracy': float(accuracy),
'precision': float(precision),
'recall': float(recall),
'f1': float(f1)
}
return metrics, similarities
def training_step(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
batch: Dict[str, torch.Tensor],
device: torch.device
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""Perform a single training step.
Args:
model: Model to train
optimizer: Optimizer to use
batch: Either a dictionary containing 'image', 'audio', and 'crow_id',
or a tuple of (images, audio, labels)
device: Device to run training on
Returns:
Tuple of (loss tensor, metrics dict)
"""
model.train()
optimizer.zero_grad()
# Get data
if isinstance(batch, dict):
images = batch['image'].to(device)
audio = None
if batch['audio'] is not None:
audio = {
'mel_spec': batch['audio']['mel_spec'].to(device),
'chroma': batch['audio']['chroma'].to(device)
}
labels = batch['crow_id']
else:
# Handle tuple format (imgs, audio, labels)
images, audio, labels = batch
images = images.to(device)
if audio is not None:
audio = {
'mel_spec': audio[0].to(device),
'chroma': audio[1].to(device)
}
# Forward pass
embeddings = model(images, audio)
# Create triplets
unique_labels = sorted(set(labels))
if len(unique_labels) < 2:
# If we don't have enough unique classes, create synthetic triplets
# by using the same sample as both anchor and positive
triplets = []
for i, (emb, label) in enumerate(zip(embeddings, labels)):
# Find a negative sample (different class)
neg_indices = [j for j, l in enumerate(labels) if l != label]
if not neg_indices:
# If no negative samples, use the same sample as negative
# but with a small perturbation
neg_emb = emb + torch.randn_like(emb) * 0.1
triplets.append((emb, emb, neg_emb))
else:
# Randomly select a negative sample
neg_idx = np.random.choice(neg_indices)
triplets.append((emb, emb, embeddings[neg_idx]))
else:
# Normal triplet creation with at least two unique classes
triplets = []
for i, (emb, label) in enumerate(zip(embeddings, labels)):
# Find positive sample (same class)
pos_indices = [j for j, l in enumerate(labels) if l == label and j != i]
if not pos_indices:
# If no positive samples, use the same sample
pos_emb = emb
else:
# Randomly select a positive sample
pos_idx = np.random.choice(pos_indices)
pos_emb = embeddings[pos_idx]
# Find negative sample (different class)
neg_indices = [j for j, l in enumerate(labels) if l != label]
if not neg_indices:
# If no negative samples, use a perturbed version of the same sample
neg_emb = emb + torch.randn_like(emb) * 0.1
else:
# Randomly select a negative sample
neg_idx = np.random.choice(neg_indices)
neg_emb = embeddings[neg_idx]
triplets.append((emb, pos_emb, neg_emb))
# Compute loss
anchors, positives, negatives = zip(*triplets)
loss = compute_triplet_loss(
torch.stack(anchors),
torch.stack(positives),
torch.stack(negatives),
margin=1.0
)
# Backward pass
loss.backward()
optimizer.step()
# Compute metrics
with torch.no_grad():
metrics, _ = compute_metrics(model, batch, device)
return loss, metrics
def validation_step(
model: torch.nn.Module,
batch: Dict[str, torch.Tensor],
device: torch.device
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""Perform a single validation step.
Args:
model: Model to evaluate
batch: Dictionary containing 'image', 'audio', and 'crow_id'
device: Device to run validation on
Returns:
Tuple of (loss tensor, metrics dict)
"""
model.eval()
with torch.no_grad():
# Get data
images = batch['image'].to(device)
audio = None
if batch['audio'] is not None:
audio = {
'mel_spec': batch['audio']['mel_spec'].to(device),
'chroma': batch['audio']['chroma'].to(device)
}
# Convert labels to strings
if isinstance(batch['crow_id'], (list, tuple)):
labels = [str(l) for l in batch['crow_id']]
else:
labels = [str(batch['crow_id'])]
# Forward pass
embeddings = model(images, audio)
# Create triplets
unique_labels = sorted(set(labels))
if len(unique_labels) < 2:
return torch.tensor(0.0, device=device), {
'accuracy': 0.0,
'precision': 0.0,
'recall': 0.0,
'f1': 0.0,
'same_crow_mean': 0.0,
'same_crow_std': 0.0,
'diff_crow_mean': 0.0,
'diff_crow_std': 0.0
}
triplets = []
for i, (emb, label) in enumerate(zip(embeddings, labels)):
pos_indices = [j for j, l in enumerate(labels) if l == label and j != i]
neg_indices = [j for j, l in enumerate(labels) if l != label]
if not pos_indices or not neg_indices:
continue
pos_idx = np.random.choice(pos_indices)
neg_idx = np.random.choice(neg_indices)
triplets.append((emb, embeddings[pos_idx], embeddings[neg_idx]))
if not triplets:
return torch.tensor(0.0, device=device), {
'accuracy': 0.0,
'precision': 0.0,
'recall': 0.0,
'f1': 0.0,
'same_crow_mean': 0.0,
'same_crow_std': 0.0,
'diff_crow_mean': 0.0,
'diff_crow_std': 0.0
}
# Compute loss
anchors, positives, negatives = zip(*triplets)
loss = compute_triplet_loss(
torch.stack(anchors),
torch.stack(positives),
torch.stack(negatives),
margin=1.0
)
# Compute metrics
metrics, _ = compute_metrics(model, batch, device)
return loss, metrics
def compute_triplet_loss(
anchor: torch.Tensor,
positive: torch.Tensor,
negative: torch.Tensor,
margin: float = 1.0
) -> torch.Tensor:
"""Compute triplet loss for a batch of embeddings.
Args:
anchor: Anchor embeddings of shape (batch_size, embed_dim)
positive: Positive embeddings of shape (batch_size, embed_dim)
negative: Negative embeddings of shape (batch_size, embed_dim)
margin: Margin for triplet loss
Returns:
Scalar tensor containing the triplet loss
"""
# Compute distances using pairwise_distance
pos_dist = F.pairwise_distance(anchor, positive)
neg_dist = F.pairwise_distance(anchor, negative)
# Compute triplet loss with margin
# Use a small epsilon only for numerical stability in the clamp
# This ensures we don't get exactly zero loss when distances are equal
loss = torch.clamp(pos_dist - neg_dist + margin, min=0.0)
# Add a small epsilon only to non-zero losses to avoid numerical issues
# This preserves the zero loss case while preventing numerical instability
mask = loss > 0
loss = torch.where(mask, loss + 1e-8, loss)
return loss.mean()