Closed geoffreyangus closed 9 months ago
Got the following error when training Mixstral-7b in a multi-GPU setting:
tensor at position 96: saved metadata: {'shape': torch.Size([142]), 'dtype': torch.int64, 'device': device(type='cuda', index=1)} recomputed metadata: {'shape': torch.Size([143]), 'dtype': torch.int64, 'device': device(type='cuda', index=1)} tensor at position 97: saved metadata: {'shape': torch.Size([142]), 'dtype': torch.int64, 'device': device(type='cuda', index=1)} recomputed metadata: {'shape': torch.Size([143]), 'dtype': torch.int64, 'device': device(type='cuda', index=1)} ...
This can be resolved by setting reentrant to True ensures
Reentrant checkpoint always recomputes function in its entirety during the backward pass. Source: https://pytorch.org/docs/stable/checkpoint.html
So shape mismatches should be prevented.
Got the following error when training Mixstral-7b in a multi-GPU setting:
This can be resolved by setting reentrant to True ensures
So shape mismatches should be prevented.