diff --git a/implicit/evaluation.pyx b/implicit/evaluation.pyx index f25d0da..e029520 100644 --- a/implicit/evaluation.pyx +++ b/implicit/evaluation.pyx @@ -187,16 +187,17 @@ cpdef leave_k_out_split( # get only users with n + 1 interactions candidate_mask = counts > K + 1 + if sum(candidate_mask) == 0: + return ratings.tocsr(), csr_matrix(ratings.shape) + unique_candidate_users = unique_users[candidate_mask] # keep a given subset of users _only_ in the training set. if train_only_size > 0.0: - train_only_mask = ~np.isin( - unique_users, _choose(random_state, len(unique_users), train_only_size) - ) - candidate_mask = train_only_mask & candidate_mask + adjusted_ratio = min(1, (1 - train_only_size) / (unique_candidate_users.shape[0] / (unique_users.shape[0] + 1))) + train_only_mask = _choose(random_state, len(unique_candidate_users), adjusted_ratio) + unique_candidate_users = unique_candidate_users[train_only_mask] # get unique users who appear in the test set - unique_candidate_users = unique_users[candidate_mask] full_candidate_mask = np.isin(users, unique_candidate_users) # get all users, items and ratings that match specified requirements to be