Open rwesterman opened 2 days ago
Thanks for the issue, we are looking into this.
In the meantime, could you provide a minimal calling code?
Can you try the following things and report the behavior?
losses
and call torch.autograd.backward(losses, torch.ones_like(losses))
instead of torchjd.backward
parallel_chunk_size=1
to your call to torchjd.backward
?training_step
:
multiclass_out
and binary_out
multiclass_loss
binary_loss
It feels like the bug is coming from torch
, possibly because of the bad combination of the @compile
tag used in torch lightning and our use of vmap
.
Did you compile any of the models that you provide to MinimalModule
?
Possibly related:
Thank you for the quick response!
Everything is called through LightningCLI
from lightning.pytorch.cli import LightningCLI
def cli_main():
cli = LightningCLI(
parser_kwargs={"parser_mode": "omegaconf"},
seed_everything_default=42,
)
if __name__ == "__main__":
cli_main()
It reads from a config file like this
# lightning.pytorch==2.2.3
trainer:
accelerator: auto
strategy: ddp
devices: auto
num_nodes: 1
precision: null
logger: null
plugins:
class_path: lightning.pytorch.plugins.precision.MixedPrecision
init_args:
precision: 16-mixed
device: cuda
scaler:
class_path: torch.cuda.amp.GradScaler
init_args:
init_scale: 1.0
callbacks:
- class_path: lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor
init_args:
logging_interval: step
fast_dev_run: false
max_epochs: 10
max_steps: -1
check_val_every_n_epoch: 1
log_every_n_steps: 50
enable_checkpointing: false
enable_progress_bar: false
accumulate_grad_batches: 1
gradient_clip_val: null
gradient_clip_algorithm: null
inference_mode: false
reload_dataloaders_every_n_epochs: 0
ckpt_path: null
model:
class_path: lid.training.minimal_example.MinimalModule
init_args:
embedding_model:
class_path: lid.model.ECAPA_TDNN
init_args:
input_size: 80
conv_config_groups: # (num_channel, kernel_size, dilation)
- [448, 5, 1]
- [448, 3, 2]
- [448, 3, 3]
- [448, 3, 4]
- [1344, 1, 1]
attention_channels: null
res2net_scale: null
se_channels: null
output_size: null
global_context: true
classifier_model:
class_path: lid.model.MultiTaskClassifier
init_args:
input_size: 192
first_out_dim: 36
second_out_dim: 36
lr: 0.001
labels:
- "en"
- "de"
- "fr"
- "es"
data:
class_path: lid.data_processing.xscp_data_module.FeatureDataModule
init_args:
data_dir: null
labels: ${model.init_args.labels}
label_name: "language"
batch_size: 512
fp16: true
num_workers: 32
batch_weight: null # Need to provide this in train script
stack the two losses into one vector called losses and call torch.autograd.backward(losses, torch.ones_like(losses)) instead of torchjd.backward
This first throws an AssertionError about not using scaler
[rank1]: AssertionError: Attempted unscale_ but _scale is None. This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration.
If I manually use scaler like this
scaler = self.trainer.strategy.precision_plugin.scaler
if scaler is not None:
losses = scaler.scale(losses)
torch.autograd.backward(losses, torch.ones_like(losses))
then the code runs without error.
give the parameter parallel_chunk_size=1 to your call to torchjd.backward?
This gives the error:
[rank1]: ValueError: When using `retain_graph=False`, parameter `parallel_chunk_size` must be `None` or large enough to compute all gradients in parallel.
And after setting retain_graph=True
, the error again becomes
[rank1]: File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torch/autograd/graph.py", line 769, in _engine_run_backward
[rank1]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: NotImplementedError: Cannot access storage of BatchedTensorImpl
Print the type of the following tensors in training_step: multiclass_out and binary_out multiclass_loss binary_loss
Here are the types, dtypes, and grad_fns of the outputs and losses:
Type of multiclass_out: <class 'torch.Tensor'>
Type of binary_out: <class 'torch.Tensor'>
dtype of multiclass_out: torch.float16
dtype of binary_out: torch.float16
grad_fn of multiclass_out: <AddmmBackward0 object at 0x7fc6ae199e70>
grad_fn of binary_out: <AddmmBackward0 object at 0x7fc6ae199e70>
Type of multiclass_loss: <class 'torch.Tensor'>
Type of binary_loss: <class 'torch.Tensor'>
dtype of multiclass_loss: torch.float32
dtype of binary_loss: torch.float32
grad_fn of multiclass_loss: <DivBackward1 object at 0x7fc6ae199e70>
grad_fn of binary_loss: <BinaryCrossEntropyWithLogitsBackward0 object at 0x7fc6ae199e70>
Did you compile any of the models that you provide to MinimalModule?
Not manually, but it's possible that Lightning is attempting to compile somewhere.
I think I'd see grad_fn=<CompiledFunctionBackward>
as the grad function for the losses or model outputs but I don't see it for any of them.
Thank you for the detailed update!
One other thing that would be helpful is to replace the call to backward with the following:
losses = torch.stack([multiclass_loss, binary_loss])
ones = torch.ones_like(losses)
param_list = list(self.model.parameters())
_ = torch.autograd.grad([losses], param_list, [ones])
I realized that I didn't specify package versions in my original post so I've updated that and I'll also post them here:
python: 3.11.9
torch: 2.4.1+cu121
lightning: 2.4.0
torchjd: 0.2.1
For the code you posted above:
opt.zero_grad()
losses = torch.stack([multiclass_loss, binary_loss])
ones = torch.ones_like(losses)
param_list = list(self.model.parameters())
scaler = self.trainer.strategy.precision_plugin.scaler
if scaler is not None:
losses = scaler.scale(losses)
_ = torch.autograd.grad([losses], param_list, [ones])
opt.step()
yields the following error from opt.step()
[rank0]: File "/home/ryan.westerman/sap/lid/training/minimal_example.py", line 73, in training_step
[rank0]: opt.step()
[rank0]: File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/core/optimizer.py", line 153, in step
[rank0]: step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/strategies/ddp.py", line 270, in optimizer_step
[rank0]: optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 238, in optimizer_step
[rank0]: return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ryan.westerman/.local/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/amp.py", line 93, in optimizer_step
[rank0]: step_output = self.scaler.step(optimizer, **kwargs) # type: ignore[arg-type]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ryan.westerman/.local/lib/python3.11/site-packages/torch/amp/grad_scaler.py", line 451, in step
[rank0]: len(optimizer_state["found_inf_per_device"]) > 0
[rank0]: AssertionError: No inf checks were recorded for this optimizer.
However if I remove the PrecisionPlugin
from my config file (in my previous comment), and instead set precision: bf16-mixed
, this code runs correctly.
Out of curiousity I tried that same configuration (precision: bf16-mixed
, no PrecisionPlugin
) with the torchjd code and found that it still gives the original error NotImplementedError: Cannot access storage of BatchedTensorImpl
opt = self.optimizers()
opt.zero_grad()
scaler = self.trainer.strategy.precision_plugin.scaler
if scaler is not None:
multiclass_loss = scaler.scale(multiclass_loss)
binary_loss = scaler.scale(binary_loss)
torchjd.backward([multiclass_loss, binary_loss], self.model.parameters(), self.aggregator)
opt.step()
I am trying to use TorchJD with Pytorch Lightning, which requires shutting off the manual optimization loop in lightning. My code looks like the minimal example in the documentation here except that I am having to override
manual_backward()
in order to runtorchjd.backward()
.I have tried a few approaches to integrate the two libraries, but no matter what I do, I get an error:
Here is a minimal example of my code that causes the error:
And finally the full traceback
Version information: