Skip to content

Latest commit

 

History

History
190 lines (168 loc) · 9.22 KB

README_ZH.md

File metadata and controls

190 lines (168 loc) · 9.22 KB

roberta-crf

English 简体中文

在roberta上添加crf层做文本标注任务

本项目的主要目的是在保持roberta原本结构和代码逻辑的前提下,借用fairseq给roberta添加crf层,在使用本代码前你需要克隆 fairseq.

  1. 克隆 fairseq.
  2. 替换 fairseq/fairseq/models/roberta/model.py, fairseq/fairseq/models/roberta/hub_interface.py, fairseq/fairseq/trainer.py, 不要担心改动会影响其他任务,被改动的代码后面会附上
  3. 把 tasks 和 criterions文件夹下面的文件移动到 fairseq/fairseq/tasks and fairseq/fairseq/criterions
  4. 运行 kaggel_ner_encoder.py 给训练数据做编码
  5. 运行 pre_process.sh 生成训练和验证文件
  6. 运行 train.sh finetuning roberta

点击 这里 下载数据

difference in trainer.py

            try:
                self.get_model().load_state_dict(
                    state["model"], strict=True, args=self.args
                )
                if utils.has_parameters(self.get_criterion()):
+                    state["criterion"] = self.get_criterion().state_dict()  # add this code to load crf layer while loading state dict
                    self.get_criterion().load_state_dict(
                        state["criterion"], strict=True
                    )

difference in hub_interface.py

+    def predict_label(self, head: str, tokens: torch.LongTensor, return_logits: bool = False):
+        features = self.extract_features(tokens.to(device=self.device))
+        path_score, path = self.model.labeling_heads[head].forward_decode(features)
+        return path_score, path

difference in model.py

from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
+from fairseq.modules.dynamic_crf_layer import DynamicCRF as CRF

from .hub_interface import RobertaHubInterface


logger = logging.getLogger(__name__)


@register_model('roberta')
class RobertaModel(FairseqEncoderModel):
  ...

        self.classification_heads = nn.ModuleDict()
+        self.labeling_heads = nn.ModuleDict()

    @staticmethod
-    def forward(self, src_tokens, features_only=False, return_all_hiddens=False, classification_head_name=None, **kwargs):
-        if classification_head_name is not None:
+    # add arg `labeling_head_name`
+    def forward(self, src_tokens, features_only=False, return_all_hiddens=False, classification_head_name=None, labeling_head_name=None, **kwargs):
+        if (classification_head_name or labeling_head_name) is not None:
            features_only = True

        x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)

        if classification_head_name is not None:
            x = self.classification_heads[classification_head_name](x)
+        if labeling_head_name is not None:
+            x = self.labeling_heads[labeling_head_name](x, **kwargs)
        return x, extra

    def get_normalized_probs(self, net_output, log_probs, sample=None):
        """Get normalized probabilities (or log probs) from a net's output."""
        logits = net_output[0].float()
        if log_probs:
            return F.log_softmax(logits, dim=-1)
        else:
            return F.softmax(logits, dim=-1)
    ...
+    # according to `register_classificaation_head`
+    def register_labeling_head(self, name, num_tags=None, inner_dim=None, **kwargs):
+        """Register a labeling head."""
+        if name in self.labeling_heads:
+            prev_num_tags = self.labeling_heads[name].dense.out_features
+            prev_inner_dim = self.labeling_heads[name].dense.in_features
+            if num_tags != prev_num_tags or inner_dim != prev_inner_dim:
+                logger.warning(
+                    're-registering head "{}" with num_tags {} (prev: {}) '
+                    'and inner_dim {} (prev: {})'.format(
+                        name, num_tags, prev_num_tags, inner_dim, prev_inner_dim
+                    )
+                )
+        self.labeling_heads[name] = RobertaLabelingHead(
+            self.args.encoder_embed_dim,
+            inner_dim or self.args.encoder_embed_dim,
+            num_tags,
+            self.args.pooler_dropout,
+            self.args.quant_noise_pq,
+            self.args.quant_noise_pq_block_size,
+        )
    ...
    def upgrade_state_dict_named(self, state_dict, name):
        prefix = name + '.' if name != '' else ''
        # rename decoder -> encoder before upgrading children modules
        for k in list(state_dict.keys()):
            if k.startswith(prefix + 'decoder'):
                new_k = prefix + 'encoder' + k[len(prefix + 'decoder'):]
                state_dict[new_k] = state_dict[k]
                del state_dict[k]

        # upgrade children modules
        super().upgrade_state_dict_named(state_dict, name)

        # Handle new classification heads present in the state dict.
-        current_head_names = (
-            [] if not hasattr(self, 'classification_heads')
-            else self.classification_heads.keys()
-        )
+        if hasattr(self, 'classification_heads'):
+            current_head_names = (self.classification_heads.keys())
+        elif hasattr(self, 'labeling_heads'):
+            current_head_names = (self.labeling_heads.keys())
+        else:
+            current_head_names = ([])
        keys_to_delete = []
        for k in state_dict.keys():
-             if not k.startswith(prefix + 'classification_heads.'):
-                continue
-
-            head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
-            num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
-            inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
+            if not (k.startswith(prefix + 'classification_heads.') or k.startswith(prefix + 'labeling_heads.')):
+                continue
+            elif k.startswith(prefix + 'classification_heads.'):
+                head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
+                num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
+                inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
+            elif k.startswith(prefix + 'labeling_heads.'):
+                head_name = k[len(prefix + 'labeling_heads.'):].split('.')[0]
+                num_classes = state_dict[prefix + 'labeling_heads.' + head_name + '.dense.weight'].size(0)
+                inner_dim = state_dict[prefix + 'labeling_heads.' + head_name + '.dense.weight'].size(1)

            if getattr(self.args, 'load_checkpoint_heads', False):
-                if head_name not in current_head_names:
-                    self.register_classification_head(head_name, num_classes, inner_dim)
+                if (head_name not in current_head_names 
+                    and k.startswith(prefix + 'classification_heads.')):
+                    self.register_classification_head(head_name, num_classes, inner_dim)
+                elif (head_name not in current_head_names 
+                      and k.startswith(prefix + 'labeling_heads.')):
+                    self.register_labeling_head(head_name, num_classes, inner_dim)
            else:
                if head_name not in current_head_names:
                    logger.warning(
                        'deleting classification head ({}) from checkpoint '
                        'not present in current model: {}'.format(head_name, k)
                    )
                    keys_to_delete.append(k)
                elif (
                    num_classes != self.classification_heads[head_name].out_proj.out_features
                    or inner_dim != self.classification_heads[head_name].dense.out_features
                ) 
+                or (num_classes != self.labeling_heads[head_name].dense.weight):
                    logger.warning(
                        'deleting classification head ({}) from checkpoint '
                        'with different dimensions than current model: {}'.format(head_name, k)
                    )
                    keys_to_delete.append(k)
        for k in keys_to_delete:
            del state_dict[k]

        # Copy any newly-added classification heads into the state dict
        # with their current weights.
        if hasattr(self, 'classification_heads'):
            cur_state = self.classification_heads.state_dict()
            for k, v in cur_state.items():
                if prefix + 'classification_heads.' + k not in state_dict:
                    logger.info('Overwriting ' + prefix + 'classification_heads.' + k)
                    state_dict[prefix + 'classification_heads.' + k] = v
+        if hasattr(self, 'labeling_heads'):
+            cur_state = self.labeling_heads.state_dict()
+            for k, v in cur_state.items():
+                if prefix + 'labeling_heads.' + k not in state_dict:
+                    state_dict[prefix + 'labeling_heads.' + k] = v
    ...