lavis-nlp / spert

PyTorch code for SpERT: Span-based Entity and Relation Transformer
MIT License
692 stars 148 forks source link

Softmax implementation #48

Open ghost opened 3 years ago

ghost commented 3 years ago

Hi @markus-eberts , thanks for sharing your great work.

I was playing around a variation of spERT, where the relations where extracted using a softmax instead of a sigmoid. To ensure the correctness of the overall system I trained it with the version of conll04 that you provided with the model and everything seemed fine. The issues arose when trying to train it with a different dataset, converted to a format compatible with spERT. Train went smoothly, but the model didn't make any prediction at all, be it an entity or relation. I am for sure missing something, I was wondering if you could maybe provide to me a direction from which start to work.

Here is a single sample from the training dataset:

{"tokens": ["The", "role", "of", "p27(Kip1", ")", "in", "dasatinib-enhanced", "paclitaxel", "cytotoxicity", "in", "human", "ovarian", "cancer", "cells", ".", "\r\n"], "entities": [{"type": "drug", "start": 6, "end": 7}, {"type": "drug", "start": 7, "end": 8}], "relations": [{"type": "effect", "head": 0, "tail": 1}], "orig_id": "DDI-MedLine.d194.s0"}

On this dataset the softmax is recommended since all the relations are symmetrical and between two entities exists only a single relation.

Here is the log of the training run:


Config: {'label': 'softmax_ddi', 'model_type': 'spert', 'model_path': 'bert-base-cased', 'tokenizer_path': 'bert-base-cased', 'train_path': 'data/datasets/unibs/train/all.json', 'valid_path': 'data/datasets/unibs/dev/all.json', 'types_path': 'data/datasets/unibs/types.json', 'train_batch_size': '2', 'eval_batch_size': '1', 'neg_entity_count': '100', 'neg_relation_count': '100', 'epochs': '5', 'lr': '5e-5', 'lr_warmup': '0.1', 'weight_decay': '0.01', 'max_grad_norm': '1.0', 'rel_filter_threshold': '0.4', 'size_embedding': '25', 'prop_drop': '0.1', 'max_span_size': '10', 'store_predictions': 'true', 'store_examples': 'true', 'sampling_processes': '4', 'max_pairs': '1000', 'final_eval': 'true', 'log_path': 'data/log/', 'save_path': 'data/save/'} Repeat 1 times

Iteration 0

2021-05-29 09:45:30,631 [MainThread ] [INFO ] Datasets: data/datasets/unibs/train/all.json, data/datasets/unibs/dev/all.json 2021-05-29 09:45:30,631 [MainThread ] [INFO ] Model type: spert Parse dataset 'train': 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6045/6045 [00:09<00:00, 622.16it/s] Parse dataset 'valid': 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 931/931 [00:01<00:00, 609.67it/s] 2021-05-29 09:45:41,928 [MainThread ] [INFO ] Relation type count: 5 2021-05-29 09:45:41,928 [MainThread ] [INFO ] Entity type count: 5 2021-05-29 09:45:41,928 [MainThread ] [INFO ] Entities: 2021-05-29 09:45:41,928 [MainThread ] [INFO ] No Entity=0 2021-05-29 09:45:41,928 [MainThread ] [INFO ] Drug name=1 2021-05-29 09:45:41,928 [MainThread ] [INFO ] Drug=2 2021-05-29 09:45:41,928 [MainThread ] [INFO ] Group=3 2021-05-29 09:45:41,928 [MainThread ] [INFO ] Brand=4 2021-05-29 09:45:41,928 [MainThread ] [INFO ] Relations: 2021-05-29 09:45:41,928 [MainThread ] [INFO ] No Relation=0 2021-05-29 09:45:41,929 [MainThread ] [INFO ] Effect=1 2021-05-29 09:45:41,929 [MainThread ] [INFO ] Int=2 2021-05-29 09:45:41,929 [MainThread ] [INFO ] Mechanism=3 2021-05-29 09:45:41,929 [MainThread ] [INFO ] Advise=4 2021-05-29 09:45:41,929 [MainThread ] [INFO ] Dataset: train 2021-05-29 09:45:41,929 [MainThread ] [INFO ] Document count: 6045 2021-05-29 09:45:41,929 [MainThread ] [INFO ] Relation count: 3378 2021-05-29 09:45:41,929 [MainThread ] [INFO ] Entity count: 12549 2021-05-29 09:45:41,929 [MainThread ] [INFO ] Dataset: valid 2021-05-29 09:45:41,929 [MainThread ] [INFO ] Document count: 931 2021-05-29 09:45:41,929 [MainThread ] [INFO ] Relation count: 642 2021-05-29 09:45:41,929 [MainThread ] [INFO ] Entity count: 2216 2021-05-29 09:45:41,929 [MainThread ] [INFO ] Updates per epoch: 3022 2021-05-29 09:45:41,929 [MainThread ] [INFO ] Updates total: 15110

[...]

Evaluation

--- Entities (named entity recognition (NER)) --- An entity is considered correct if the entity type and span is predicted correctly

           type    precision       recall     f1-score      support
         drug_n         0.00         0.00         0.00        101.0
           drug         0.00         0.00         0.00       1396.0
          group         0.00         0.00         0.00        538.0
          brand         0.00         0.00         0.00        169.0

          micro         0.00         0.00         0.00       2204.0
          macro         0.00         0.00         0.00       2204.0

--- Relations ---

Without named entity classification (NEC) A relation is considered correct if the relation type and the spans of the two related entities are predicted correctly (entity type is not considered)

           type    precision       recall     f1-score      support
         advise         0.00         0.00         0.00        130.0
            int         0.00         0.00         0.00          8.0
         effect         0.00         0.00         0.00        250.0
      mechanism         0.00         0.00         0.00        253.0

          micro         0.00         0.00         0.00        641.0
          macro         0.00         0.00         0.00        641.0

With named entity classification (NEC) A relation is considered correct if the relation type and the two related entities are predicted correctly (in span and entity type)

           type    precision       recall     f1-score      support
         advise         0.00         0.00         0.00        130.0
            int         0.00         0.00         0.00          8.0
         effect         0.00         0.00         0.00        250.0
      mechanism         0.00         0.00         0.00        253.0

          micro         0.00         0.00         0.00        641.0
          macro         0.00         0.00         0.00        641.0

2021-05-29 12:21:19,887 [MainThread ] [INFO ] Logged in: data/log/softmax_ddi/2021-05-29_09-45-29.875587 2021-05-29 12:21:19,887 [MainThread ] [INFO ] Saved in: data/save/softmax_ddi/2021-05-29_09-45-29.875587

The following are the major changes that I applied to the original model: spert/spert_trainer.py

 class SpERTTrainer(BaseTrainer):
                                             config=config,
                                             # SpERT model parameters
                                             cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'),
-                                            relation_types=input_reader.relation_type_count - 1,
+                                            relation_types=input_reader.relation_type_count,
                                             entity_types=input_reader.entity_type_count,
                                             max_pairs=self._args.max_pairs,
                                             prop_drop=self._args.prop_drop,
 class SpERTTrainer(BaseTrainer):
                                                                  num_warmup_steps=args.lr_warmup * updates_total,
                                                                  num_training_steps=updates_total)
         # create loss function
-        rel_criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
+        rel_criterion = torch.nn.CrossEntropyLoss(reduction='none')
         entity_criterion = torch.nn.CrossEntropyLoss(reduction='none')

spert/loss.py

 class SpERTLoss(Loss):

         if rel_count.item() != 0:
             rel_logits = rel_logits.view(-1, rel_logits.shape[-1])
-            rel_types = rel_types.view(-1, rel_types.shape[-1])
+            rel_types = rel_types.view(-1)

             rel_loss = self._rel_criterion(rel_logits, rel_types)
-            rel_loss = rel_loss.sum(-1) / rel_loss.shape[-1]
             rel_loss = (rel_loss * rel_sample_masks).sum() / rel_count

             # joint loss

spert/sampling.py

 def create_train_sample(doc, neg_entity_count: int, neg_rel_count: int, max_span
         rel_sample_masks = torch.zeros([1], dtype=torch.bool)

     # relation types to one-hot encoding
-    rel_types_onehot = torch.zeros([rel_types.shape[0], rel_type_count], dtype=torch.float32)
-    rel_types_onehot.scatter_(1, rel_types.unsqueeze(1), 1)
-    rel_types_onehot = rel_types_onehot[:, 1:]  # all zeros for 'none' relation

     return dict(encodings=encodings, context_masks=context_masks, entity_masks=entity_masks,
                 entity_sizes=entity_sizes, entity_types=entity_types,
-                rels=rels, rel_masks=rel_masks, rel_types=rel_types_onehot,
+                rels=rels, rel_masks=rel_masks, rel_types=rel_types,
                 entity_sample_masks=entity_sample_masks, rel_sample_masks=rel_sample_masks)

spert/models.py

 class SpERT(BertPreTrainedModel):
             chunk_rel_logits = self._classify_relations(entity_spans_pool, size_embeddings,
                                                         relations, rel_masks, h_large, i)
             # apply sigmoid
-            chunk_rel_clf = torch.sigmoid(chunk_rel_logits)
-            rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf
+            rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_logits

-        rel_clf = rel_clf * rel_sample_masks  # mask

         # apply softmax
         entity_clf = torch.softmax(entity_clf, dim=2)
+        rel_clf = torch.softmax(rel_clf, dim=2)
+        rel_clf *= rel_sample_masks

         return entity_clf, rel_clf, relations

spert/predictions.py

 def convert_predictions(batch_entity_clf: torch.tensor, batch_rel_clf: torch.ten
     batch_entity_types *= batch['entity_sample_masks'].long()

     # apply threshold to relations
-    batch_rel_clf[batch_rel_clf < rel_filter_threshold] = 0

     batch_pred_entities = []
     batch_pred_relations = []

spert/predictions.py

 def _convert_pred_relations(rel_clf: torch.tensor, rels: torch.tensor,
                             entity_types: torch.tensor, entity_spans: torch.tensor, input_reader: BaseInputReader):
-    rel_class_count = rel_clf.shape[1]
-    rel_clf = rel_clf.view(-1)

     # get predicted relation labels and corresponding entity pairs
-    rel_nonzero = rel_clf.nonzero().view(-1)
-    pred_rel_scores = rel_clf[rel_nonzero]
-
-    pred_rel_types = (rel_nonzero % rel_class_count) + 1  # model does not predict None class (+1)
-    valid_rel_indices = rel_nonzero // rel_class_count
+    valid_rel_indices = torch.nonzero(torch.sum(rel_clf, dim=-1)).view(-1)
+    valid_rel_indices = valid_rel_indices.view(-1)
+    
+    pred_rel_types = rel_clf[valid_rel_indices]
+    if pred_rel_types.shape[0] != 0:
+        pred_rel_types = pred_rel_types.argmax(dim=-1)
+        valid_rel_indices = torch.nonzero(pred_rel_types).view(-1)
+        
+        pred_rel_types = pred_rel_types[valid_rel_indices]
+
+    pred_rel_scores = rel_clf[valid_rel_indices]
+    if pred_rel_scores.shape[0] != 0:
+        pred_rel_scores = pred_rel_scores.max(dim=-1)[0]

     valid_rels = rels[valid_rel_indices]

Not related to the previous topic, thought I'd add it here since the same dataset is involved. During the experimentation with the original spERT I changed bert to scibert. Using 1 epoch of training I had no issues whatsoever, when I increased them to 5 the procedure to store the predictions started to pick up relations that should instead be filtered out by previous elaboration (if I interpreted everything correctly). Here is the log


Config: {'label': 'scibert_ddi', 'model_type': 'spert', 'model_path': '/home/deeplearning/Salvalai/scibert_scivocab_uncased', 'tokenizer_path': '/home/deeplearning/Salvalai/scibert_scivocab_uncased', 'train_path': 'data/datasets/unibs/train/all.json', 'valid_path': 'data/datasets/unibs/dev/all.json', 'types_path': 'data/datasets/unibs/types.json', 'train_batch_size': '2', 'eval_batch_size': '1', 'neg_entity_count': '100', 'neg_relation_count': '100', 'epochs': '5', 'lr': '5e-5', 'lr_warmup': '0.1', 'weight_decay': '0.01', 'max_grad_norm': '1.0', 'rel_filter_threshold': '0.4', 'size_embedding': '25', 'prop_drop': '0.1', 'max_span_size': '10', 'store_predictions': 'true', 'store_examples': 'true', 'sampling_processes': '4', 'max_pairs': '1000', 'final_eval': 'true', 'log_path': 'data/log/', 'save_path': 'data/save/'} Repeat 1 times

Iteration 0

2021-05-28 10:54:28,162 [MainThread ] [INFO ] Datasets: data/datasets/unibs/train/all.json, data/datasets/unibs/dev/all.json 2021-05-28 10:54:28,162 [MainThread ] [INFO ] Model type: spert Parse dataset 'train': 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6045/6045 [00:12<00:00, 466.41it/s] Parse dataset 'valid': 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 931/931 [00:02<00:00, 359.59it/s] 2021-05-28 10:54:43,771 [MainThread ] [INFO ] Relation type count: 5 2021-05-28 10:54:43,771 [MainThread ] [INFO ] Entity type count: 5 2021-05-28 10:54:43,771 [MainThread ] [INFO ] Entities: 2021-05-28 10:54:43,772 [MainThread ] [INFO ] No Entity=0 2021-05-28 10:54:43,772 [MainThread ] [INFO ] Drug name=1 2021-05-28 10:54:43,772 [MainThread ] [INFO ] Drug=2 2021-05-28 10:54:43,772 [MainThread ] [INFO ] Group=3 2021-05-28 10:54:43,772 [MainThread ] [INFO ] Brand=4 2021-05-28 10:54:43,772 [MainThread ] [INFO ] Relations: 2021-05-28 10:54:43,772 [MainThread ] [INFO ] No Relation=0 2021-05-28 10:54:43,772 [MainThread ] [INFO ] Effect=1 2021-05-28 10:54:43,772 [MainThread ] [INFO ] Int=2 2021-05-28 10:54:43,772 [MainThread ] [INFO ] Mechanism=3 2021-05-28 10:54:43,772 [MainThread ] [INFO ] Advise=4 2021-05-28 10:54:43,772 [MainThread ] [INFO ] Dataset: train 2021-05-28 10:54:43,772 [MainThread ] [INFO ] Document count: 6045 2021-05-28 10:54:43,772 [MainThread ] [INFO ] Relation count: 3378 2021-05-28 10:54:43,772 [MainThread ] [INFO ] Entity count: 12549 2021-05-28 10:54:43,772 [MainThread ] [INFO ] Dataset: valid 2021-05-28 10:54:43,772 [MainThread ] [INFO ] Document count: 931 2021-05-28 10:54:43,773 [MainThread ] [INFO ] Relation count: 642 2021-05-28 10:54:43,773 [MainThread ] [INFO ] Entity count: 2216 2021-05-28 10:54:43,773 [MainThread ] [INFO ] Updates per epoch: 3022 2021-05-28 10:54:43,773 [MainThread ] [INFO ] Updates total: 15110

[...]

Evaluation

--- Entities (named entity recognition (NER)) --- An entity is considered correct if the entity type and span is predicted correctly

           type    precision       recall     f1-score      support
          brand         0.00         0.00         0.00        169.0
           drug         0.00         0.00         0.00       1396.0
         drug_n         0.00         0.00         0.00        101.0
          group         0.00         0.00         0.00        538.0

          micro         0.00         0.00         0.00       2204.0
          macro         0.00         0.00         0.00       2204.0

--- Relations ---

Without named entity classification (NEC) A relation is considered correct if the relation type and the spans of the two related entities are predicted correctly (entity type is not considered)

           type    precision       recall     f1-score      support
         advise         0.00         0.00         0.00        130.0
      mechanism         0.00         0.00         0.00        253.0
         effect         0.00         0.00         0.00        250.0
            int         0.00         0.00         0.00          8.0

          micro         0.00         0.00         0.00        641.0
          macro         0.00         0.00         0.00        641.0

With named entity classification (NEC) A relation is considered correct if the relation type and the two related entities are predicted correctly (in span and entity type)

           type    precision       recall     f1-score      support
         advise         0.00         0.00         0.00        130.0
      mechanism         0.00         0.00         0.00        253.0
         effect         0.00         0.00         0.00        250.0
            int         0.00         0.00         0.00          8.0

          micro         0.00         0.00         0.00        641.0
          macro         0.00         0.00         0.00        641.0

Process SpawnProcess-1: Traceback (most recent call last): File "/home/deeplearning/.conda/envs/salvalai/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap self.run() File "/home/deeplearning/.conda/envs/salvalai/lib/python3.8/multiprocessing/process.py", line 108, in run self._target(*self._args, **self._kwargs) File "/home/deeplearning/Salvalai/spert/spert.py", line 16, in __train trainer.train(train_path=run_args.train_path, valid_path=run_args.valid_path, File "/home/deeplearning/Salvalai/spert/spert/spert_trainer.py", line 97, in train self._eval(model, validation_dataset, input_reader, epoch + 1, updates_epoch) File "/home/deeplearning/Salvalai/spert/spert/spert_trainer.py", line 253, in _eval evaluator.store_predictions() File "/home/deeplearning/Salvalai/spert/spert/evaluator.py", line 87, in store_predictions prediction.store_predictions(self._dataset.documents, self._pred_entities, File "/home/deeplearning/Salvalai/spert/spert/prediction.py", line 196, in store_predictions head_idx = converted_entities.index(converted_head) ValueError: {'type': 'None', 'start': 0, 'end': 1} is not in list

Best regards

markus-eberts commented 3 years ago

Hi, I just pushed a corner case handling (commit e0d9aee90cd774fff3cb244701cbcf323359f4ee) which may be related to your problem. In some cases (especially strings containing only control characters) the tokenizer we are using maps tokens to empty sequences, which could lead to zero divisions (and NaN values) down the road. Can you please check if the commit fixes your problem? If so, it would still be better to remove any control characters from your dataset beforehand.

If this does not fix your isuse, could you please send me the dataset (or a representative part of it) by email (markus.eberts@hs-rm.de)? I can have a look at it then.

ghost commented 3 years ago

Thank you, I'll check asap and let you know