File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -170,6 +170,8 @@ def distance_to_closest_record(
170170 categorical_features = np .array (X .dtypes )== pl .Utf8
171171 if not isinstance (categorical_features , np .ndarray ):
172172 categorical_features = np .array (categorical_features )
173+ X_categorical = X [:, categorical_features ]
174+ Y_categorical = Y [:, categorical_features ]
173175
174176 # Both datafrmaes are turned into numpy arrays
175177 if not isinstance (X , np .ndarray ):
@@ -192,9 +194,9 @@ def distance_to_closest_record(
192194
193195 # Apply the encoder on all categorical columns at once
194196 # Categorical feature matrix of X (num_rows_X x num_cat_feat)
195- X_categorical = encoder .fit_transform (X [:, categorical_features ] )
197+ X_categorical = encoder .fit_transform (X_categorical )
196198 # Categorical feature matrix of Y (num_rows_Y x num_cat_feat)
197- Y_categorical = encoder .transform (Y [:, categorical_features ] )
199+ Y_categorical = encoder .transform (Y_categorical )
198200
199201 # Numerical feature matrix of X (num_rows_X x num_num_feat)
200202 X_numerical = X [:, np .logical_not (categorical_features )].astype ("float32" )
You can’t perform that action at this time.
0 commit comments