Skip to content

Commit e2cdd42

Browse files
Moved RCV1 elif Block
1 parent 4550389 commit e2cdd42

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

data.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,18 @@ def prepare_dataset(dataset, out, vocab_size, **params):
442442
train.documents, test.documents = all_data.documents[:split_index], all_data.documents[split_index:]
443443
train.data, test.data = all_data.data[:split_index], all_data.data[split_index:]
444444
train.labels, test.labels = all_data.labels[:split_index], all_data.labels[split_index:]
445+
elif dataset == "RCV1":
446+
print("Preparing data...")
447+
all_data = TextRCV1()
448+
all_data.preprocess(out=out, vocab_size=vocab_size, **params)
449+
450+
# Split train/test set
451+
train = copy.deepcopy(all_data)
452+
test = copy.deepcopy(all_data)
453+
split_index = all_data.data.shape[0] // 2 # according to Bruna's paper & Hinton's dropout paper
454+
train.documents, test.documents = all_data.documents[:split_index], all_data.documents[split_index:]
455+
train.data, test.data = all_data.data[:split_index], all_data.data[split_index:]
456+
train.labels, test.labels = all_data.labels[:split_index], all_data.labels[split_index:]
445457
elif dataset == "RCV1-Vectors-Original":
446458
assert out == "tfidf"
447459
assert vocab_size == None
@@ -467,18 +479,6 @@ def prepare_dataset(dataset, out, vocab_size, **params):
467479
split_index = all_data.data.shape[0] // 2 # according to Bruna's paper & Hinton's dropout paper
468480
train.data, test.data = all_data.data[:split_index], all_data.data[split_index:]
469481
train.labels, test.labels = all_data.labels[:split_index], all_data.labels[split_index:]
470-
elif dataset == "RCV1":
471-
print("Preparing data...")
472-
all_data = TextRCV1()
473-
all_data.preprocess(out=out, vocab_size=vocab_size, **params)
474-
475-
# Split train/test set
476-
train = copy.deepcopy(all_data)
477-
test = copy.deepcopy(all_data)
478-
split_index = all_data.data.shape[0] // 2 # according to Bruna's paper & Hinton's dropout paper
479-
train.documents, test.documents = all_data.documents[:split_index], all_data.documents[split_index:]
480-
train.data, test.data = all_data.data[:split_index], all_data.data[split_index:]
481-
train.labels, test.labels = all_data.labels[:split_index], all_data.labels[split_index:]
482482

483483
return train, test
484484

0 commit comments

Comments
 (0)