From c94831bbc199feb07b9006251f1c39e463a0aa3b Mon Sep 17 00:00:00 2001 From: ZCheng <34483849+ZhiweiCheng2020@users.noreply.github.com> Date: Mon, 10 Apr 2023 20:47:57 +0200 Subject: [PATCH] Update TextCNN.py --- 2-1.TextCNN/TextCNN.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/2-1.TextCNN/TextCNN.py b/2-1.TextCNN/TextCNN.py index f19c139..b7fb373 100644 --- a/2-1.TextCNN/TextCNN.py +++ b/2-1.TextCNN/TextCNN.py @@ -16,7 +16,7 @@ def __init__(self): self.filter_list = nn.ModuleList([nn.Conv2d(1, num_filters, (size, embedding_size)) for size in filter_sizes]) def forward(self, X): - embedded_chars = self.W(X) # [batch_size, sequence_length, sequence_length] + embedded_chars = self.W(X) # [batch_size, sequence_length, embedding_size] embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size] pooled_outputs = [] @@ -81,4 +81,4 @@ def forward(self, X): if predict[0][0] == 0: print(test_text,"is Bad Mean...") else: - print(test_text,"is Good Mean!!") \ No newline at end of file + print(test_text,"is Good Mean!!")