asyml / texar-pytorch

Integrating the Best of TF into PyTorch, for Machine Learning, Natural Language Processing, and Text Generation. This is part of the CASL project: http://casl-project.ai/
https://asyml.io
Apache License 2.0
745 stars 117 forks source link

Query related to data iterators for Seq2Seq translation using bert-gpt2 #256

Open seekingpeace opened 4 years ago

seekingpeace commented 4 years ago

Hi, while trying to use the following snippet:

import texar.torch as tx
from texar.torch.run import *

# (1) Modeling
class BERTGPT2Model(nn.Module):
  """An encoder-decoder model with GPT-2 as the decoder."""
  def __init__(self, vocab_size):
    super().__init__()
    # Use hyperparameter dict for model configuration
    self.tokeniserBERT = tx.data.BERTTokenizer('bert-base-uncased)
    self.tokeniserGPT2 = tx.data.GPT2Tokenizer('gpt2-medium')
    self.encoder = modules.BERTEncoder('bert-base-uncased')
    self.decoder = tx.modules.GPT2Decoder("gpt2-medium")  # With pre-trained weights

  def _get_decoder_output(self, batch, train=True):
    """Perform model inference, i.e., decoding."""
    enc_states,_ = self.encoder(inputs=self.embedder(batch['source_text_ids']),
                              sequence_length=batch['source_length'])
    if train:  # Teacher-forcing decoding at training time
      return self.decoder(
          inputs=batch['target_text_ids'], sequence_length=batch['target_length'] - 1,
          memory=enc_states, memory_sequence_length=batch['source_length'])
    else:      # Beam search decoding at prediction time
      start_tokens = torch.full_like(batch['source_text_ids'][:, 0], BOS)  # which BOS to use?
      return self.decoder(
          beam_width=5, start_tokens=start_tokens,
          memory=enc_states, memory_sequence_length=batch['source_length'])

  def forward(self, batch):
    """Compute training loss."""
    outputs = self._get_decoder_output(batch)
    loss = tx.losses.sequence_sparse_softmax_cross_entropy(  # Sequence loss
        labels=batch['target_text_ids'][:, 1:], logits=outputs.logits,
        sequence_length=batch['target_length'] - 1)  # Automatic masking
    return {"loss": loss}

  def predict(self, batch):
    """Compute model predictions."""
    sequence, _ = self._get_decoder_output(batch, train=False)
    return {"gen_text_ids": sequence}

# (2) Data
# Create dataset splits using built-in data loaders
datasets = {split: tx.data.PairedTextData(hparams=data_hparams[split])
            for split in ["train", "valid", "test"]}

model = BERTGPT2Model(datasets["train"].target_vocab.size)

# (3) Training
# Manage the train-eval loop with the Executor API
executor = Executor(
  model=model, datasets=datasets,
  optimizer={"type": torch.optim.Adam, "kwargs": {"lr": 5e-4}},
  stop_training_on=cond.epoch(20),
  log_every=cond.iteration(100),
  validate_every=cond.epoch(1),
  train_metric=("loss", metric.RunningAverage(10, pred_name="loss")),
  valid_metric=metric.BLEU(pred_name="gen_text_ids", label_name="target_text_ids"),
  save_every=cond.validation(better=True),
  checkpoint_dir="outputs/saved_models/")
executor.train()
executor.test(datasets["test"]) 

In this example

  1. How should i use data iterators from files
  2. Data config for generating the file from source text to tokeniserBERT.encode_text(src) and target text tokeniserGPT2.encode_text(tgt) so that it can pass through the batch.
  3. does PairedTextData has an option to pass different processors in above use case.

TIA

huzecong commented 4 years ago
  1. PairedTextData directly reads from files. Please take a look at its hparams, where you can set file paths for both the source-side and target-side datasets.

  2. (and 3.) I'm not sure what you want to achieve here. My guess is that you want to use the BERT and GPT2 tokenizers to tokenizer the source and target datasets respectively, is this correct?

    PairedTextData has a number of processing options built-in, such as delimiter for tokenization ("delimiter"), maximum sentence length ("max_seq_length"). If you want to do additional processing, you have two options:

    • Write transformation functions (functions that take as input a list of strings, indicating a sentence, and returning the processed sentence, also as a list of strings) and add then to "other_transformations".
    • Inherit PairedTextData and override the process function. Only do this if you understand how PairedTextData works.

    However, in your case, I think the easier way would be to write the data loader yourself, as tokenizers direct convert untokenized strings to token IDs. @gpengzhi can you help writing an example for this use case?

seekingpeace commented 4 years ago

Thanks @huzecong for the reply, To make it work i made separate vocab files for both gpt2 and bert. stored it in text and passed a part of tokeniser as mentioned below:

tokenizer_gpt2 = tx.data.GPT2Tokenizer(
        pretrained_model_name='gpt2-small')
tokenizer_bert = tx.data.BERTTokenizer(
        pretrained_model_name='bert-base-uncased')

def token_transform_bert(arr):
    arr_str = ' '.join(arr)
    ret_arr = tokenizer_bert.map_text_to_token(arr_str)
    return ret_arr

def token_transform_gpt2(arr):
    arr_str = ' '.join(arr)
    ret_arr = tokenizer_gpt2.map_text_to_token(arr_str)
    return ret_arr

data_hparams={
        'train':{
            'source_dataset': {'files': 'exp/train_src.txt','vocab_file':'exp/bert_vocab.txt','max_seq_length': 40,
                               'bos_token':'[CLS]','eos_token':'[SEP]','other_transformations':[token_transform_bert]},
            'target_dataset': {'files': 'exp/train_tgt.txt','vocab_file':'exp/gpt2_vocab.txt','max_seq_length': 40,
                               'bos_token':'<|endoftext|>','eos_token':'<|endoftext|>','other_transformations':[token_transform_gpt2]},
            'batch_size': 40,
            "allow_smaller_final_batch": True,
            "shuffle": True,
            "num_parallel_calls":3
            },
        'test':{
            'source_dataset': {'files': 'exp/test_src.txt','vocab_file':'exp/bert_vocab.txt','max_seq_length': 40,
                               'bos_token':'[CLS]','eos_token':'[SEP]','other_transformations':[token_transform_bert]},
            'target_dataset': {'files': 'exp/test_tgt.txt','vocab_file':'exp/gpt2_vocab.txt','max_seq_length': 40,
                               'bos_token':'<|endoftext|>','eos_token':'<|endoftext|>','other_transformations':[token_transform_gpt2]},
            'batch_size': 12
            },
        'valid':{
            'source_dataset': {'files': 'exp/valid_src.txt','vocab_file':'exp/bert_vocab.txt','max_seq_length': 40,
                               'bos_token':'[CLS]','eos_token':'[SEP]','other_transformations':[token_transform_bert]},
            'target_dataset': {'files': 'exp/valid_tgt.txt','vocab_file':'exp/gpt2_vocab.txt','max_seq_length': 40,
                               'bos_token':'<|endoftext|>','eos_token':'<|endoftext|>','other_transformations':[token_transform_gpt2]},
            'batch_size': 12
            }

        }

After this an exception was raised that these special tokens already exists in vocab. So had to remove that from vocabulary.py class. Also, monkey patched paired_text_data.py since there was no way to pass pad and unk to PairedTextData

self._src_vocab = Vocab(src_hparams.vocab_file,
                                bos_token=src_hparams.bos_token,
                                eos_token=src_hparams.eos_token,
                                pad_token='[PAD]',
                                unk_token='[UNK]')
self._tgt_vocab = Vocab(tgt_hparams["vocab_file"],
                                bos_token=tgt_bos_token,
                                eos_token=tgt_eos_token,
                                pad_token='<|endoftext|>',
                                unk_token='<|endoftext|>')

I Think:

  1. We can have an option to pass these additional tokens in PairedTextData
  2. In vocab building class for adding special seq to vocab we can check if special seq is already present and skip that check and use existing id. since the vocab is adding vocab = [self._pad_token, self._bos_token, self._eos_token, self._unk_token] + vocab my whole vocab will shift to right giving incorrect results.
huzecong commented 4 years ago

Thank you for your feedback! These are all valuable suggestions and I think we could add them. We're actually discussing the possibility to deprecate the Vocab class and switch everything to tokenizer-based, or at least provide interfaces in the data modules to use tokenizers instead of vocab. @gpengzhi Could you try working on this after the holidays?

gpengzhi commented 4 years ago

Yes. I think we should support this feature. Since pre-trained tokenizers already take care of the corresponding vocabulary files and the special tokens, it is unnecessary to require vocabulary file and the special tokens when people use PairedTextData. I will think about this enhancement in our data module.

seekingpeace commented 4 years ago

So, @gpengzhi @huzecong Ideally will there be a tokeniser builder class which can accept a pretrained tokeniser or a new tokeniser. This new tokeniser for src and tgt be passed to data modules for processing the data? This can be a cool new feature. Will make many things seamless.