Skip to content

Commit 7824ba9

Browse files
committed
minor fix
1 parent 6a45e83 commit 7824ba9

2 files changed

Lines changed: 51 additions & 58 deletions

File tree

ml_grid/pipeline/data_correlation_matrix.py

Lines changed: 43 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def correlation_coefficient(col1: pd.Series, col2: pd.Series) -> float:
2222

2323
def handle_correlation_matrix(
2424
local_param_dict: Dict, drop_list: List, df: pd.DataFrame, chunk_size: int = 50
25-
) -> List[Tuple[str, str]]:
25+
) -> List[str]:
2626
"""Identifies highly correlated column pairs and adds them to a drop list.
2727
2828
This function calculates the correlation matrix of a DataFrame in chunks to
@@ -33,15 +33,15 @@ def handle_correlation_matrix(
3333
Args:
3434
local_param_dict: A dictionary containing local parameters, including
3535
the 'corr' threshold.
36-
drop_list: A list to which correlated column pairs `(col1, col2)`
36+
drop_list: A list to which one column from each highly correlated pair
3737
will be appended.
3838
df: The input DataFrame to analyze.
3939
chunk_size: The number of columns to process in each chunk.
4040
Defaults to 50.
4141
4242
Returns:
43-
A list of unique tuples, where each tuple contains a pair of
44-
column names that are correlated above the threshold.
43+
An updated list of columns to drop, including one column from each
44+
highly correlated pair.
4545
"""
4646

4747
# Define the correlation threshold
@@ -55,51 +55,46 @@ def handle_correlation_matrix(
5555
numeric_columns = df.select_dtypes(include=["number"]).columns
5656
df_numeric = df[numeric_columns]
5757

58-
# Split columns into chunks
58+
if df_numeric.empty:
59+
return []
60+
61+
n_cols = len(df_numeric.columns)
62+
to_drop = set()
63+
processed_cols = set()
64+
65+
# Split columns into chunks for memory efficiency
5966
column_chunks = [
6067
df_numeric.columns[i : i + chunk_size]
61-
for i in range(0, len(df_numeric.columns), chunk_size)
68+
for i in range(0, n_cols, chunk_size)
6269
]
6370

64-
# Iterate through each column chunk
65-
for chunk in tqdm(column_chunks, desc="Calculating Correlations"):
66-
# Calculate the correlation coefficients for the current chunk
67-
try:
68-
# Using abs() to consider both positive and negative correlations
69-
correlations = df_numeric[chunk].corr().abs()
70-
except Exception as e:
71-
logger.error(
72-
"Encountered exception while calculating correlations for chunk", chunk
73-
)
74-
logger.error(e)
75-
continue
76-
77-
# Iterate through each column in the chunk
78-
for col in chunk:
79-
# Filter columns with correlation coefficient greater than the threshold
80-
try:
81-
# Exclude self-correlation (which is always 1)
82-
correlated_cols = correlations[col][
83-
(correlations[col] > threshold) & (correlations[col] <= 1.0)
84-
].index.tolist()
85-
# Explicitly remove self-correlation if present
86-
if col in correlated_cols:
87-
correlated_cols.remove(col)
88-
except KeyError:
89-
logger.error(
90-
"Encountered KeyError while calculating correlations for column",
91-
col,
92-
)
93-
logger.error("Continuing with an empty list of correlated columns")
94-
correlated_cols = []
95-
96-
# Add the correlated columns to the list
97-
drop_list.extend([(col, corr_col) for corr_col in correlated_cols])
98-
99-
# Remove duplicates from the list
100-
# A frozenset is used to make pairs order-independent, e.g., (a, b) is same as (b, a)
101-
unique_pairs = {frozenset(pair) for pair in drop_list}
102-
# Convert back to list of tuples
103-
drop_list = [tuple(pair) for pair in unique_pairs]
104-
105-
return drop_list
71+
with tqdm(total=n_cols, desc="Calculating Correlations") as pbar:
72+
for i, chunk_cols in enumerate(column_chunks):
73+
# Define the columns to correlate against: the current chunk and all subsequent chunks
74+
remaining_cols = df_numeric.columns[i * chunk_size :]
75+
76+
# Calculate correlation for the current slice of the matrix
77+
corr_matrix_chunk = df_numeric[remaining_cols].corr().abs()
78+
79+
# We only need to check correlations of the current chunk against all remaining columns
80+
# This is equivalent to the top-left block of the chunk's correlation matrix
81+
sub_matrix = corr_matrix_chunk.loc[chunk_cols, :]
82+
83+
# Find highly correlated pairs
84+
for col1 in chunk_cols:
85+
# Skip if this column is already marked to be dropped
86+
if col1 in to_drop:
87+
continue
88+
# Find correlations above the threshold, excluding self-correlation
89+
correlated_series = sub_matrix.loc[col1][sub_matrix.loc[col1] > threshold]
90+
for col2, _ in correlated_series.items():
91+
# Ensure we don't drop both columns in a pair.
92+
# If col2 is not already in to_drop, add it.
93+
if col1 != col2 and col2 not in to_drop:
94+
to_drop.add(col2)
95+
pbar.update(len(chunk_cols))
96+
97+
logger.info(f"Identified {len(to_drop)} columns to drop due to high correlation.")
98+
99+
# Return a list of unique columns to drop
100+
return sorted(list(to_drop))

tests/test_data_correlation_matrix_handle.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,20 @@ def test_negative_correlation(self):
5252
result = handle_correlation_matrix(
5353
self.local_param_dict, self.drop_list, self.df, chunk_size=3
5454
)
55-
# Should find both ('A', 'B') and ('A', 'C')
56-
self.assertEqual(len(result), 2)
57-
self.assertIn(frozenset(("A", "B")), [frozenset(p) for p in result])
58-
self.assertIn(frozenset(("A", "C")), [frozenset(p) for p in result])
55+
# Pairs found: (A,B), (A,C), (B,C)
56+
# Logic: drop B (from A,B), drop C (from A,C). B is already processed when we check (B,C).
57+
# The function now correctly identifies that 'B' and 'C' should be dropped.
58+
self.assertEqual(len(result), 2, f"Expected 2 items, but got {len(result)}: {result}")
59+
self.assertCountEqual(result, ["B", "C"])
5960

6061
def test_all_columns_in_single_chunk(self):
6162
result = handle_correlation_matrix(
6263
self.local_param_dict, self.drop_list, self.df, chunk_size=3
6364
)
6465

65-
# fmt: off
66-
# Since the function uses frozenset for order-independent pairs, the output will contain
67-
# either ('A', 'B') or ('B', 'A'), but not both. We check for the presence of one.
68-
self.assertTrue(len(result) == 1)
69-
self.assertEqual(frozenset(result[0]), frozenset(('A', 'B')))
70-
# fmt: on
66+
# 'B' is highly correlated with 'A', so 'B' should be added to the drop list.
67+
self.assertEqual(len(result), 1, f"Expected 1 item, but got {len(result)}: {result}")
68+
self.assertCountEqual(result, ["B"])
7169

7270

7371
if __name__ == "__main__":

0 commit comments

Comments
 (0)