lavis-nlp / spert

PyTorch code for SpERT: Span-based Entity and Relation Transformer
MIT License
685 stars 146 forks source link

Using Transformers other than BERT and SciBERT #49

Closed dogatekin closed 3 years ago

dogatekin commented 3 years ago

First of all, thanks for the great implementation!

I've been trying to use different transformers (e.g. RoBERTa) with SpERT, but I ran into some problems. Simply changing the model_path and tokenizer_path in the config to the name of a different transformer from https://huggingface.co/models does not work, since the code currently uses BERT-specific classes such as BertTokenizer and BertModel instead of AutoTokenizer and AutoModel.

But even if I change those, I still have problems, possibly because the SpERT class itself is derived from BertPreTrainedModel. Using SciBERT by following your instructions definitely works, but I think that's because SciBERT has the exact same architecture/layer names/etc.

Do you know if using other transformers are possible with the current implementation (I might have misunderstood some parts)? If not, do you know what modifications would be needed to make it work with other transformers, or if there are any workarounds I could use? Thanks in advance!

markus-eberts commented 3 years ago

Hi @dogatekin, thanks for your interest in SpERT. In this case, you also need to derive from *PreTrainedModel (where * depends on the transformer) and change the transformer instance. For example, if you choose to use RoBERTa instead of BERT you should adjust the 'models.py' file like this:

class SpERT(RobertaPreTrainedModel):
    def __init__(self, config: BertConfig, cls_token: int, relation_types: int, entity_types: int,
                 size_embedding: int, prop_drop: float, freeze_transformer: bool, max_pairs: int = 100):
        super(SpERT, self).__init__(config)

        # RoBERTA model
        self.roberta = RobertaModel(config)

        [...]

    def _forward_train(self, encodings: torch.tensor, context_masks: torch.tensor, entity_masks: torch.tensor,
                       entity_sizes: torch.tensor, relations: torch.tensor, rel_masks: torch.tensor):
        [...]
        h = self.roberta(input_ids=encodings, attention_mask=context_masks)['last_hidden_state']
        [...]

    def _forward_inference(self, encodings: torch.tensor, context_masks: torch.tensor, entity_masks: torch.tensor,
                           entity_sizes: torch.tensor, entity_spans: torch.tensor, entity_sample_masks: torch.tensor):
        [...]
        h = self.roberta(input_ids=encodings, attention_mask=context_masks)['last_hidden_state']
dogatekin commented 3 years ago

Thanks a lot @markus-eberts, this does indeed work for many transformer models such as ALBERT. Interestingly, it doesn't seem possible to use from transformers import RobertaPreTrainedModel, but it is possible to import *PreTrainedModel versions of many other transformers. I've tried this using both transformers version 3 and version 4. This is unrelated to SpERT though, probably this is a feature/bug of the transformers library itself.

markus-eberts commented 3 years ago

I see. This seems to work for me (version 4.1.1): from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel

dogatekin commented 3 years ago

Thanks for the help! When you say it seems to work, does SpERT work when you do that change? The import also works for me, but when I change everything to Roberta I get the following error just as training starts (I'm running the unchanged example python ./spert.py train --config configs/example_train.conf):

Process SpawnProcess-1:
Traceback (most recent call last):
  File "/storage/dtekin/miniconda3/envs/spert/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/storage/dtekin/miniconda3/envs/spert/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/storage/dtekin/spert-orig/spert.py", line 17, in __train
    types_path=run_args.types_path, input_reader_cls=input_reader.JsonInputReader)
  File "/storage/dtekin/spert-orig/spert/spert_trainer.py", line 93, in train
    self._train_epoch(model, compute_loss, optimizer, train_dataset, updates_epoch, epoch)
  File "/storage/dtekin/spert-orig/spert/spert_trainer.py", line 191, in _train_epoch
    relations=batch['rels'], rel_masks=batch['rel_masks'])
  File "/storage/dtekin/miniconda3/envs/spert/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/storage/dtekin/spert-orig/spert/models.py", line 223, in forward
    return self._forward_train(*args, **kwargs)
  File "/storage/dtekin/spert-orig/spert/models.py", line 67, in _forward_train
    entity_clf, entity_spans_pool = self._classify_entities(encodings, h, entity_masks, size_embeddings)
  File "/storage/dtekin/spert-orig/spert/models.py", line 134, in _classify_entities
    entity_spans_pool, size_embeddings], dim=2)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 2. Got 2 and 4 in dimension 0 at /pytorch/aten/src/THC/generic/THCTensorMath.cu:71

This does not happen when I import AlbertPreTrainedModel (whether I import it using from transformers.models.albert.modeling_albert import AlbertPreTrainedModel or from transformers import AlbertPreTrainedModel does not matter) and change everything to Albert, I can train and evaluate models normally. Do you have any idea what might be happening?

markus-eberts commented 3 years ago

Hi, sorry for my late reply - in case you did not already solve the problem: The transition to RoBERTa is a bit more complicated with the current code. Since RoBERTa uses different IDs for the cls, sep, unk and padding tokens, you need to adjust some code parts. In 'input_reader.py' in the '_parse_tokens' function, replace [CLS], [SEP] and [UNK] with \, \ and \ respectively. In 'spert_trainer.py' ('_load_model' function) [CLS] must also be replaced with \. Also, in 'sampling.py' ('collate_fn_padding' function) the padding ID must be set to 1 for encodings, so something like this:

def collate_fn_padding(batch):
    padded_batch = dict()
    keys = batch[0].keys()

    for key in keys:
        samples = [s[key] for s in batch]

        if not batch[0][key].shape:
            padded_batch[key] = torch.stack(samples)
        else:
            padded_batch[key] = util.padded_stack([s[key] for s in batch],
                                                  padding=1 if key == 'encodings' else 0)

    return padded_batch

Maybe there are some other code parts that need to be adjusted but you get the idea.

dogatekin commented 3 years ago

Thanks so much for your efforts @markus-eberts, the changes you mentioned along with the renaming discussed earlier seem to be enough to get SpERT to work with RoBERTa. You've been incredibly helpful!

markus-eberts commented 3 years ago

@dogatekin no problem! One further suggestion: I'm not aware of a 'cased' RoBERTa version in huggingface. There seems to be only 'roberta-base' and 'roberta-large', which are both uncased as far as I know. In this case, you should set 'lowercase' to True (in the config file) when using SpERT (otherwise many cased words are mapped to the 'unknown' token).

dogatekin commented 3 years ago

Thanks for the thought and suggestion! This issue at https://github.com/pytorch/fairseq/issues/1429 seems to suggest that the RoBERTa models are cased and my preliminary experiments using SpERT as-is with roberta-base did not have any problems, so I believe it's all good.

markus-eberts commented 3 years ago

I see - good to know!