Skip to content

Commit 5fc85ac

Browse files
Removed Shuffling of RCV1 & RCV1-Vectors
1 parent 2fd2ecf commit 5fc85ac

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

data.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,17 +233,13 @@ class TextRCV1(TextDataset):
233233
Paper: http://www.jmlr.org/papers/volume5/lewis04a/lewis04a.pdf
234234
"""
235235

236-
def __init__(self, shuffle=True, random_state=42):
236+
def __init__(self):
237237
self.documents, self.labels, self.class_names = self._load()
238238

239239
assert len(self.class_names) == 103 # 103 categories according to LYRL2004
240240
N, C = self.labels.shape
241241
assert C == len(self.class_names)
242242

243-
if shuffle:
244-
# TODO: Implement shuffling of dataset? Violates chronological split recommended in LYRL2004.
245-
pass
246-
247243
def preprocess(self, out, vocab_size=2000, **params):
248244
# Selection of classes
249245
keep = ['C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C21', 'C22', 'C23', 'C24',
@@ -324,6 +320,11 @@ class TextRCV1_Vectors(TextDataset):
324320
"""
325321

326322
def __init__(self, subset, shuffle=True, random_state=42):
323+
if subset == "all":
324+
shuffle = False # chronological split violated if shuffled
325+
else:
326+
shuffle = shuffle
327+
327328
dataset = sklearn.datasets.fetch_rcv1(subset=subset, shuffle=shuffle, random_state=random_state)
328329
self.data = dataset.data
329330
self.labels = dataset.target
@@ -334,8 +335,7 @@ def __init__(self, subset, shuffle=True, random_state=42):
334335
assert C == len(self.class_names)
335336

336337
N, V = self.data.shape
337-
# TODO: Hacky workaround to create placeholder value
338-
self.vocab = np.zeros(V)
338+
self.vocab = np.zeros(V) # hacky workaround to create placeholder value
339339
self.orig_vocab_size = V
340340

341341
def preprocess(self, out, **params):
@@ -459,7 +459,7 @@ def prepare_dataset(dataset, out, vocab_size, **params):
459459
assert vocab_size == None
460460

461461
print("Preparing data...")
462-
all_data = TextRCV1_Vectors(subset="all") # TODO: shuffle=False? Chronological split violated?
462+
all_data = TextRCV1_Vectors(subset="all")
463463
all_data.preprocess(out="tfidf", **params)
464464

465465
# Split train/test set

0 commit comments

Comments
 (0)