@@ -22,7 +22,7 @@ def correlation_coefficient(col1: pd.Series, col2: pd.Series) -> float:
2222
2323def handle_correlation_matrix (
2424 local_param_dict : Dict , drop_list : List , df : pd .DataFrame , chunk_size : int = 50
25- ) -> List [Tuple [ str , str ] ]:
25+ ) -> List [str ]:
2626 """Identifies highly correlated column pairs and adds them to a drop list.
2727
2828 This function calculates the correlation matrix of a DataFrame in chunks to
@@ -33,15 +33,15 @@ def handle_correlation_matrix(
3333 Args:
3434 local_param_dict: A dictionary containing local parameters, including
3535 the 'corr' threshold.
36- drop_list: A list to which correlated column pairs `(col1, col2)`
36+ drop_list: A list to which one column from each highly correlated pair
3737 will be appended.
3838 df: The input DataFrame to analyze.
3939 chunk_size: The number of columns to process in each chunk.
4040 Defaults to 50.
4141
4242 Returns:
43- A list of unique tuples, where each tuple contains a pair of
44- column names that are correlated above the threshold .
43+ An updated list of columns to drop, including one column from each
44+ highly correlated pair .
4545 """
4646
4747 # Define the correlation threshold
@@ -55,51 +55,46 @@ def handle_correlation_matrix(
5555 numeric_columns = df .select_dtypes (include = ["number" ]).columns
5656 df_numeric = df [numeric_columns ]
5757
58- # Split columns into chunks
58+ if df_numeric .empty :
59+ return []
60+
61+ n_cols = len (df_numeric .columns )
62+ to_drop = set ()
63+ processed_cols = set ()
64+
65+ # Split columns into chunks for memory efficiency
5966 column_chunks = [
6067 df_numeric .columns [i : i + chunk_size ]
61- for i in range (0 , len ( df_numeric . columns ) , chunk_size )
68+ for i in range (0 , n_cols , chunk_size )
6269 ]
6370
64- # Iterate through each column chunk
65- for chunk in tqdm (column_chunks , desc = "Calculating Correlations" ):
66- # Calculate the correlation coefficients for the current chunk
67- try :
68- # Using abs() to consider both positive and negative correlations
69- correlations = df_numeric [chunk ].corr ().abs ()
70- except Exception as e :
71- logger .error (
72- "Encountered exception while calculating correlations for chunk" , chunk
73- )
74- logger .error (e )
75- continue
76-
77- # Iterate through each column in the chunk
78- for col in chunk :
79- # Filter columns with correlation coefficient greater than the threshold
80- try :
81- # Exclude self-correlation (which is always 1)
82- correlated_cols = correlations [col ][
83- (correlations [col ] > threshold ) & (correlations [col ] <= 1.0 )
84- ].index .tolist ()
85- # Explicitly remove self-correlation if present
86- if col in correlated_cols :
87- correlated_cols .remove (col )
88- except KeyError :
89- logger .error (
90- "Encountered KeyError while calculating correlations for column" ,
91- col ,
92- )
93- logger .error ("Continuing with an empty list of correlated columns" )
94- correlated_cols = []
95-
96- # Add the correlated columns to the list
97- drop_list .extend ([(col , corr_col ) for corr_col in correlated_cols ])
98-
99- # Remove duplicates from the list
100- # A frozenset is used to make pairs order-independent, e.g., (a, b) is same as (b, a)
101- unique_pairs = {frozenset (pair ) for pair in drop_list }
102- # Convert back to list of tuples
103- drop_list = [tuple (pair ) for pair in unique_pairs ]
104-
105- return drop_list
71+ with tqdm (total = n_cols , desc = "Calculating Correlations" ) as pbar :
72+ for i , chunk_cols in enumerate (column_chunks ):
73+ # Define the columns to correlate against: the current chunk and all subsequent chunks
74+ remaining_cols = df_numeric .columns [i * chunk_size :]
75+
76+ # Calculate correlation for the current slice of the matrix
77+ corr_matrix_chunk = df_numeric [remaining_cols ].corr ().abs ()
78+
79+ # We only need to check correlations of the current chunk against all remaining columns
80+ # This is equivalent to the top-left block of the chunk's correlation matrix
81+ sub_matrix = corr_matrix_chunk .loc [chunk_cols , :]
82+
83+ # Find highly correlated pairs
84+ for col1 in chunk_cols :
85+ # Skip if this column is already marked to be dropped
86+ if col1 in to_drop :
87+ continue
88+ # Find correlations above the threshold, excluding self-correlation
89+ correlated_series = sub_matrix .loc [col1 ][sub_matrix .loc [col1 ] > threshold ]
90+ for col2 , _ in correlated_series .items ():
91+ # Ensure we don't drop both columns in a pair.
92+ # If col2 is not already in to_drop, add it.
93+ if col1 != col2 and col2 not in to_drop :
94+ to_drop .add (col2 )
95+ pbar .update (len (chunk_cols ))
96+
97+ logger .info (f"Identified { len (to_drop )} columns to drop due to high correlation." )
98+
99+ # Return a list of unique columns to drop
100+ return sorted (list (to_drop ))
0 commit comments