diff --git a/medcat-v2/medcat/components/addons/meta_cat/ml_utils.py b/medcat-v2/medcat/components/addons/meta_cat/ml_utils.py index 08cd701b1..fa8c5c615 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/ml_utils.py +++ b/medcat-v2/medcat/components/addons/meta_cat/ml_utils.py @@ -144,14 +144,11 @@ def split_list_train_test(data: list, test_size: float, shuffle: bool = True Returns: tuple: The train data, and the test data. """ - if shuffle: - random.shuffle(data) - X_features = [x[:-1] for x in data] y_labels = [x[-1] for x in data] X_train, X_test, y_train, y_test = train_test_split( - X_features, y_labels, test_size=test_size, random_state=42) + X_features, y_labels, test_size=test_size, shuffle=shuffle) train_data = [x + [y] for x, y in zip(X_train, y_train)] test_data = [x + [y] for x, y in zip(X_test, y_test)] diff --git a/v1/medcat/medcat/utils/meta_cat/ml_utils.py b/v1/medcat/medcat/utils/meta_cat/ml_utils.py index 0d75eabe8..6050e4a24 100644 --- a/v1/medcat/medcat/utils/meta_cat/ml_utils.py +++ b/v1/medcat/medcat/utils/meta_cat/ml_utils.py @@ -132,14 +132,10 @@ def split_list_train_test(data: List, test_size: float, shuffle: bool = True) -> Returns: Tuple: The train data, and the test data. """ - if shuffle: - random.shuffle(data) - X_features = [x[:-1] for x in data] y_labels = [x[-1] for x in data] - X_train, X_test, y_train, y_test = train_test_split(X_features, y_labels, test_size=test_size, - random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X_features, y_labels, test_size=test_size, shuffle=shuffle) train_data = [x + [y] for x, y in zip(X_train, y_train)] test_data = [x + [y] for x, y in zip(X_test, y_test)]