Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.53k stars 3.39k forks source link

[BUG] (`DataLoader`) sanity check fails due to `Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor)` #20456

Open MathiasBaumgartinger opened 3 days ago

MathiasBaumgartinger commented 3 days ago

Bug description

Hi there! I have previously created my first LightningDataModule. More specifically, a NonGeoDataModule which inherits from there (see torchgeo-fork. Interestingly, when I try to run this module I get RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor. Even more intersting is the fact, that if I override the transfer_batch_to_device like:

def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
        batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
        print("----------------------------------------")
        for k in batch.keys(): print(k, batch[k][0].get_device())
        print("----------------------------------------")

        return batch

I get the output

image 0 mask 0

It happens during the validation step (lightning/pytorch/strategies/strategy.py", line 411).

What version are you seeing the problem on?

v2.4

How to reproduce the bug

def train(
    config: dict, 
    data_dir: str=default_data_dir, 
    root_dir: str=default_root_dir,
    min_epochs: int=1,
    max_epochs: int=25) -> None:

    tune_metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}

    module = FL(
        num_workers=config["num_workers"], 
        batch_size=config["batch_size"], 
        patch_size=config["patch_size"],
        val_split_pct=0.25,
        use_toy=True,
        #augs=transforms,
        root=data_dir, 
    )
    task = SemanticSegmentationTask(
        model="unet",
        backbone="resnet50",
        ignore_index=255,
        in_channels=5,#(5+3), #appended indices
        num_classes=13,
        lr=config["lr"],
        patience=config["lr_patience"]
    )

    # Callbacks
    checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min")
    lr_monitor = LearningRateMonitor(logging_interval="step")
    tune_callback = TuneReportCheckpointCallback(
        {"loss": "val_loss", "accuracy": "val_accuracy"}, on="validation_end"
    )
    logger = TensorBoardLogger(save_dir=root_dir, name="FLAIR2logs")

    trainer = Trainer(
        accelerator=accelerator,
        num_nodes=1,
        callbacks=[checkpoint_callback, lr_monitor, tune_callback],
        log_every_n_steps=1,
        logger=logger,
        min_epochs=1,
        max_epochs=25,
        precision=32,
    )

    trainer.fit(model=task, datamodule=module)

Error messages and logs

Traceback (most recent call last):
  File "//Dev/forks/torchgeo/train_simple.py", line 158, in <module>
    main()
  File "//Dev/forks/torchgeo/train_simple.py", line 154, in main
    train(config)
  File "//Dev/forks/torchgeo/train_simple.py", line 151, in train
    trainer.fit(model=task, datamodule=module)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1023, in _run_stage
    self._run_sanity_check()
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1052, in _run_sanity_check
    val_loop.run()
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 178, in _decorator
    return loop_run(self, *args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 411, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
  File "//Dev/forks/torchgeo/torchgeo/trainers/segmentation.py", line 251, in validation_step
    y_hat = self(x)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "//Dev/forks/torchgeo/torchgeo/trainers/base.py", line 81, in forward
    return self.model(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/segmentation_models_pytorch/base/model.py", line 38, in forward
    features = self.encoder(x)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/segmentation_models_pytorch/encoders/resnet.py", line 63, in forward
    x = stages[i](x)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/container.py", line 219, in forward
    input = module(input)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 458, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 454, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

Environment

Current environment ``` ----------------------------------------------------------- Python Version: 3.10.4 PyTorch Version: 2.4.1 Cuda is available version: 12.4 Torch built with CUDA: True cuDNN Version: 90100 cuDNN Enabled: True cuDNN available: True Device: cuda Accelerator: gpu lightning 2.4.0 lightning-utilities 0.11.9 pytorch-lightning 2.4.0 ## conda env name: torchgeo channels: - pytorch - nvidia - conda-forge - defaults dependencies: - python=3.10 - pytorch-cuda=12.4 - pytorch=2.4 - torchgeo=0.6.0 - tensorboard=2.17 ----------------------------------------------------------- ```

More info

No response

MathiasBaumgartinger commented 2 days ago

After some debugging, I found that indeed, the batches were not on the GPU during the different steps. When I add .to(self.device) to the batch['image'] and batch['mask'] accesses (see: https://github.com/microsoft/torchgeo/blob/main/torchgeo/trainers/segmentation.py), the pipeline executes without errors.

AFAIK, this should not be necessary, as a pl.Trainer with a pl.LightningDataModule and pl.LightningModule should guarantee these are on the same device via the transfer_batch_to_device function. And obviously, those are are being called as I do get an output.

EDIT: Tagging @adamjstewart as this might be related to torchgeo

robmarkcole commented 2 days ago

Hit this too. Not sure how significant it is that the batch is a dict