-
Notifications
You must be signed in to change notification settings - Fork 166
/
Copy pathbert_spc.py
62 lines (54 loc) · 2.74 KB
/
bert_spc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# -*- coding: utf-8 -*-
# file: BERT_SPC.py
# author: songyouwei <[email protected]>
# Copyright (C) 2019. All Rights Reserved.
import torch
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertPooler
from pyabsa.network.sa_encoder import Encoder
class BERT_SPC(nn.Module):
inputs = ['text_bert_indices',
'left_text_bert_indices',
'right_text_bert_indices']
def __init__(self, bert, opt):
super(BERT_SPC, self).__init__()
self.bert = bert
self.opt = opt
self.linear = nn.Linear(opt.embed_dim, opt.embed_dim)
self.linear_window_2h = nn.Linear(2 * opt.embed_dim, opt.embed_dim) if self.opt.lsa else nn.Linear(opt.embed_dim, opt.embed_dim)
self.linear_window_3h = nn.Linear(3 * opt.embed_dim, opt.embed_dim) if self.opt.lsa else nn.Linear(opt.embed_dim, opt.embed_dim)
self.encoder = Encoder(bert.config, opt)
self.dropout = nn.Dropout(opt.dropout)
self.pooler = BertPooler(bert.config)
self.dense = nn.Linear(opt.embed_dim, opt.polarities_dim)
def forward(self, inputs):
res = {'logits': None}
if self.opt.lsa:
feat = self.bert(inputs['text_bert_indices'])['last_hidden_state']
left_feat = self.bert(inputs['left_text_bert_indices'])['last_hidden_state']
right_feat = self.bert(inputs['right_text_bert_indices'])['last_hidden_state']
if 'lr' == self.opt.window or 'rl' == self.opt.window:
if self.opt.eta >= 0:
cat_features = torch.cat((feat, self.opt.eta * left_feat, (1 - self.opt.eta) * right_feat), -1)
else:
cat_features = torch.cat((feat, left_feat, right_feat), -1)
sent_out = self.linear_window_3h(cat_features)
elif 'l' == self.opt.window:
sent_out = self.linear_window_2h(torch.cat((feat, left_feat), -1))
elif 'r' == self.opt.window:
sent_out = self.linear_window_2h(torch.cat((feat, right_feat), -1))
else:
raise KeyError('Invalid parameter:', self.opt.window)
cat_feat = self.linear(sent_out)
cat_feat = self.dropout(cat_feat)
cat_feat = self.encoder(cat_feat)
cat_feat = self.pooler(cat_feat)
res['logits'] = self.dense(cat_feat)
else:
cat_feat = self.bert(inputs['text_bert_indices'])['last_hidden_state']
cat_feat = self.linear(cat_feat)
cat_feat = self.dropout(cat_feat)
cat_feat = self.encoder(cat_feat)
cat_feat = self.pooler(cat_feat)
res['logits'] = self.dense(cat_feat)
return res