Lightning-Universe / lightning-flash

Your PyTorch AI Factory - Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains
https://lightning-flash.readthedocs.io
Apache License 2.0
1.74k stars 212 forks source link

`RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <enum 'DefaultDataKeys'>.` #770

Closed dlangerm closed 2 years ago

dlangerm commented 2 years ago

🐛 Bug

It seems that there is some sort of state issue with custom preprocessors and flash modules. My custom preprocessor basic code is below, but it's based on the examples given in the documentation and is very simple. I can't get the error to appear consistently but it has happened at least once in my past few training iterations. to_onnx seems to trigger it as well, but not always, which makes me think it has something to do with the preprocessor state (current_transform perhaps).

I am happy to help debug this issue but it's really annoying and I do need a custom preprocessor for my data.

To Reproduce

Steps to reproduce the behavior:

  1. Add a custom preprocessor to a data module initialized from a data source
  2. Add an example_input_array to the module
  3. Finetune model

Stack trace:

Traceback (most recent call last):
  File "xxx\dummy_run.py", line 35, in <module>
    run()
  File "xxx\dummy_run.py", line 31, in run
    train(_cfg_)
  File "xxx\train.py", line 78, in train
    test_model(trainer, had_error, checkpointer, data_module)
  File "xxx\test_model.py", line 18, in test_model
    raise had_error
  File "xxx\train.py", line 69, in train
    trainer.finetune(net, datamodule=data_module,
  File "xxx\env\lib\site-packages\flash\core\trainer.py", line 165, in finetune
    return super().fit(model, train_dataloader, val_dataloaders, datamodule)
  File "xxxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 553, in fit
    self._run(model)
  File "xxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 918, in _run
    self._dispatch()
  File "xxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 986, in _dispatch
    self.accelerator.start_training(self)
  File "xxx\overview.ai\tyson\env\lib\site-packages\pytorch_lightning\accelerators\accelerator.py", line 92, in start_training
    self.training_type_plugin.start_training(trainer)
  File "xxx\env\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py", line 161, in start_training
    self._results = trainer.run_stage()
  File "xxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 996, in run_stage
    return self._run_train()
  File "xxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1026, in _run_train
    self._pre_training_routine()
  File "xxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1019, in _pre_training_routine
    ref_model.summarize(max_depth=max_depth)
  File "xxx\env\lib\site-packages\pytorch_lightning\core\lightning.py", line 1711, in summarize
    model_summary = ModelSummary(self, max_depth=max_depth)
  File "xxx\env\lib\site-packages\pytorch_lightning\core\memory.py", line 215, in __init__
    self._layer_summary = self.summarize()
  File "xxx\env\lib\site-packages\pytorch_lightning\core\memory.py", line 271, in summarize
    self._forward_example_input()
  File "xxx\env\lib\site-packages\pytorch_lightning\core\memory.py", line 288, in _forward_example_input
    input_ = model._apply_batch_transfer_handler(input_)
  File "xxx\env\lib\site-packages\pytorch_lightning\core\lightning.py", line 281, in _apply_batch_transfer_handler
    batch = self.transfer_batch_to_device(batch, device, dataloader_idx)
  File "xxx\env\lib\site-packages\flash\core\data\data_pipeline.py", line 609, in __call__
    outputs = additional_func(outputs)
  File "xxx\env\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "xxx\env\lib\site-packages\flash\core\data\batch.py", line 239, in forward
    samples = self.per_batch_transform(samples)
  File "xxx\env\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "xxx\env\lib\site-packages\flash\core\data\utils.py", line 178, in forward
    return self.func(*args, **kwargs)
  File "xxx\lib\site-packages\flash\core\data\process.py", line 409, in per_batch_transform_on_device
    return self.current_transform(batch)
  File "xxx\env\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "xxx\lib\site-packages\flash\core\data\transforms.py", line 40, in forward
    keys = list(filter(lambda key: key in x, self.keys))
  File "xxx\lib\site-packages\flash\core\data\transforms.py", line 40, in <lambda>
    keys = list(filter(lambda key: key in x, self.keys))
  File "xxx\env\lib\site-packages\torch\_tensor.py", line 670, in __contains__
    raise RuntimeError(
RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <enum 'DefaultDataKeys'>.

Code sample

from lightning.CustomDataSource import CustomDataSource 
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.process import DefaultPreprocess
from torchvision import transforms as T
import re
from argparse import Namespace
import numpy as np

from flash.core.data.transforms import ApplyToKeys, merge_transforms
from flash.image.classification.transforms import default_transforms
from flash.core.data.data_source import DefaultDataKeys
from torchvision.transforms.functional import pil_to_tensor

class OverviewPreprocessor(DefaultPreprocess):
    def __init__(self, config):
        self.config = config
        img_size = (config['image_size'], config['image_size'])

        thing = ApplyToKeys(
            DefaultDataKeys.INPUT,
            T.Compose([
                T.Resize(config['image_size'])
            ])
        )

        train_transform = merge_transforms(
            default_transforms(img_size),
            {
                "pre_tensor_transform": thing,
                "post_tensor_transform": ApplyToKeys(
                    DefaultDataKeys.INPUT,
                    T.Compose(
                          T.RandomHorizontalFlip(),
                    )
                )
            }
        )

        tform= merge_transforms(
            default_transforms(img_size),
            {"post_tensor_transform": thing}
        )

        super().__init__(
            train_transform=train_transform,
            val_transform=tform,
            test_transform=tform,
            data_sources={
                "regression": CustomDataSource()
            },
            default_data_source="xxx",
        )

    @staticmethod
    def input_to_tensor(in_pil: np.ndarray):
        """Transform which creates a tensor from the given pil image and converts it to ``float``"""
        return pil_to_tensor(in_pil).float()

Expected behavior

Expect to_onnx or model.summarize to use the correct datatypes for inference

Environment

Additional context

FWIW the Task that I'm using is very basic but initializes the example_input_array in the constructor like this:

self.example_input_array = torch.ones(
            (
                1,
                channels,
                config['image_size'],
                config['image_size']
            )
        )
dlangerm commented 2 years ago

Additional investigation: I can load my checkpoint created just before this error and run to_onnx or summarize and it works fine. This only happens during training.

dlangerm commented 2 years ago

Turns out to fix it I had to modify the forward call of my module to accept a dictionary or a Tensor.

    def forward(self, x) -> torch.Tensor:
        if isinstance(x, torch.Tensor):
            x = self.backbone(x) ## <---- initially only had this line here
        elif isinstance(x, dict):
            x = self.backbone(x[DefaultDataKeys.INPUT])

Not exactly sure how to classify this as a bug or just inconsistent behavior but I think a fix is in order. In my mind, there's no reason for to_onnx to fail on data when I pass it in using the lightning method of example_input_array.

ethanwharris commented 2 years ago

Hi @dlangerm thanks for your work on this! Glad you got something working :smiley: Internally in Flash we represent everything as a dict. Usually in our tasks we unpack the dict in the *_step method and then implement the forward to just expect a tensor. So your common_step or similar method would have:

input, target = x[DefaultDataKeys.INPUT], x[DefaultDataKeys.TARGET]
pred = self(input)

Or something like that. That way things like onnx or torchscript which assume your forward to work with just an input tensor will still work without the need for your modification. Really interested to hear your thoughts or suggestions about how we can make this cleaner / smoother to use. Thanks :smiley:

dlangerm commented 2 years ago

@ethanwharris So the thing is I have all my *_step methods for that, but it seems like to_onnx and summarize doesn't call them, but when the example_input_array is passed through the preprocessor, it's wrapped in the dict which causes this error.

ethanwharris commented 2 years ago

Ahhhh, that's interesting, so the ONNX includes the preprocess but skips the step method? That's definitely something we should address. Let me know if you have any ideas to solve it, I'll try to look into it and see what options there are.

dlangerm commented 2 years ago

It seems like that is what's happening. Since my to_onnx call is in a callback itself, that may be where it's getting tripped up? It still doesn't make sense why summarize is failing though...

dlangerm commented 2 years ago

@ethanwharris I was able to confirm that every other stage of training is passing a dict to _apply_batch_transfer_handler except for the to_onnx and summarize calls which are passing raw Tensors.

dlangerm commented 2 years ago

@ethanwharris To fix the to_onnx call I had to change this line

https://github.com/PyTorchLightning/pytorch-lightning/blob/3aba9d16a8f9a6b23853dcece967b05b976a329b/pytorch_lightning/core/lightning.py#L1799

To this:

input_sample = self._apply_batch_transfer_handler({'input': input_sample})['input']

I think it could be fixed by overriding to_onnx in Task since the behavior is specifically because everything in a Task is expected to be a dictionary. Lightning seems not to have this behavior, so I think inserting this code directly in there would have unintended consequences.

dlangerm commented 2 years ago

The hunt for this bug has led to unrelated other issues, I'm going to go ahead and open a new issue about documentation on the Preprocess module.

dlangerm commented 2 years ago

BTW This is still an issue. For anyone else who runs into this, a hotfix is to override _apply_batch_transfer_handler in your Task.

    def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None) -> Any:
        if isinstance(batch, torch.Tensor):
            return super()._apply_batch_transfer_handler(batch={DefaultDataKeys.INPUT: batch}, device=device, dataloader_idx=dataloader_idx)[DefaultDataKeys.INPUT]
        else:
            return super()._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx)

But I'm not comfortable making this a pull request as _ prefixed functions are generally not meant to be overridden per convention and I'm not exactly sure what other side effects this would have.

stale[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.