-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_preprocessing.py
More file actions
210 lines (179 loc) · 10.5 KB
/
Copy pathdata_preprocessing.py
File metadata and controls
210 lines (179 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# data_preprocessing.py
import numpy as np
from scipy.io import loadmat
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import os
import config
from utils import set_seeds
def load_hsi_data(dataset_name_from_config):
data_file_name, gt_file_name = "", ""
data_key, gt_key = "", ""
if dataset_name_from_config == "IndianPines":
data_file_name, gt_file_name = "Indian_pines_corrected.mat", "Indian_pines_gt.mat"
data_key, gt_key = 'indian_pines_corrected', 'indian_pines_gt'
elif dataset_name_from_config == "Salinas":
data_file_name, gt_file_name = "Salinas_corrected.mat", "Salinas_gt.mat"
data_key, gt_key = 'salinas_corrected', 'salinas_gt' # Common keys for Salinas
elif dataset_name_from_config == "PaviaU": # Example for another dataset
data_file_name, gt_file_name = "PaviaU.mat", "PaviaU_gt.mat"
data_key, gt_key = 'paviaU', 'paviaU_gt'
else:
raise ValueError(f"Dataset '{dataset_name_from_config}' is not supported for loading.")
data_file_path = os.path.join(config.DATA_PATH, data_file_name)
gt_file_path = os.path.join(config.DATA_PATH, gt_file_name)
if not os.path.exists(data_file_path) or not os.path.exists(gt_file_path):
raise FileNotFoundError(
f"Dataset files ('{data_file_name}', '{gt_file_name}') not found in '{config.DATA_PATH}'. "
"Run download_data.py or place them manually."
)
data_mat = loadmat(data_file_path)
gt_mat = loadmat(gt_file_path)
# Attempt to find the correct key if standard ones don't work
if data_key not in data_mat:
print(f"Warning: Standard key '{data_key}' not found in {data_file_name}. Trying other common keys...")
# Common fallback keys (often the filename without extension)
potential_keys = [k for k in data_mat if not k.startswith('__')]
if potential_keys: data_key = potential_keys[0]
else: raise KeyError(f"No suitable data key found in {data_file_name}")
if gt_key not in gt_mat:
print(f"Warning: Standard key '{gt_key}' not found in {gt_file_name}. Trying other common keys...")
potential_keys = [k for k in gt_mat if not k.startswith('__')]
if potential_keys: gt_key = potential_keys[0]
else: raise KeyError(f"No suitable data key found in {gt_file_name}")
data = data_mat[data_key]
gt = gt_mat[gt_key]
print(f"Successfully loaded {dataset_name_from_config}:")
print(f" Original data shape: {data.shape}, GT shape: {gt.shape}")
unique_labels_in_gt = np.unique(gt)
config.NUM_CLASSES_ACTUAL = len(unique_labels_in_gt[unique_labels_in_gt != 0])
print(f" Number of classes (excluding background 0): {config.NUM_CLASSES_ACTUAL}")
return data, gt
def apply_pca(data, num_components):
h, w, c = data.shape
data_reshaped = data.reshape(-1, c)
scaler = StandardScaler()
data_scaled = scaler.fit_transform(data_reshaped.astype(np.float32))
actual_num_components = num_components
if num_components > data_scaled.shape[1]:
print(f"Warning: num_pca_components ({num_components}) > num features ({data_scaled.shape[1]}). Clamping to {data_scaled.shape[1]}.")
actual_num_components = data_scaled.shape[1]
if actual_num_components > data_scaled.shape[0]:
print(f"Warning: num_pca_components ({actual_num_components}) > num samples ({data_scaled.shape[0]}). Clamping to {data_scaled.shape[0]}.")
actual_num_components = data_scaled.shape[0]
if actual_num_components <= 0:
raise ValueError(f"Number of PCA components must be positive. Calculated: {actual_num_components}")
pca = PCA(n_components=actual_num_components, random_state=config.RANDOM_SEED)
data_pca_transformed = pca.fit_transform(data_scaled) # Renamed for clarity
print(f"PCA: Explained variance by {actual_num_components} components: {np.sum(pca.explained_variance_ratio_):.4f}")
return data_pca_transformed.reshape(h, w, actual_num_components), pca, scaler # Return transformed data, pca_obj, scaler_obj
def create_patches(data_pca, gt, patch_size):
h, w, c_pca = data_pca.shape
pad_width = patch_size // 2
padded_data = np.pad(data_pca,
((pad_width, pad_width), (pad_width, pad_width), (0,0)),
mode='constant', constant_values=0)
patches_list, labels_list, coordinates_list = [], [], []
for r_idx in range(h):
for c_idx in range(w):
label = gt[r_idx, c_idx]
if label != 0:
patch = padded_data[r_idx : r_idx + patch_size,
c_idx : c_idx + patch_size, :]
patches_list.append(patch)
labels_list.append(label - 1)
coordinates_list.append((r_idx, c_idx))
return np.array(patches_list, dtype=np.float32), \
np.array(labels_list, dtype=np.int64), \
np.array(coordinates_list)
def split_data(patches, labels, coordinates):
# ... (Your existing split_data function remains unchanged) ...
num_samples = len(labels)
indices = np.arange(num_samples)
train_val_indices, test_indices, y_train_val_strat, _ = train_test_split(
indices, labels,
test_size=config.TEST_RATIO,
random_state=config.RANDOM_SEED,
stratify=labels
)
effective_val_ratio = 0
if (1.0 - config.TEST_RATIO) > 1e-6:
effective_val_ratio = config.VAL_RATIO_FROM_TRAIN / (1.0 - config.TEST_RATIO)
if effective_val_ratio >= 1.0 or effective_val_ratio <= 1e-6 or len(y_train_val_strat) == 0:
if config.VAL_RATIO_FROM_TRAIN == 0 or len(y_train_val_strat) < 2 :
train_indices = train_val_indices
val_indices = np.array([], dtype=int)
print("Warning: Validation set is empty based on ratios or insufficient samples.")
else:
print(f"Warning: effective_val_ratio ({effective_val_ratio:.2f}) is problematic. Adjusting val split to small fixed proportion (e.g. 0.1 of train_val).")
min_val_samples = 2
test_size_for_val_split = min(0.1, effective_val_ratio if effective_val_ratio > 0 else 0.1)
if len(y_train_val_strat) * test_size_for_val_split < min_val_samples and len(y_train_val_strat) > min_val_samples : # ensure val set is not too small for stratify
test_size_for_val_split = min_val_samples / len(y_train_val_strat)
if len(y_train_val_strat) > min_val_samples / test_size_for_val_split if test_size_for_val_split > 0 else float('inf') :
train_indices, val_indices, _, _ = train_test_split(
train_val_indices, y_train_val_strat,
test_size=test_size_for_val_split,
random_state=config.RANDOM_SEED,
stratify=y_train_val_strat
)
else:
train_indices = train_val_indices
val_indices = np.array([], dtype=int)
print("Warning: Not enough samples in train_val for validation split. Val set is empty.")
else:
train_indices, val_indices, _, _ = train_test_split(
train_val_indices, y_train_val_strat,
test_size=effective_val_ratio,
random_state=config.RANDOM_SEED,
stratify=y_train_val_strat
)
X_train, y_train = patches[train_indices], labels[train_indices]
X_val, y_val = patches[val_indices] if len(val_indices) > 0 else np.array([]), \
labels[val_indices] if len(val_indices) > 0 else np.array([])
X_test, y_test = patches[test_indices], labels[test_indices]
print(f"Data split: Train={len(X_train)}, Val={len(X_val)}, Test={len(X_test)}")
return (X_train, y_train, None), (X_val, y_val, None), (X_test, y_test, None)
# MODIFIED: This function will now be the primary one called by main.py
def get_full_data_and_splits():
"""Loads data, performs PCA, creates patches, splits them, and returns all necessary components."""
original_hsi_data, original_gt = load_hsi_data(dataset_name_from_config=config.DATASET_NAME)
# Apply PCA to the entire HSI data to get fitted processors
data_pca_full, fitted_pca, fitted_scaler = apply_pca(original_hsi_data, config.NUM_PCA_COMPONENTS)
# Create patches from this PCA'd full data
all_patches, all_labels, all_coords = create_patches(data_pca_full, original_gt, config.PATCH_SIZE)
# Split these patches
(X_train, y_train, _), (X_val, y_val, _), (X_test, y_test, _) = split_data(
all_patches, all_labels, all_coords
)
return X_train, y_train, X_val, y_val, X_test, y_test, \
fitted_pca, fitted_scaler, original_gt, original_hsi_data, data_pca_full
# NEW HELPER: for applying PCA transform to full HSI using fitted objects
def apply_pca_transform_to_full_hsi(full_hsi_data, pca_obj, scaler_obj):
h, w, c = full_hsi_data.shape
data_reshaped = full_hsi_data.reshape(-1, c)
# Important: Use transform, not fit_transform, with the scaler and PCA objects
data_scaled = scaler_obj.transform(data_reshaped.astype(np.float32))
data_pca = pca_obj.transform(data_scaled)
return data_pca.reshape(h, w, -1) # Reshape to (h, w, num_pca_components)
if __name__ == "__main__":
set_seeds(config.RANDOM_SEED)
print(f"\n--- Testing Data Preprocessing for: {config.DATASET_NAME} ---")
# Call the new main data function
X_tr, y_tr, X_v, y_v, X_te, y_te, pca_o, scaler_o, gt_o, hsi_o, hsi_pca_o = get_full_data_and_splits()
print(f"\nShape of X_train: {X_tr.shape}, y_train: {y_tr.shape}")
if X_v.shape[0] > 0: print(f"Shape of X_val: {X_v.shape}, y_val: {y_v.shape}")
else: print("X_val is empty.")
print(f"Shape of X_test: {X_te.shape}, y_test: {y_te.shape}")
print(f"Number of classes (from config): {config.NUM_CLASSES_ACTUAL}")
if len(y_tr) > 0 : print(f"Min/Max label in y_train: {np.min(y_tr)}, {np.max(y_tr)}")
print(f"\nPCA object type: {type(pca_o)}")
print(f"Scaler object type: {type(scaler_o)}")
print(f"Original GT shape returned: {gt_o.shape}")
print(f"Original HSI data shape returned: {hsi_o.shape}")
print(f"Full PCA HSI data shape returned: {hsi_pca_o.shape}")
# Test the transform function (optional)
# test_pca_transformed_full = apply_pca_transform_to_full_hsi(hsi_o, pca_o, scaler_o)
# print(f"Test transform on full HSI data shape: {test_pca_transformed_full.shape}")
# assert np.allclose(hsi_pca_o, test_pca_transformed_full), "PCA transform function output mismatch with initial PCA."