Zasder3 / train-CLIP

A PyTorch Lightning solution to training OpenAI's CLIP from scratch.
MIT License
653 stars 78 forks source link

precision convert problem #32

Closed VincentLeeMax closed 2 years ago

VincentLeeMax commented 2 years ago

problem

when image_logits.dtype == torch.half and len(image_logits) == 4096, I encount this problem:

Traceback (most recent call last):
  File "/workspace/train-CLIP/train_finetune.py", line 46, in <module>
    main(args)
  File "/workspace/train-CLIP/train_finetune.py", line 33, in main
    trainer.fit(model, dm)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 741, in fit
    self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 1199, in _run
    self._dispatch()
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 1279, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 1289, in run_stage
    return self._run_train()
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 1319, in _run_train
    self.fit_loop.run()
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
    self.epoch_loop.run(data_fetcher)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 193, in advance
    batch_output = self.batch_loop.run(batch, batch_idx)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 90, in advance
    outputs = self.manual_loop.run(split_batch, batch_idx)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/optimization/manual_loop.py", line 111, in advance
    training_step_output = self.trainer.accelerator.training_step(step_kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py", line 219, in training_step
    return self.training_type_plugin.training_step(*step_kwargs.values())
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/ddp.py", line 439, in training_step
    return self.model(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/distributed.py", line 886, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/overrides/base.py", line 81, in forward
    output = self.module.training_step(*inputs, **kwargs)
  File "/workspace/train-CLIP/models/wrapper.py", line 452, in training_step
    loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth)).div(2)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 2846, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: CUDA error: device-side assert triggered

reproducing

It seems that int64 -> half -> long will have some problem, changing 4095 to 4096 for example :

> python3 -c "import torch; a = torch.rand(4096, 4096, dtype=torch.half).cuda(); print(torch.arange(len(a)).type_as(a).long())"

tensor([   0,    1,    2,  ..., 4092, 4094, 4096], device='cuda:0')

solution

Use to() instead of type_as to change the device of ground_truth