Closed forever208 closed 7 months ago
@forever208 You may want to edit line 48 class_labels = [...]
and line 57 y_null = torch.tensor([...] * n
, device=device) according to your num_class
@Gokaii Thanks!
@forever208 You may want to edit line 48
class_labels = [...]
and line 57y_null = torch.tensor([...] * n
, device=device) according to your num_class
Hello, Can you please tell in which file I should do this?
@forever208 You may want to edit line 48
class_labels = [...]
and line 57y_null = torch.tensor([...] * n
, device=device) according to your num_classHello, Can you please tell in which file I should do this?
just sample.py
@forever208 You may want to edit line 48
class_labels = [...]
and line 57y_null = torch.tensor([...] * n
, device=device) according to your num_classHello, Can you please tell in which file I should do this?
just
sample.py
Thank you, it worked!
I trained the DiT-B/2 in my custom dataset. But got the errors when doing sampling using the ckpt
Traceback (most recent call last): File "/gpfs/home5/mning/DiT/sample_ddp.py", line 167, in
main(args)
File "/gpfs/home5/mning/DiT/sample_ddp.py", line 125, in main
samples = diffusion.p_sample_loop(
File "/gpfs/home5/mning/DiT/diffusion/gaussian_diffusion.py", line 465, in p_sample_loop
for sample in self.p_sample_loop_progressive(
File "/gpfs/home5/mning/DiT/diffusion/gaussian_diffusion.py", line 516, in p_sample_loop_progressive
out = self.p_sample(
File "/gpfs/home5/mning/DiT/diffusion/gaussian_diffusion.py", line 417, in p_sample
out = self.p_mean_variance(
File "/gpfs/home5/mning/DiT/diffusion/respace.py", line 92, in p_mean_variance
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
File "/gpfs/home5/mning/DiT/diffusion/gaussian_diffusion.py", line 301, in p_mean_variance
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
File "/gpfs/home5/mning/DiT/diffusion/gaussian_diffusion.py", line 885, in _extract_into_tensor
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Any help would be appreciated!