Skip to content

Commit 6b08775

Browse files
committed
minor fix
1 parent 0e4648c commit 6b08775

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

sure/distance_metrics/distance.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ def distance_to_closest_record(
170170
categorical_features = np.array(X.dtypes)==pl.Utf8
171171
if not isinstance(categorical_features, np.ndarray):
172172
categorical_features = np.array(categorical_features)
173+
X_categorical = X[:, categorical_features]
174+
Y_categorical = Y[:, categorical_features]
173175

174176
# Both datafrmaes are turned into numpy arrays
175177
if not isinstance(X, np.ndarray):
@@ -192,9 +194,9 @@ def distance_to_closest_record(
192194

193195
# Apply the encoder on all categorical columns at once
194196
# Categorical feature matrix of X (num_rows_X x num_cat_feat)
195-
X_categorical = encoder.fit_transform(X[:, categorical_features])
197+
X_categorical = encoder.fit_transform(X_categorical)
196198
# Categorical feature matrix of Y (num_rows_Y x num_cat_feat)
197-
Y_categorical = encoder.transform(Y[:, categorical_features])
199+
Y_categorical = encoder.transform(Y_categorical)
198200

199201
# Numerical feature matrix of X (num_rows_X x num_num_feat)
200202
X_numerical = X[:, np.logical_not(categorical_features)].astype("float32")

0 commit comments

Comments
 (0)