@@ -442,6 +442,18 @@ def prepare_dataset(dataset, out, vocab_size, **params):
442
442
train .documents , test .documents = all_data .documents [:split_index ], all_data .documents [split_index :]
443
443
train .data , test .data = all_data .data [:split_index ], all_data .data [split_index :]
444
444
train .labels , test .labels = all_data .labels [:split_index ], all_data .labels [split_index :]
445
+ elif dataset == "RCV1" :
446
+ print ("Preparing data..." )
447
+ all_data = TextRCV1 ()
448
+ all_data .preprocess (out = out , vocab_size = vocab_size , ** params )
449
+
450
+ # Split train/test set
451
+ train = copy .deepcopy (all_data )
452
+ test = copy .deepcopy (all_data )
453
+ split_index = all_data .data .shape [0 ] // 2 # according to Bruna's paper & Hinton's dropout paper
454
+ train .documents , test .documents = all_data .documents [:split_index ], all_data .documents [split_index :]
455
+ train .data , test .data = all_data .data [:split_index ], all_data .data [split_index :]
456
+ train .labels , test .labels = all_data .labels [:split_index ], all_data .labels [split_index :]
445
457
elif dataset == "RCV1-Vectors-Original" :
446
458
assert out == "tfidf"
447
459
assert vocab_size == None
@@ -467,18 +479,6 @@ def prepare_dataset(dataset, out, vocab_size, **params):
467
479
split_index = all_data .data .shape [0 ] // 2 # according to Bruna's paper & Hinton's dropout paper
468
480
train .data , test .data = all_data .data [:split_index ], all_data .data [split_index :]
469
481
train .labels , test .labels = all_data .labels [:split_index ], all_data .labels [split_index :]
470
- elif dataset == "RCV1" :
471
- print ("Preparing data..." )
472
- all_data = TextRCV1 ()
473
- all_data .preprocess (out = out , vocab_size = vocab_size , ** params )
474
-
475
- # Split train/test set
476
- train = copy .deepcopy (all_data )
477
- test = copy .deepcopy (all_data )
478
- split_index = all_data .data .shape [0 ] // 2 # according to Bruna's paper & Hinton's dropout paper
479
- train .documents , test .documents = all_data .documents [:split_index ], all_data .documents [split_index :]
480
- train .data , test .data = all_data .data [:split_index ], all_data .data [split_index :]
481
- train .labels , test .labels = all_data .labels [:split_index ], all_data .labels [split_index :]
482
482
483
483
return train , test
484
484
0 commit comments