allenai / allennlp

An open-source NLP research library, built on PyTorch.
http://www.allennlp.org
Apache License 2.0
11.75k stars 2.25k forks source link

tutorial for basic language model: nested tokens break with a PytorchSeq2SeqWrapper model #3452

Closed DSLituiev closed 4 years ago

DSLituiev commented 4 years ago

Is there a tutorial or any sample code how to use LanguageModel?

I am trying to write a basic GRU-based LanguageModel and running into a huge inconvenience.

gru = GRU(config['embedding_dim'], 
           hidden_size=64,
           num_layers=2,
           dropout=0.0,
          batch_first=True)

gru_contextualizer = PytorchSeq2SeqWrapper(gru)

model = LanguageModel(vocab=vocab, text_field_embedder=word_embeddings, contextualizer=gru_contextualizer)

My batch comes from BucketIterator that takes in objects of Instance class

iterator = BucketIterator(batch_size=config['batch_size'], 
                          sorting_keys=[("tokens", "num_tokens")],
                         )
iterator.index_with(vocab)
for batch in iterator(instances):
    break

print(batch)
# {'tokens': {'tokens': tensor}}

Following call works:

model(batch['tokens'])

However, when I try to feed it into a trainer, it internally calls model(**batch) and it all breaks.

optimizer = optim.Adam(model.parameters(), lr=config['lr'])

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=instances,
    cuda_device=0 if config['gpu'] else -1,
    num_epochs=config['epochs'],
)
/Applications/anaconda3/envs/nlp/lib/python3.7/site-packages/allennlp/training/trainer.py in batch_loss(self, batch_group, for_training)
    269             batch = batch_group[0]
    270             batch = nn_util.move_to_device(batch, self._cuda_devices[0])
--> 271             output_dict = self.model(**batch)

Is there a functionality or a wrapper that will unpack the tokens?

matt-gardner commented 4 years ago

Can you give a complete stack trace? You didn't actually paste in the error that you got.

I suspect that your error is because of a mismatch between your dataset reader and your model, but it's hard to confirm that without more information.

DSLituiev commented 4 years ago

It is definitely because of the mismatch between my reader (which returns Instance({"tokens": TextField(...)})) and the PytorchSeq2SeqWrapper model:

https://gist.github.com/f6a6215aa82725bf47d9bc7a6cb5e594

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-169-3613d6448b3a> in <module>
----> 1 metrics = trainer.train()

/Applications/anaconda3/envs/nlp/lib/python3.7/site-packages/allennlp/training/trainer.py in train(self)
    488         for epoch in range(epoch_counter, self._num_epochs):
    489             epoch_start_time = time.time()
--> 490             train_metrics = self._train_epoch(epoch)
    491 
    492             # get peak of memory usage

/Applications/anaconda3/envs/nlp/lib/python3.7/site-packages/allennlp/training/trainer.py in _train_epoch(self, epoch)
    326             self.optimizer.zero_grad()
    327 
--> 328             loss = self.batch_loss(batch_group, for_training=True)
    329 
    330             if torch.isnan(loss):

/Applications/anaconda3/envs/nlp/lib/python3.7/site-packages/allennlp/training/trainer.py in batch_loss(self, batch_group, for_training)
    269             batch = batch_group[0]
    270             batch = nn_util.move_to_device(batch, self._cuda_devices[0])
--> 271             output_dict = self.model(**batch)
    272 
    273         try:

/Applications/anaconda3/envs/nlp/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    539             result = self._slow_forward(*input, **kwargs)
    540         else:
--> 541             result = self.forward(*input, **kwargs)
    542         for hook in self._forward_hooks.values():
    543             hook_result = hook(self, input, result)

TypeError: forward() got an unexpected keyword argument 'tokens'
matt-gardner commented 4 years ago

It's not a mismatch with PytorchSeq2SeqWrapper, it's with the LanguageModel. You wrote your own DatasetReader, and you didn't match the keys in your Instance with the arguments to LanguageModel.forward. The LanguageModel is expecting an argument named source, but you gave it an argument named tokens. I can't link directly to your line of code, because it's a notebook, but it happens in JsonLanguageModelingDatasetReader.text_to_instance() - you need to use {"source": ...} instead of {"tokens": ...}, so that the keys match between your instance and your model's forward method.