GaParmar / clean-fid

PyTorch - FID calculation with proper image resizing and quantization steps [CVPR 2022]
https://www.cs.cmu.edu/~clean-fid/
MIT License
971 stars 74 forks source link

cuda device mismatch in `DataParallel` when not using `cuda:0` #60

Open janfb opened 6 months ago

janfb commented 6 months ago

Hi there, thanks for this package, it's really helpful!

On a cluster with multiple GPUs, I have my model on device cuda:1.

When calculating FID with a passed gen function, new samples are generated during FID calculation. To that end, a model_fn(x) function is defined here: https://github.com/GaParmar/clean-fid/blob/bd44693af04626963af76e94bdb1d4529a76bd11/cleanfid/features.py#L23-L25

and if use_dataparallel=True, the model will be wrapped with model = torch.nn.DataParallel(model).

Problem: DataParallel has a kwarg device_ids=None which defaults to all the available devices and then selects the first device as the "source" device, i.e., cuda:0. Later it asserts that all parameters and buffers of the model are on that device. Now, if device_ids is not passed, this will result in an error because my model device is different from cuda:0. I am wondering why DataParallel just hard codes everything to the first of all available devices, but there is a solution on the cleanfid side for this problem.

Solution: pass device_ids with the device of the model:

        if use_dataparallel:
            device_ids = [torch.cuda.current_device()]  # or use next(model.parameters()).device
            model = torch.nn.DataParallel(model, device_ids=device_ids)
        def model_fn(x): return model(x)

I would be happy to make a PR fixing this. Unless I am missing something?

Cheers, Jan

GaParmar commented 5 months ago

Hi Jan,

Thank you for pointing this out! Feel free to make a PR. Your proposed solution makes a lot of sense, I will add it to the main repo!

-Gaurav