Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.16k stars 3.37k forks source link

_has_len does not handle NotImplementedError (raised by torchtext) #2277

Closed thschaaf closed 4 years ago

thschaaf commented 4 years ago

πŸ› Bug

When using torchtext.data.Iterator with a batch_size_fn function the len function raises a NotImplementedError which is not caught by _has_len function.

A bug-fix is very simple by just returning False if a NotImplementedError is raised. This is unlikely to have any negative side effects since it corresponds with what _hads_len is expected to do. The fix allowed me to train my model using torch text.

I plan to submit a pull request with the fix above.

There are no additional dependencies required; however this problem occurred when using torchtext.

Example stack trace:

Traceback (most recent call last):
  File "/Users/thomas/scm/OakDataPrep/oakSkipThoughtTrainer.py", line 18, in <module>
    trainer.fit(model)
  File "/Users/thomas/virtualenv/Python3/PyTorch/env/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 952, in fit
    self.run_pretrain_routine(model)
  File "/Users/thomas/virtualenv/Python3/PyTorch/env/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1091, in run_pretrain_routine
    self.train()
  File "/Users/thomas/virtualenv/Python3/PyTorch/env/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 334, in train
    self.reset_train_dataloader(model)
  File "/Users/thomas/virtualenv/Python3/PyTorch/env/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py", line 201, in reset_train_dataloader
    if not _has_len(self.train_dataloader):
  File "/Users/thomas/virtualenv/Python3/PyTorch/env/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py", line 49, in _has_len
    if len(dataloader) == 0:
  File "/Users/thomas/virtualenv/Python3/PyTorch/env/lib/python3.7/site-packages/torchtext/data/iterator.py", line 136, in __len__
    raise NotImplementedError
NotImplementedError

To Reproduce

Sorry I currently don't have a minimal example. The issue will always occur when torchtext.data.Iterator gets a batch_size_fn passed in. If the fix is not convincing I can take the time and construct a code example. Hope this is not necessary.

Code sample

I created my own Iterator for a Skip-Thought model, that dynamically batches sentences together. This might be unnecessary complex, or even not really useful however it revealed that issue described above when using torchtext. For context here is a code excerpt that creates the issue:

import torchtext
...
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
    "Keep augmenting batch and calculate total number of tokens + padding."
    global max_src_in_batch, max_tgt_in_batch
    if count == 1:
        max_src_in_batch = 0
        max_tgt_in_batch = 0
    max_src_in_batch = max(max_src_in_batch,  len(new.current))
    max_tgt_in_batch = max(max_tgt_in_batch,  len(new.next) + 2)
    src_elements = count * max_src_in_batch
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)

class MyIterator(torchtext.data.Iterator):

    def create_batches(self):
        if self.train:
            def pool(d, random_shuffler):
                for p in data.batch(d, self.batch_size * 100):
                    p_batch = data.batch(
                            sorted(p, key=self.sort_key),
                            self.batch_size, self.batch_size_fn)
                    for b in random_shuffler(list(p_batch)):
                        yield b
            self.batches = pool(self.data(), self.random_shuffler)

        else:
            self.batches = []
            for b in data.batch(self.data(), self.batch_size,
                                self.batch_size_fn):
                self.batches.append(sorted(b, key=self.sort_key))

...
class SkipThoughts(pl.LightningModule):
...
    @pl.data_loader
    def train_dataloader(self):
        train_iter = MyIterator(self.my_train_dataloader, batch_size=self.batch_size, repeat=False,
                                   sort_key=lambda x:
                                   data.interleave_keys(len(x.current),
                                                        data.interleave_keys(len(x.prev),
                                                                             len(x.next))),
                                   batch_size_fn=batch_size_fn, train=True,
                                   shuffle=True)
        return train_iter

But this happens whenever a batch_size_fn is used in torchtext. Because it is unknown how many batches the data set will have torchtext len method returns a NotImplementedError. See code snipped below:

def __len__(self):
    if self.batch_size_fn is not None:
        raise NotImplementedError
    return math.ceil(len(self.dataset) / self.batch_size)

Expected behavior

The function _has_len tests if len can is available and then returns True, otherwise False. It shoudl return False if NotImplementedError is raised.

Environment

/Users/thomas/virtualenv/Python3/PyTorch/env/bin/python /Users/thomas/scm/OakDataPrep/collect_env_details.py

Process finished with exit code 0

Additional context

Issue occur with Pytorch-Lighning 0.8 and Torchtext 0.6

github-actions[bot] commented 4 years ago

Hi! thanks for your contribution!, great first issue!

williamFalcon commented 4 years ago

good catch! mind submitting a PR?

thschaaf commented 4 years ago

Absolutely. The fix is easy. How important is a test for that case?