Closed jaceqin closed 1 year ago
Could you print _xt.shape and eps.shape in the indicated line?
eps.shape=torch.Size([1, 2, 224, 224]);x_t.shape=torch.Size([1, 1, 224, 224])
Do you know why eps.shape and x_t.shape are differen?I use your data and code.
what is your "learn_sigma" flag? did you set it to True
or False
?
learn_sigma is False
ok, try to set all flags according to the README file. There, we have learn_sigma=True
OK.Thank you!
Traceback (most recent call last): File "scripts/segmentation_sample.py", line 125, in
main()
File "scripts/segmentation_sample.py", line 97, in main
model_kwargs=model_kwargs,
File "./guided_diffusion/gaussian_diffusion.py", line 524, in p_sample_loop_known
progress=progress,
File "./guided_diffusion/gaussian_diffusion.py", line 586, in p_sample_loop_progressive
model_kwargs=model_kwargs,
File "./guided_diffusion/gaussian_diffusion.py", line 433, in p_sample
model_kwargs=model_kwargs,
File "./guided_diffusion/respace.py", line 90, in p_mean_variance
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
File "./guided_diffusion/gaussian_diffusion.py", line 310, in p_mean_variance
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
File "./guided_diffusion/gaussian_diffusion.py", line 331, in _predict_xstart_from_eps
assert x_t.shape == eps.shape
AssertionError