-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlosses.py
More file actions
181 lines (152 loc) · 9.08 KB
/
Copy pathlosses.py
File metadata and controls
181 lines (152 loc) · 9.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# losses.py
import torch
import torch.nn as nn
import torch.nn.functional as F
# MODIFIED: Import the new lambda name for MSE logit distillation
from config import (
LAMBDA_CE_FINAL, LAMBDA_CE_AUX,
LAMBDA_LOGIT_MSE_DISTILL, # Changed from LAMBDA_KL_DISTILL
# DISTILLATION_TEMP, # No longer directly needed for MSE on raw logits
LAMBDA_L2_HINT,
# For the test block, we might need these if not set by a calling script
NUM_CLASSES_ACTUAL, BATCH_SIZE # And other model params for test model
)
# It's generally better practice for modules like losses.py not to import model definitions
# for their own test blocks, but rather use mock data directly.
# For this specific __main__ block, we'll keep it as it was and assume config is available.
# We'll also need a simple way to mock model output or a dummy model.
# --- Loss Criteria ---
criterion_ce = nn.CrossEntropyLoss()
# MODIFIED: Criterion for Logit Distillation using MSE
criterion_logit_distill_mse = nn.MSELoss(reduction='mean')
# L2 Loss (Mean Squared Error) for hint loss (feature distillation)
criterion_l2_hint = nn.MSELoss(reduction='mean')
def compute_total_loss(model_outputs, labels):
aux_logits_list, intermediate_features_list_projected = model_outputs
num_total_heads = len(aux_logits_list)
if num_total_heads == 0:
return torch.tensor(0.0, device=labels.device, requires_grad=True), {}
final_logits = aux_logits_list[-1]
loss_ce_final = criterion_ce(final_logits, labels)
loss_ce_aux_sum = torch.tensor(0.0, device=labels.device)
num_student_heads = num_total_heads - 1
if num_student_heads > 0:
for i in range(num_student_heads):
loss_ce_aux_sum += criterion_ce(aux_logits_list[i], labels)
loss_logit_distill_mse_sum = torch.tensor(0.0, device=labels.device)
if num_student_heads > 0:
teacher_logits = final_logits.detach()
for i in range(num_student_heads):
student_logits = aux_logits_list[i]
mse_d_loss = criterion_logit_distill_mse(student_logits, teacher_logits)
loss_logit_distill_mse_sum += mse_d_loss
loss_l2_hint_sum = torch.tensor(0.0, device=labels.device)
if num_student_heads > 0 and len(intermediate_features_list_projected) == num_total_heads:
teacher_hint_features = intermediate_features_list_projected[-1]
for i in range(num_student_heads):
student_hint_features = intermediate_features_list_projected[i]
student_features_norm = F.normalize(student_hint_features, p=2, dim=1)
teacher_features_norm = F.normalize(teacher_hint_features, p=2, dim=1)
l2_h_loss = criterion_l2_hint(student_features_norm, teacher_features_norm)
loss_l2_hint_sum += l2_h_loss
elif num_student_heads > 0:
print(f"Warning: Mismatch for L2 hint loss. Logits: {num_total_heads}, Hints: {len(intermediate_features_list_projected)}. Skipping L2 hint loss.")
total_loss = (LAMBDA_CE_FINAL * loss_ce_final +
LAMBDA_CE_AUX * loss_ce_aux_sum +
LAMBDA_LOGIT_MSE_DISTILL * loss_logit_distill_mse_sum +
LAMBDA_L2_HINT * loss_l2_hint_sum)
loss_components = {
"total_loss": total_loss.item(),
"ce_final": loss_ce_final.item(),
"ce_aux_sum": loss_ce_aux_sum.item() if isinstance(loss_ce_aux_sum, torch.Tensor) else float(loss_ce_aux_sum),
"logit_mse_sum": loss_logit_distill_mse_sum.item() if isinstance(loss_logit_distill_mse_sum, torch.Tensor) else float(loss_logit_distill_mse_sum),
"l2_hint_sum": loss_l2_hint_sum.item() if isinstance(loss_l2_hint_sum, torch.Tensor) else float(loss_l2_hint_sum),
}
return total_loss, loss_components
if __name__ == "__main__":
# --- Test the loss computation (adapted for 3 heads and MSE) ---
# This test block is for running `python losses.py` directly.
# It will use parameters from config.py.
# Ensure config.py is in a state that this test can run with (e.g., relevant model params).
print("--- Testing losses.py (Standalone with MSE Logit Distillation) ---")
# Attempt to import necessary items from config for the test
# These defaults are just for this standalone test if config isn't fully set for a model run
try:
from config import RANDOM_SEED, NUM_CLASSES_ACTUAL as cfg_NUM_CLASSES_ACTUAL, \
BATCH_SIZE as cfg_BATCH_SIZE, \
HYBRIDSN_FC_HIDDEN_UNITS, HYBRIDSN_CONV2D_OUT_CHANNELS, \
PATCH_SIZE as cfg_PATCH_SIZE, \
HYBRIDSN_CONV2D_PADDING, HYBRIDSN_CONV2D_KERNEL_SIZE # For teacher_hint_dim calculation
from utils import set_seeds # Assuming utils.py is in the same directory or accessible
config_loaded = True
except ImportError:
print("Warning: Could not import all necessary items from config.py or utils.py for full test. Using placeholders.")
config_loaded = False
RANDOM_SEED = 42
cfg_NUM_CLASSES_ACTUAL = 16
cfg_BATCH_SIZE = 4
# Placeholder for teacher_hint_dim if config details are missing
# This means the mock features might not match what a real model would produce
# but it allows the loss computation structure to be tested.
H_DIM_PLACEHOLDER = 128
if config_loaded:
set_seeds(RANDOM_SEED)
num_classes_test = cfg_NUM_CLASSES_ACTUAL if cfg_NUM_CLASSES_ACTUAL is not None else 16
batch_size_test = cfg_BATCH_SIZE if cfg_BATCH_SIZE > 0 else 2
# Calculate a plausible teacher_hint_dim based on config (simplified from model)
_h, _w = cfg_PATCH_SIZE, cfg_PATCH_SIZE
# Simulate spatial reduction by 3D convs (assuming valid padding as per recent config)
for layer_conf_3d in config.HYBRIDSN_CONV3D_LAYERS:
_, kh, kw = layer_conf_3d["kernel_size"]
_h = (_h - kh) + 1; _w = (_w - kw) + 1
# Simulate spatial reduction by 2D conv
_h = (_h + 2*config.HYBRIDSN_CONV2D_PADDING - config.HYBRIDSN_CONV2D_KERNEL_SIZE) + 1
_w = (_w + 2*config.HYBRIDSN_CONV2D_PADDING - config.HYBRIDSN_CONV2D_KERNEL_SIZE) + 1
fc_input_for_teacher_calc = config.HYBRIDSN_CONV2D_OUT_CHANNELS * _h * _w
teacher_hint_dim_test = config.HYBRIDSN_FC_HIDDEN_UNITS[-1] if config.HYBRIDSN_FC_HIDDEN_UNITS else fc_input_for_teacher_calc
else: # Fallback if config couldn't be fully loaded
num_classes_test = 16
batch_size_test = 4
teacher_hint_dim_test = H_DIM_PLACEHOLDER
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Test params: NumClasses={num_classes_test}, BatchSize={batch_size_test}, TeacherHintDim={teacher_hint_dim_test}")
print(f"Lambdas: CE_Final={LAMBDA_CE_FINAL}, CE_Aux={LAMBDA_CE_AUX}, "
f"Logit_MSE={LAMBDA_LOGIT_MSE_DISTILL}, L2_Hint={LAMBDA_L2_HINT}")
# Mock data for 3 heads: [Student1_logits, Student2_logits, Teacher_logits]
# Corresponding hint features (all should be projected to teacher_hint_dim_test)
mock_logits_list = [
torch.randn(batch_size_test, num_classes_test, device=device),
torch.randn(batch_size_test, num_classes_test, device=device),
torch.randn(batch_size_test, num_classes_test, device=device)
]
mock_features_list_projected = [ # Assume these are already projected
torch.randn(batch_size_test, teacher_hint_dim_test, device=device),
torch.randn(batch_size_test, teacher_hint_dim_test, device=device),
torch.randn(batch_size_test, teacher_hint_dim_test, device=device)
]
mock_labels = torch.randint(0, num_classes_test, (batch_size_test,), device=device)
print(f"\nMock data shapes (for 3 heads):")
for i in range(len(mock_logits_list)):
print(f" Logits {i+1}: {mock_logits_list[i].shape}, Hint Features {i+1}: {mock_features_list_projected[i].shape}")
print(f"Labels shape: {mock_labels.shape}")
total_loss_val, components = compute_total_loss(
(mock_logits_list, mock_features_list_projected), mock_labels
)
print(f"\nComputed total loss: {total_loss_val.item():.4f}")
print("Loss components:")
for name, value in components.items():
value_float = value.item() if isinstance(value, torch.Tensor) else float(value)
print(f" {name}: {value_float:.4f}")
print("\n--- Testing with 1 head (no shallow/student heads) ---")
mock_logits_list_single = [torch.randn(batch_size_test, num_classes_test, device=device)]
mock_features_list_single = [torch.randn(batch_size_test, teacher_hint_dim_test, device=device)]
total_loss_single, components_single = compute_total_loss(
(mock_logits_list_single, mock_features_list_single), mock_labels
)
print(f"Computed total loss (1 head): {total_loss_single.item():.4f}")
print("Loss components (1 head):")
for name, value in components_single.items():
value_float = value.item() if isinstance(value, torch.Tensor) else float(value)
print(f" {name}: {value_float:.4f}")
# Expected: ce_aux_sum, logit_mse_sum, l2_hint_sum should be 0.0 for single head case.