cientgu / VQ-Diffusion

MIT License
439 stars 43 forks source link

'DiffusionTransformer' object has no attribute 'cf_predict_start' #16

Open songxueXS opened 2 years ago

songxueXS commented 2 years ago

Hi author, thanks for sharing your inspiring work! I'm trying to make re-implementation of Improved VQ-Diffusion. I have the problem.

Traceback (most recent call last): File "/env/anaconda3-pytorch1.7-python3.7-cuda11/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap fn(i, args) File "/code/VQ-Diffusion-Improved_VQ-Diffusion/image_synthesis/distributed/launch.py", line 93, in distributed_worker fn(local_rank, args) File "/code/VQ-Diffusion-Improved_VQ-Diffusion/train.py", line 168, in main_worker solver.train() File "/code/VQ-Diffusion-Improved_VQ-Diffusion/image_synthesis/engine/solver.py", line 546, in train self.train_epoch() File "/code/VQ-Diffusion-Improved_VQ-Diffusion/image_synthesis/engine/solver.py", line 463, in train_epoch self.sample(batch, phase='train', step_type='iteration') File "/code/VQ-Diffusion-Improved_VQ-Diffusion/image_synthesis/engine/solver.py", line 198, in sample samples = model.sample(batch=batch, step=self.last_iter) File "/env/anaconda3-pytorch1.7-python3.7-cuda11/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context return func(args, kwargs) File "/code/VQ-Diffusion-Improved_VQ-Diffusion/image_synthesis/modeling/models/dalle.py", line 303, in sample kwargs) File "/code/VQ-Diffusion-Improved_VQ-Diffusion/image_synthesis/modeling/transformers/diffusion_transformer.py", line 592, in sample log_z, sampled = self.p_sample(log_z, cond_emb, t, sampled, self.n_sample[diffusion_index]) # log_z is log_onehot File "/env/anaconda3-pytorch1.7-python3.7-cuda11/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context return func(args, **kwargs) File "/code/VQ-Diffusion-Improved_VQ-Diffusion/image_synthesis/modeling/transformers/diffusion_transformer.py", line 278, in p_sample model_log_prob, log_x_recon = self.p_pred(log_x, cond_emb, t) File "/code/VQ-Diffusion-Improved_VQ-Diffusion/image_synthesis/modeling/transformers/diffusion_transformer.py", line 267, in p_pred log_x_recon = self.cf_predict_start(log_x, cond_emb, t) File "/env/anaconda3-pytorch1.7-python3.7-cuda11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 779, in getattr type(self).name, name)) torch.nn.modules.module.ModuleAttributeError: 'DiffusionTransformer' object has no attribute 'cf_predict_start'

koutilya-pnvr commented 2 years ago

https://github.com/cientgu/VQ-Diffusion/blob/fe79083818b47d4d376ab9579ec19cba2a43c3cb/image_synthesis/modeling/transformers/diffusion_transformer.py#L267

More precisely, this is the code line from the Improved_VQ-Diffusion branch. The cf_predict_start function is not defined in the DiffusionTransformer class. Is it the same as the one from the dalle class? https://github.com/cientgu/VQ-Diffusion/blob/fe79083818b47d4d376ab9579ec19cba2a43c3cb/image_synthesis/modeling/models/dalle.py#L170

songxueXS commented 2 years ago

https://github.com/cientgu/VQ-Diffusion/blob/fe79083818b47d4d376ab9579ec19cba2a43c3cb/image_synthesis/modeling/transformers/diffusion_transformer.py#L267

More precisely, this is the code line from the Improved_VQ-Diffusion branch. The cf_predict_start function is not defined in the DiffusionTransformer class. Is it the same as the one from the dalle class?

https://github.com/cientgu/VQ-Diffusion/blob/fe79083818b47d4d376ab9579ec19cba2a43c3cb/image_synthesis/modeling/models/dalle.py#L170 Thank you very much for your reply, but I find they are not exactly the same and 'def cf_predict_start(log_x_t, cond_emb, t)' depends on many parameters. Can you give a full version?

tzco commented 2 years ago

Sorry for late reply. The cf_predict_start function is defined in generate_content in dalle.py for classifier-free sampling, but while training we don't need classifier-free sampling and it is not executed. By adding a placeholder in diffusion_transformer.py it should be solved:

def cf_predict_start(self, log_x_t, cond_emb, t):
    return self.predict_start(log_x_t, cond_emb, t)