shuuchen / DetCo.pytorch

A PyTorch implementation of DetCo https://arxiv.org/pdf/2102.04803.pdf
MIT License
23 stars 2 forks source link

Tensor dimension issue occurring partway through training #8

Open olliestanley opened 2 years ago

olliestanley commented 2 years ago

Hi, I have encountered a weird issue when attempting to train a DetCo (ResNet18 backbone) model. In short, the model trained perfectly well for 9 epochs, with loss on a downwards trend. Then part-way through the 10th epoch, an error occurred:

RuntimeError('The expanded size of the tensor (8) must match the existing size (16) at non-singleton dimension 2. Target sizes: [8, 128, 8]. Tensor sizes: [8, 128, 16]')

Running it again with the code wrapped in try-except to validate, confirms that after this first occurs it occurs for every forward call made from that point onwards, on all of the GPUs. The stack trace points to the line

self.queue[:, :, ptr:ptr + batch_size] = keys.permute(1, 2, 0)

in this function, which is called at the end of the forward function:

https://github.com/shuuchen/DetCo.pytorch/blob/b4591b9cc3bb7ad1777a369079d7fdb0a7f78a15/detco/builder.py#L48-L62

It seems like this cannot be a problem with the data, as the exact same data have been passed through the model several times previously and following the point it occurs, it occurs for every batch. I believe the exact point it occurs to be different each run, as if I remember correctly it occurred after 8 epochs the first time I encountered it but after 9 the second time.

I wondered whether it could be due to using 4 GPUs while the code was tested with 8, but this would still not really explain it only beginning part-way through training. Running continual experiments to test this or other hypotheses would require a lot of GPU-time, so I decided to post here first in case you have any insight. Thanks!

shuuchen commented 2 years ago

Hi,

Have you tested it on 4 GPUs ? If you tested it on 8 GPUs, you might change this:

https://github.com/shuuchen/DetCo.pytorch/blob/main/main_detco.py#L136

and then check whether each GPU is working correctly by nvidia-smi.

olliestanley commented 2 years ago

Thanks for your response! I am using an instance with 4 GPUs and am attempting to train using all 4 of them currently.

I am using slightly different launch code as I only needed the single-node DDP functionality, but when my worker function is called (via PyTorch multiprocessing's spawn function) I use the passed gpu argument as the value passed into set_device(), cuda(), device_ids=[] etc.

Based on that code it looks like you are taking the process ID passed as gpu to your worker function by spawn, and adding 4 to it. I thought this was because you have 8 GPUs but want to train using 4 of them, so are spawning torch.cuda.device_count() // 2 processes, and then using GPUs 4 through 7. Is this understanding not correct? On my end I have been spawning 4 processes and using all GPUs 0 through 3.

I am also not sure if a GPU problem like this would explain the issue, as if the distributed training were improperly configured I imagine we would expect a failure much earlier in the process?

shuuchen commented 2 years ago

Yes. I used GPUs 4 through 7, with the 4th as master GPU. If your machine contains 4 GPUs, just use torch.cuda.device_count() and set the master GPU to 0.

olliestanley commented 2 years ago

I believe that lines up with what I am doing currently in that case, so shouldn't be related to the problem. I have now tried with another much smaller dataset on the same GPU configuration, and was able to make it past 10 epochs no problem both times. So perhaps it is a data-related issue somehow