when image_logits.dtype == torch.half and len(image_logits) == 4096, I encount this problem:
Traceback (most recent call last):
File "/workspace/train-CLIP/train_finetune.py", line 46, in <module>
main(args)
File "/workspace/train-CLIP/train_finetune.py", line 33, in main
trainer.fit(model, dm)
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 741, in fit
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 1199, in _run
self._dispatch()
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 1279, in _dispatch
self.training_type_plugin.start_training(self)
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 1289, in run_stage
return self._run_train()
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 1319, in _run_train
self.fit_loop.run()
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
self.epoch_loop.run(data_fetcher)
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 193, in advance
batch_output = self.batch_loop.run(batch, batch_idx)
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 90, in advance
outputs = self.manual_loop.run(split_batch, batch_idx)
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 111, in advance
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py", line 219, in training_step
return self.training_type_plugin.training_step(*step_kwargs.values())
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/ddp.py", line 439, in training_step
return self.model(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/distributed.py", line 886, in forward
output = self.module(*inputs[0], **kwargs[0])
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/overrides/base.py", line 81, in forward
output = self.module.training_step(*inputs, **kwargs)
File "/workspace/train-CLIP/models/wrapper.py", line 452, in training_step
loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth)).div(2)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 2846, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: CUDA error: device-side assert triggered
reproducing
It seems that int64 -> half -> long will have some problem, changing 4095 to 4096 for example :
problem
when
image_logits.dtype == torch.half and len(image_logits) == 4096
, I encount this problem:reproducing
It seems that
int64 -> half -> long
will have some problem, changing 4095 to 4096 for example :solution
Use
to()
instead oftype_as
to change the device ofground_truth