var_indexed = torch.from_numpy(var)[index].float() this line will result in an error with torch==1.13.0+cu117.
because the tensor torch.from_numpy(var) and index index may not on the same device.
Change to var_indexed = torch.from_numpy(var)[index.cpu()].float() resolves the issue.
Screenshots
Screenshots after executing python tools/train.py configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_cifar10_32x32_b8x16_500k.py.
Hey @zeakey thanks for you contribution. Can you add a screenshot in the description to proof the correctness of the modification? Then this PR can be merged.
var_indexed = torch.from_numpy(var)[index].float()
this line will result in an error withtorch==1.13.0+cu117
. because the tensortorch.from_numpy(var)
and indexindex
may not on the same device.Change to
var_indexed = torch.from_numpy(var)[index.cpu()].float()
resolves the issue.Screenshots
Screenshots after executing
python tools/train.py configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_cifar10_32x32_b8x16_500k.py
.Before modification:
After modification