@@ -233,17 +233,13 @@ class TextRCV1(TextDataset):
233
233
Paper: http://www.jmlr.org/papers/volume5/lewis04a/lewis04a.pdf
234
234
"""
235
235
236
- def __init__ (self , shuffle = True , random_state = 42 ):
236
+ def __init__ (self ):
237
237
self .documents , self .labels , self .class_names = self ._load ()
238
238
239
239
assert len (self .class_names ) == 103 # 103 categories according to LYRL2004
240
240
N , C = self .labels .shape
241
241
assert C == len (self .class_names )
242
242
243
- if shuffle :
244
- # TODO: Implement shuffling of dataset? Violates chronological split recommended in LYRL2004.
245
- pass
246
-
247
243
def preprocess (self , out , vocab_size = 2000 , ** params ):
248
244
# Selection of classes
249
245
keep = ['C11' , 'C12' , 'C13' , 'C14' , 'C15' , 'C16' , 'C17' , 'C18' , 'C21' , 'C22' , 'C23' , 'C24' ,
@@ -324,6 +320,11 @@ class TextRCV1_Vectors(TextDataset):
324
320
"""
325
321
326
322
def __init__ (self , subset , shuffle = True , random_state = 42 ):
323
+ if subset == "all" :
324
+ shuffle = False # chronological split violated if shuffled
325
+ else :
326
+ shuffle = shuffle
327
+
327
328
dataset = sklearn .datasets .fetch_rcv1 (subset = subset , shuffle = shuffle , random_state = random_state )
328
329
self .data = dataset .data
329
330
self .labels = dataset .target
@@ -334,8 +335,7 @@ def __init__(self, subset, shuffle=True, random_state=42):
334
335
assert C == len (self .class_names )
335
336
336
337
N , V = self .data .shape
337
- # TODO: Hacky workaround to create placeholder value
338
- self .vocab = np .zeros (V )
338
+ self .vocab = np .zeros (V ) # hacky workaround to create placeholder value
339
339
self .orig_vocab_size = V
340
340
341
341
def preprocess (self , out , ** params ):
@@ -459,7 +459,7 @@ def prepare_dataset(dataset, out, vocab_size, **params):
459
459
assert vocab_size == None
460
460
461
461
print ("Preparing data..." )
462
- all_data = TextRCV1_Vectors (subset = "all" ) # TODO: shuffle=False? Chronological split violated?
462
+ all_data = TextRCV1_Vectors (subset = "all" )
463
463
all_data .preprocess (out = "tfidf" , ** params )
464
464
465
465
# Split train/test set
0 commit comments