TorchJD / torchjd

Library for Jacobian descent with PyTorch. It enables optimization of neural networks with multiple losses (e.g. multi-task learning).
https://torchjd.org
MIT License
131 stars 0 forks source link

Cannot access storage of BatchedTensorImp #154

Open rwesterman opened 2 days ago

rwesterman commented 2 days ago

I am trying to use TorchJD with Pytorch Lightning, which requires shutting off the manual optimization loop in lightning. My code looks like the minimal example in the documentation here except that I am having to override manual_backward() in order to run torchjd.backward().

I have tried a few approaches to integrate the two libraries, but no matter what I do, I get an error:

NotImplementedError: Cannot access storage of BatchedTensorImpl

Here is a minimal example of my code that causes the error:

from typing import List, Optional

import torch
import torch.nn

import torchjd
from torchjd.aggregation import UPGrad

from lightning import LightningModule

class MinimalModule(LightningModule):
    def __init__(self,
                 embedding_model: torch.nn.Module,
                 classifier_model: torch.nn.Module,
                 lr: float, 
                 labels: Optional[List[str]] = None,
                 multiclass_loss_fn: torch.nn.Module = torch.nn.CrossEntropyLoss(),
                 binary_loss_fn: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
                 ):
        super().__init__()
        self.model = torch.nn.Sequential(embedding_model, classifier_model)
        self.lr = lr
        self.labels = labels
        self.multiclass_loss_fn = multiclass_loss_fn
        self.binary_loss_fn = binary_loss_fn

        self.val_step_outputs = []
        self.test_step_outputs = []
        # Need to manually perform optimization for Jacobian Descent Multi-task training
        self.automatic_optimization = False
        self.aggregator = UPGrad()

    def forward(self, input_features):
        # Returns a tuple of the multiclass and binary outputs
        return self.model.forward(input_features)

    def training_step(self, batch, batch_idx):
        input_features, labels, _, _ = batch
        multiclass_out, binary_out, = self.forward(input_features)
        multiclass_loss = self.multiclass_loss_fn(multiclass_out, labels) 
        binary_loss = self.binary_loss_fn(binary_out, labels)
        opt = self.optimizers()
        opt.zero_grad()

        torchjd.backward([multiclass_loss, binary_loss], self.model.parameters(), self.aggregator)

        opt.step()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.lr)

And finally the full traceback

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/ryan.westerman/sap//lid/bin/run_lid_train.py", line 15, in <module>
[rank0]:     cli_main()
[rank0]:   File "/home/ryan.westerman/sap//lid/bin/run_lid_train.py", line 7, in cli_main
[rank0]:     cli = LightningCLI(
[rank0]:           ^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/cli.py", line 394, in __init__
[rank0]:     self._run_subcommand(self.subcommand)
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/cli.py", line 701, in _run_subcommand
[rank0]:     fn(**fn_kwargs)
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
[rank0]:     call._call_and_handle_interrupt(
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
[rank0]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank0]:     return function(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
[rank0]:     self._run(model, ckpt_path=ckpt_path)
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 986, in _run
[rank0]:     results = self._run_stage()
[rank0]:               ^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1030, in _run_stage
[rank0]:     self.fit_loop.run()
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
[rank0]:     self.advance()
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
[rank0]:     self.epoch_loop.run(self._data_fetcher)
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
[rank0]:     self.advance(data_fetcher)
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 252, in advance
[rank0]:     batch_output = self.manual_optimization.run(kwargs)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/manual.py", line 94, in run
[rank0]:     self.advance(kwargs)
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/manual.py", line 114, in advance
[rank0]:     training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 311, in _call_strategy_hook
[rank0]:     output = fn(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 389, in training_step
[rank0]:     return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 640, in __call__
[rank0]:     wrapper_output = wrapper_module(*args, **kwargs)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1636, in forward
[rank0]:     else self._run_ddp_forward(*inputs, **kwargs)
[rank0]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1454, in _run_ddp_forward
[rank0]:     return self.module(*inputs, **kwargs)  # type: ignore[index]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 633, in wrapped_forward
[rank0]:     out = method(*_args, **_kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/sap/lid/training/minimal_example.py", line 46, in training_step
[rank0]:     torchjd.backward([multiclass_loss, binary_loss], self.model.parameters(), self.aggregator)
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torchjd/autojac/backward.py", line 91, in backward
[rank0]:     backward_transform(EmptyTensorDict())
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torchjd/autojac/_transform/base.py", line 48, in __call__
[rank0]:     return self._compute(input)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torchjd/autojac/_transform/base.py", line 79, in _compute
[rank0]:     return self.outer(intermediate)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torchjd/autojac/_transform/base.py", line 48, in __call__
[rank0]:     return self._compute(input)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torchjd/autojac/_transform/base.py", line 79, in _compute
[rank0]:     return self.outer(intermediate)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torchjd/autojac/_transform/base.py", line 48, in __call__
[rank0]:     return self._compute(input)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torchjd/autojac/_transform/base.py", line 78, in _compute
[rank0]:     intermediate = self.inner(input)
[rank0]:                    ^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torchjd/autojac/_transform/base.py", line 48, in __call__
[rank0]:     return self._compute(input)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torchjd/autojac/_transform/_differentiate.py", line 26, in _compute
[rank0]:     differentiated_tuple = self._differentiate(tensor_outputs)
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torchjd/autojac/_transform/jac.py", line 25, in _differentiate
[rank0]:     return _jac(
[rank0]:            ^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torchjd/autojac/_transform/jac.py", line 83, in _jac
[rank0]:     grouped_jacobian_matrix = torch.vmap(get_vjp, chunk_size=chunk_size)(jac_outputs)
[rank0]:                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torch/_functorch/apis.py", line 201, in wrapped
[rank0]:     return vmap_impl(
[rank0]:            ^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
[rank0]:     return _flat_vmap(
[rank0]:            ^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 48, in fn
[rank0]:     return f(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 480, in _flat_vmap
[rank0]:     batched_outputs = func(*batched_inputs, **kwargs)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torchjd/autojac/_transform/jac.py", line 72, in get_vjp
[rank0]:     optional_grads = torch.autograd.grad(
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torch/autograd/__init__.py", line 436, in grad
[rank0]:     result = _engine_run_backward(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torch/autograd/graph.py", line 769, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: NotImplementedError: Cannot access storage of BatchedTensorImpl

Version information:

python: 3.11.9
torch: 2.4.1+cu121
lightning: 2.4.0
torchjd: 0.2.1
PierreQuinton commented 2 days ago

Thanks for the issue, we are looking into this.

In the meantime, could you provide a minimal calling code?

Can you try the following things and report the behavior?

It feels like the bug is coming from torch, possibly because of the bad combination of the @compile tag used in torch lightning and our use of vmap. Did you compile any of the models that you provide to MinimalModule?

Possibly related:

rwesterman commented 2 days ago

Thank you for the quick response!

Everything is called through LightningCLI

from lightning.pytorch.cli import LightningCLI

def cli_main():
    cli = LightningCLI(
        parser_kwargs={"parser_mode": "omegaconf"},
        seed_everything_default=42,
    )

if __name__ == "__main__":
    cli_main()

It reads from a config file like this

# lightning.pytorch==2.2.3
trainer:
    accelerator: auto                                                                                                                                                                                                                                    
    strategy: ddp                                                                                                                                                                                                                                        
    devices: auto                                                                                                                                                                                                                                        
    num_nodes: 1                                                                                                                                                                                                                                         
    precision: null
    logger: null
    plugins:
        class_path: lightning.pytorch.plugins.precision.MixedPrecision
        init_args:
            precision: 16-mixed
            device: cuda
            scaler:
                class_path: torch.cuda.amp.GradScaler
                init_args:
                    init_scale: 1.0
    callbacks:
        - class_path: lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor
          init_args:
              logging_interval: step
    fast_dev_run: false
    max_epochs: 10
    max_steps: -1
    check_val_every_n_epoch: 1
    log_every_n_steps: 50
    enable_checkpointing: false
    enable_progress_bar: false
    accumulate_grad_batches: 1
    gradient_clip_val: null
    gradient_clip_algorithm: null
    inference_mode: false
    reload_dataloaders_every_n_epochs: 0
ckpt_path: null
model:
    class_path: lid.training.minimal_example.MinimalModule
    init_args:
        embedding_model:
            class_path: lid.model.ECAPA_TDNN
            init_args:
                input_size: 80
                conv_config_groups: # (num_channel, kernel_size, dilation)
                    - [448, 5, 1]
                    - [448, 3, 2]
                    - [448, 3, 3]
                    - [448, 3, 4]
                    - [1344, 1, 1]
                attention_channels: null
                res2net_scale: null
                se_channels: null
                output_size: null
                global_context: true
        classifier_model:
            class_path: lid.model.MultiTaskClassifier
            init_args:
                input_size: 192
                first_out_dim: 36
                second_out_dim: 36
        lr: 0.001
        labels:
        - "en"
        - "de"
        - "fr"
        - "es"
data:
    class_path: lid.data_processing.xscp_data_module.FeatureDataModule
    init_args:
        data_dir: null
        labels: ${model.init_args.labels}
        label_name: "language"
        batch_size: 512
        fp16: true
        num_workers: 32
        batch_weight: null # Need to provide this in train script

stack the two losses into one vector called losses and call torch.autograd.backward(losses, torch.ones_like(losses)) instead of torchjd.backward

This first throws an AssertionError about not using scaler

[rank1]: AssertionError: Attempted unscale_ but _scale is None.  This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration.

If I manually use scaler like this

        scaler = self.trainer.strategy.precision_plugin.scaler
        if scaler is not None:
            losses = scaler.scale(losses)
        torch.autograd.backward(losses, torch.ones_like(losses))

then the code runs without error.

give the parameter parallel_chunk_size=1 to your call to torchjd.backward?

This gives the error:

[rank1]: ValueError: When using `retain_graph=False`, parameter `parallel_chunk_size` must be `None` or large enough to compute all gradients in parallel.

And after setting retain_graph=True, the error again becomes

[rank1]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torch/autograd/graph.py", line 769, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: NotImplementedError: Cannot access storage of BatchedTensorImpl

Print the type of the following tensors in training_step: multiclass_out and binary_out multiclass_loss binary_loss

Here are the types, dtypes, and grad_fns of the outputs and losses:

Type of multiclass_out: <class 'torch.Tensor'>
Type of binary_out: <class 'torch.Tensor'>
dtype of multiclass_out: torch.float16
dtype of binary_out: torch.float16
grad_fn of multiclass_out: <AddmmBackward0 object at 0x7fc6ae199e70>
grad_fn of binary_out: <AddmmBackward0 object at 0x7fc6ae199e70>
Type of multiclass_loss: <class 'torch.Tensor'>
Type of binary_loss: <class 'torch.Tensor'>
dtype of multiclass_loss: torch.float32
dtype of binary_loss: torch.float32
grad_fn of multiclass_loss: <DivBackward1 object at 0x7fc6ae199e70>
grad_fn of binary_loss: <BinaryCrossEntropyWithLogitsBackward0 object at 0x7fc6ae199e70>

Did you compile any of the models that you provide to MinimalModule?

Not manually, but it's possible that Lightning is attempting to compile somewhere. I think I'd see grad_fn=<CompiledFunctionBackward> as the grad function for the losses or model outputs but I don't see it for any of them.

PierreQuinton commented 1 day ago

Thank you for the detailed update!

One other thing that would be helpful is to replace the call to backward with the following:

losses = torch.stack([multiclass_loss, binary_loss])
ones = torch.ones_like(losses)
param_list = list(self.model.parameters())
_ = torch.autograd.grad([losses], param_list, [ones])
rwesterman commented 1 day ago

I realized that I didn't specify package versions in my original post so I've updated that and I'll also post them here:

python: 3.11.9
torch: 2.4.1+cu121
lightning: 2.4.0
torchjd: 0.2.1

For the code you posted above:

        opt.zero_grad()

        losses = torch.stack([multiclass_loss, binary_loss])
        ones = torch.ones_like(losses)
        param_list = list(self.model.parameters())
        scaler = self.trainer.strategy.precision_plugin.scaler
        if scaler is not None:
            losses = scaler.scale(losses)
        _ = torch.autograd.grad([losses], param_list, [ones])

        opt.step()

yields the following error from opt.step()

[rank0]:   File "/home/ryan.westerman/sap/lid/training/minimal_example.py", line 73, in training_step
[rank0]:     opt.step()
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/core/optimizer.py", line 153, in step
[rank0]:     step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/strategies/ddp.py", line 270, in optimizer_step
[rank0]:     optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 238, in optimizer_step
[rank0]:     return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py", line 93, in optimizer_step
[rank0]:     step_output = self.scaler.step(optimizer, **kwargs)  # type: ignore[arg-type]
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torch/amp/grad_scaler.py", line 451, in step
[rank0]:     len(optimizer_state["found_inf_per_device"]) > 0
[rank0]: AssertionError: No inf checks were recorded for this optimizer.

However if I remove the PrecisionPlugin from my config file (in my previous comment), and instead set precision: bf16-mixed, this code runs correctly.

Out of curiousity I tried that same configuration (precision: bf16-mixed, no PrecisionPlugin) with the torchjd code and found that it still gives the original error NotImplementedError: Cannot access storage of BatchedTensorImpl

        opt = self.optimizers()
        opt.zero_grad()

        scaler = self.trainer.strategy.precision_plugin.scaler
        if scaler is not None:
            multiclass_loss = scaler.scale(multiclass_loss)
            binary_loss = scaler.scale(binary_loss)
        torchjd.backward([multiclass_loss, binary_loss], self.model.parameters(), self.aggregator)

        opt.step()