facebookresearch / DiT

Official PyTorch Implementation of "Scalable Diffusion Models with Transformers"
Other
6.37k stars 569 forks source link

sample_ddp failed (CUDA error: device-side assert triggered) #75

Closed forever208 closed 7 months ago

forever208 commented 8 months ago

I trained the DiT-B/2 in my custom dataset. But got the errors when doing sampling using the ckpt

Traceback (most recent call last): File "/gpfs/home5/mning/DiT/sample_ddp.py", line 167, in main(args) File "/gpfs/home5/mning/DiT/sample_ddp.py", line 125, in main samples = diffusion.p_sample_loop( File "/gpfs/home5/mning/DiT/diffusion/gaussian_diffusion.py", line 465, in p_sample_loop for sample in self.p_sample_loop_progressive( File "/gpfs/home5/mning/DiT/diffusion/gaussian_diffusion.py", line 516, in p_sample_loop_progressive out = self.p_sample( File "/gpfs/home5/mning/DiT/diffusion/gaussian_diffusion.py", line 417, in p_sample out = self.p_mean_variance( File "/gpfs/home5/mning/DiT/diffusion/respace.py", line 92, in p_mean_variance return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) File "/gpfs/home5/mning/DiT/diffusion/gaussian_diffusion.py", line 301, in p_mean_variance min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) File "/gpfs/home5/mning/DiT/diffusion/gaussian_diffusion.py", line 885, in _extract_into_tensor res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() RuntimeError: CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

Any help would be appreciated!

Gokaii commented 7 months ago

@forever208 You may want to edit line 48 class_labels = [...] and line 57 y_null = torch.tensor([...] * n, device=device) according to your num_class

forever208 commented 7 months ago

@Gokaii Thanks!

artemi8 commented 3 months ago

@forever208 You may want to edit line 48 class_labels = [...] and line 57 y_null = torch.tensor([...] * n, device=device) according to your num_class

Hello, Can you please tell in which file I should do this?

Gokaii commented 3 months ago

@forever208 You may want to edit line 48 class_labels = [...] and line 57 y_null = torch.tensor([...] * n, device=device) according to your num_class

Hello, Can you please tell in which file I should do this?

just sample.py

artemi8 commented 3 months ago

@forever208 You may want to edit line 48 class_labels = [...] and line 57 y_null = torch.tensor([...] * n, device=device) according to your num_class

Hello, Can you please tell in which file I should do this?

just sample.py

Thank you, it worked!