raymin0223 / patch-mix_contrastive_learning

Patch-Mix Contrastive Learning with Audio Spectrogram Transformer on Respiratory Sound Classification (INTERSPEECH 2023)
60 stars 11 forks source link

Distributed Run Code Error #4

Closed iffuture799 closed 1 year ago

iffuture799 commented 1 year ago

Hi @raymin0223, I have found that when I run code using multiple GPUs, the following errors occur:

Traceback (most recent call last): File "main.py", line 563, in main() File "main.py", line 531, in main loss, acc = train(train_loader, model, classifier, projector, criterion, optimizer, epoch, args, scaler) File "main.py", line 384, in train mix_images, labels_a, labels_b, lam, index = model(images, y=labels, patch_mix=True, time_domain=args.time_domain) File "/home/ygh/anaconda3/envs/pytorch20/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/ygh/anaconda3/envs/pytorch20/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 172, in forward return self.gather(outputs, self.output_device) File "/home/ygh/anaconda3/envs/pytorch20/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 184, in gather return gather(outputs, output_device, dim=self.dim) File "/home/ygh/anaconda3/envs/pytorch20/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 86, in gather res = gather_map(outputs) File "/home/ygh/anaconda3/envs/pytorch20/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 81, in gather_map return type(out)(map(gather_map, zip(outputs))) File "/home/ygh/anaconda3/envs/pytorch20/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 81, in gather_map return type(out)(map(gather_map, zip(*outputs))) TypeError: 'float' object is not iterable

I think it's caused by the parameter 'lam', so I have rewritten the code as follows:

image image

When mixing patches, I put the lam parameter into the tensor. When calculating the contrast loss, I perform the lam.mean() operation. Will this affect the calculation of contrast loss? If so, what should I do to solve the problem?

raymin0223 commented 1 year ago

Hi @iffuture799,

As all my lab servers have shutdowned today, I will see as soon as possible.

raymin0223 commented 1 year ago

I also found that there is a bug on lambda, as model = torch.nn.DataParallel(model) makes the outputs from model to be gathered. Thus, using lam.mean() will be dangerous because we cannot guarantee that lambda from multiple GPUs will be similar (they were different, actually).

Solutions may be using DDP or including loss calculation in model. I haven't checked DP with PatchMix-CL, so sorry for this bug.