在roberta上添加crf层做文本标注任务
本项目的主要目的是在保持roberta原本结构和代码逻辑的前提下,借用fairseq给roberta添加crf层,在使用本代码前你需要克隆 fairseq.
- 克隆 fairseq.
- 替换
fairseq/fairseq/models/roberta/model.py
,fairseq/fairseq/models/roberta/hub_interface.py
,fairseq/fairseq/trainer.py
, 不要担心改动会影响其他任务,被改动的代码后面会附上 - 把 tasks 和 criterions文件夹下面的文件移动到
fairseq/fairseq/tasks
andfairseq/fairseq/criterions
- 运行
kaggel_ner_encoder.py
给训练数据做编码 - 运行
pre_process.sh
生成训练和验证文件 - 运行
train.sh
finetuning roberta
点击 这里 下载数据
difference in trainer.py
try:
self.get_model().load_state_dict(
state["model"], strict=True, args=self.args
)
if utils.has_parameters(self.get_criterion()):
+ state["criterion"] = self.get_criterion().state_dict() # add this code to load crf layer while loading state dict
self.get_criterion().load_state_dict(
state["criterion"], strict=True
)
difference in hub_interface.py
+ def predict_label(self, head: str, tokens: torch.LongTensor, return_logits: bool = False):
+ features = self.extract_features(tokens.to(device=self.device))
+ path_score, path = self.model.labeling_heads[head].forward_decode(features)
+ return path_score, path
difference in model.py
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
+from fairseq.modules.dynamic_crf_layer import DynamicCRF as CRF
from .hub_interface import RobertaHubInterface
logger = logging.getLogger(__name__)
@register_model('roberta')
class RobertaModel(FairseqEncoderModel):
...
self.classification_heads = nn.ModuleDict()
+ self.labeling_heads = nn.ModuleDict()
@staticmethod
- def forward(self, src_tokens, features_only=False, return_all_hiddens=False, classification_head_name=None, **kwargs):
- if classification_head_name is not None:
+ # add arg `labeling_head_name`
+ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, classification_head_name=None, labeling_head_name=None, **kwargs):
+ if (classification_head_name or labeling_head_name) is not None:
features_only = True
x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)
if classification_head_name is not None:
x = self.classification_heads[classification_head_name](x)
+ if labeling_head_name is not None:
+ x = self.labeling_heads[labeling_head_name](x, **kwargs)
return x, extra
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output[0].float()
if log_probs:
return F.log_softmax(logits, dim=-1)
else:
return F.softmax(logits, dim=-1)
...
+ # according to `register_classificaation_head`
+ def register_labeling_head(self, name, num_tags=None, inner_dim=None, **kwargs):
+ """Register a labeling head."""
+ if name in self.labeling_heads:
+ prev_num_tags = self.labeling_heads[name].dense.out_features
+ prev_inner_dim = self.labeling_heads[name].dense.in_features
+ if num_tags != prev_num_tags or inner_dim != prev_inner_dim:
+ logger.warning(
+ 're-registering head "{}" with num_tags {} (prev: {}) '
+ 'and inner_dim {} (prev: {})'.format(
+ name, num_tags, prev_num_tags, inner_dim, prev_inner_dim
+ )
+ )
+ self.labeling_heads[name] = RobertaLabelingHead(
+ self.args.encoder_embed_dim,
+ inner_dim or self.args.encoder_embed_dim,
+ num_tags,
+ self.args.pooler_dropout,
+ self.args.quant_noise_pq,
+ self.args.quant_noise_pq_block_size,
+ )
...
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + '.' if name != '' else ''
# rename decoder -> encoder before upgrading children modules
for k in list(state_dict.keys()):
if k.startswith(prefix + 'decoder'):
new_k = prefix + 'encoder' + k[len(prefix + 'decoder'):]
state_dict[new_k] = state_dict[k]
del state_dict[k]
# upgrade children modules
super().upgrade_state_dict_named(state_dict, name)
# Handle new classification heads present in the state dict.
- current_head_names = (
- [] if not hasattr(self, 'classification_heads')
- else self.classification_heads.keys()
- )
+ if hasattr(self, 'classification_heads'):
+ current_head_names = (self.classification_heads.keys())
+ elif hasattr(self, 'labeling_heads'):
+ current_head_names = (self.labeling_heads.keys())
+ else:
+ current_head_names = ([])
keys_to_delete = []
for k in state_dict.keys():
- if not k.startswith(prefix + 'classification_heads.'):
- continue
-
- head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
- num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
- inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
+ if not (k.startswith(prefix + 'classification_heads.') or k.startswith(prefix + 'labeling_heads.')):
+ continue
+ elif k.startswith(prefix + 'classification_heads.'):
+ head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
+ num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
+ inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
+ elif k.startswith(prefix + 'labeling_heads.'):
+ head_name = k[len(prefix + 'labeling_heads.'):].split('.')[0]
+ num_classes = state_dict[prefix + 'labeling_heads.' + head_name + '.dense.weight'].size(0)
+ inner_dim = state_dict[prefix + 'labeling_heads.' + head_name + '.dense.weight'].size(1)
if getattr(self.args, 'load_checkpoint_heads', False):
- if head_name not in current_head_names:
- self.register_classification_head(head_name, num_classes, inner_dim)
+ if (head_name not in current_head_names
+ and k.startswith(prefix + 'classification_heads.')):
+ self.register_classification_head(head_name, num_classes, inner_dim)
+ elif (head_name not in current_head_names
+ and k.startswith(prefix + 'labeling_heads.')):
+ self.register_labeling_head(head_name, num_classes, inner_dim)
else:
if head_name not in current_head_names:
logger.warning(
'deleting classification head ({}) from checkpoint '
'not present in current model: {}'.format(head_name, k)
)
keys_to_delete.append(k)
elif (
num_classes != self.classification_heads[head_name].out_proj.out_features
or inner_dim != self.classification_heads[head_name].dense.out_features
)
+ or (num_classes != self.labeling_heads[head_name].dense.weight):
logger.warning(
'deleting classification head ({}) from checkpoint '
'with different dimensions than current model: {}'.format(head_name, k)
)
keys_to_delete.append(k)
for k in keys_to_delete:
del state_dict[k]
# Copy any newly-added classification heads into the state dict
# with their current weights.
if hasattr(self, 'classification_heads'):
cur_state = self.classification_heads.state_dict()
for k, v in cur_state.items():
if prefix + 'classification_heads.' + k not in state_dict:
logger.info('Overwriting ' + prefix + 'classification_heads.' + k)
state_dict[prefix + 'classification_heads.' + k] = v
+ if hasattr(self, 'labeling_heads'):
+ cur_state = self.labeling_heads.state_dict()
+ for k, v in cur_state.items():
+ if prefix + 'labeling_heads.' + k not in state_dict:
+ state_dict[prefix + 'labeling_heads.' + k] = v
...