open-mmlab / mmgeneration

MMGeneration is a powerful toolkit for generative models, based on PyTorch and MMCV.
https://mmgeneration.readthedocs.io/en/latest/
Apache License 2.0
1.91k stars 232 forks source link

fix tensor and index aren't on the same device error #476

Closed zeakey closed 2 years ago

zeakey commented 2 years ago

var_indexed = torch.from_numpy(var)[index].float() this line will result in an error with torch==1.13.0+cu117. because the tensor torch.from_numpy(var) and index index may not on the same device.

Change to var_indexed = torch.from_numpy(var)[index.cpu()].float() resolves the issue.

Screenshots

Screenshots after executing python tools/train.py configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_cifar10_32x32_b8x16_500k.py.

Before modification:

image

After modification

image

LeoXing1996 commented 2 years ago

Hey @zeakey thanks for you contribution. Can you add a screenshot in the description to proof the correctness of the modification? Then this PR can be merged.

zeakey commented 2 years ago

@LeoXing1996 added.