phizaz / diffae

Official implementation of Diffusion Autoencoders
https://diff-ae.github.io/
MIT License
847 stars 127 forks source link

Some error in the evaluation stage #3

Open kingofprank opened 2 years ago

kingofprank commented 2 years ago

Thanks for your amazing work! I try to train on customized dataset, when after 2 days training, I got some error in the calcuation of lpips, the traceback as follow:

  File "/share_graphics_ai/linminxuan/Workspace/diffusion-models/diffae/experiment.py", line 938, in train
    trainer.fit(model)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 552, in fit
    self._run(model)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 917, in _run
    self._dispatch()
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 985, in _dispatch
    self.accelerator.start_training(self)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
    self._results = trainer.run_stage()
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 995, in run_stage
    return self._run_train()
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 1044, in _run_train
    self.fit_loop.run()
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/loops/fit_loop.py", line 200, in advance
    epoch_output = self.epoch_loop.run(train_dataloader)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 150, in advance
    "on_train_batch_end", processed_batch_end_outputs, batch, self.iteration_count, self._dataloader_idx
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 1226, in call_hook
    output = hook_fx(*args, **kwargs)
  File "/share_graphics_ai/linminxuan/Workspace/diffusion-models/diffae/experiment.py", line 431, in on_train_batch_end
    self.evaluate_scores()
  File "/share_graphics_ai/linminxuan/Workspace/diffusion-models/diffae/experiment.py", line 622, in evaluate_scores
    lpips(self.model, '')
  File "/share_graphics_ai/linminxuan/Workspace/diffusion-models/diffae/experiment.py", line 611, in lpips
    latent_sampler=self.eval_latent_sampler)
  File "/share_graphics_ai/linminxuan/Workspace/diffusion-models/diffae/metrics.py", line 111, in evaluate_lpips
    latent_sampler=latent_sampler)
TypeError: render_condition() got an unexpected keyword argument 'latent_sampler'

There is no "latent_sampler" keyword in "render_condition" function, I guess the latent_sampler should use in the "render_uncondition" case. Should I delete this key?

pitchayagan commented 2 years ago

I found the same error, and after reading the author's paper I decided that it's safe to delete the argument. I made few changes in the code as below. metrics.py -> remove argument "latent_sampler" from "render_condition()" at line 111 and line 296 renderer.py -> change "model_kwargs={'cond': cond}" to "cond=cond['cond'])" at line 56 Though my training has not finished yet, so I cannot confirm if there's really no problem with these changes. Please try.

phizaz commented 2 years ago

My bad. You could delete this key. It's fixed in https://github.com/phizaz/diffae/commit/a8f1c246f08e3bffdf14173aebd604d1ca0fb28e. Regarding the renderer.py#56 mentioned by @pitchayagan, I believe both are correct (yours and mine) there is no need to change it.

matanat commented 2 years ago

When running the latest code on my own dataset, I get the following error. After I make the changes in @pitchayagan comment above, everything seems to be running smoothly.

File "/home/matan/diffae/run_my_dataset.py", line 9, in <module>
    train(conf, gpus=gpus)     
File "/home/matan/diffae/experiment.py", line 963, in train
    trainer.fit(model)
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 552, in fit
    self._run(model)
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 917, in _run
    self._dispatch()
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 985, in _dispatch
    self.accelerator.start_training(self)
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
    self.training_type_plugin.start_training(trainer)
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
    self._results = trainer.run_stage()
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 995, in run_stage
    return self._run_train()
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1044, in _run_train
    self.fit_loop.run()
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 200, in advance
    epoch_output = self.epoch_loop.run(train_dataloader)
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 149, in advance
    self.trainer.call_hook(
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1226, in call_hook
    output = hook_fx(*args, **kwargs)
File "/home/matan/diffae/experiment.py", line 438, in on_train_batch_end
    self.evaluate_scores()
File "/home/matan/diffae/experiment.py", line 641, in evaluate_scores
    lpips(self.model, '')
File "/home/matan/diffae/experiment.py", line 622, in lpips
    score = evaluate_lpips(self.eval_sampler,
File "/home/matan/diffae/metrics.py", line 107, in evaluate_lpips
    pred_imgs = render_condition(conf=conf,
File "/home/matan/diffae/renderer.py", line 56, in render_condition
    return sampler.sample(model=model,
File "/home/matan/diffae/diffusion/base.py", line 208, in sample
    return self.ddim_sample_loop(model,
File "/home/matan/diffae/diffusion/base.py", line 735, in ddim_sample_loop
    for sample in self.ddim_sample_loop_progressive(
File "/home/matan/diffae/diffusion/base.py", line 795, in ddim_sample_loop_progressive
    out = self.ddim_sample(
File "/home/matan/diffae/diffusion/base.py", line 600, in ddim_sample
    out = self.p_mean_variance(
File "/home/matan/diffae/diffusion/diffusion.py", line 96, in p_mean_variance
    return super().p_mean_variance(self._wrap_model(model), *args,
File "/home/matan/diffae/diffusion/base.py", line 307, in p_mean_variance
    model_forward = model.forward(x=x,
File "/home/matan/diffae/diffusion/diffusion.py", line 153, in forward
    return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs)
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
 File "/home/matan/diffae/model/unet_autoenc.py", line 211, in forward
    h = self.input_blocks[k](h,
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
File "/home/matan/diffae/model/blocks.py", line 39, in forward
    x = layer(x, emb=emb, cond=cond, lateral=lateral)
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
File "/home/matan/diffae/model/blocks.py", line 193, in forward
    return torch_checkpoint(self._forward, (x, emb, cond, lateral),
File "/home/matan/diffae/model/nn.py", line 137, in torch_checkpoint
    return func(*args)
File "/home/matan/diffae/model/blocks.py", line 238, in _forward
    cond_out = self.cond_emb_layers(cond).type(h.dtype)
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/torch/nn/modules/container.py", line 119, in forward
    input = module(input)
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 394, in forward
    return F.silu(input, inplace=self.inplace)
File "/home/matan/miniconda3/envs/diffae/lib/python3.9/site-packages/torch/nn/functional.py", line 1796, in silu
    return torch._C._nn.silu(input)
TypeError: silu(): argument 'input' (position 1) must be Tensor, not dict
betterze commented 2 years ago

I also meet this error 'TypeError: silu(): argument 'input' (position 1) must be Tensor, not dict'

betterze commented 2 years ago

I believe the problem comes from inconsistency when defining the return of the function,

type(model.encode(imgs))
Out[7]: torch.Tensor

type(model.model.encode(imgs))
Out[8]: dict

type(model.ema_model.encode(imgs) )
Out[9]: dict

The 'model' represent the LitModel class, its encode output is a tensor 'cond'. The 'model.model' is a BeatGANsAutoencModel class, its encode output is a dict '{'cond': cond}'.

The only place uses this dict structure is in unet_autoenc, which extract the cond directly,

            tmp = self.encode(x_start)
            cond = tmp['cond']

If we use the dict structure, we not only need to change the renderer.py as mentioned by @pitchayagan , but also change the inference codes in colab. Instead, I prefer to use the tensor directly, and we only need to do:

change line 151-152 in [unet_autoenc]((https://github.com/phizaz/diffae/blob/master/model/unet_autoenc.py#L152) to

            #tmp = self.encode(x_start)
            #cond = tmp['cond']
            tmp = self.encode(x_start)

and change line 85 in unet_autoenc to

        # return {'cond': cond}
        return cond

Then the training and inference codes work for me.

mdv3101 commented 1 year ago

I believe the problem comes from inconsistency when defining the return of the function,

type(model.encode(imgs))
Out[7]: torch.Tensor

type(model.model.encode(imgs))
Out[8]: dict

type(model.ema_model.encode(imgs) )
Out[9]: dict

The 'model' represent the LitModel class, its encode output is a tensor 'cond'. The 'model.model' is a BeatGANsAutoencModel class, its encode output is a dict '{'cond': cond}'.

The only place uses this dict structure is in unet_autoenc, which extract the cond directly,

            tmp = self.encode(x_start)
            cond = tmp['cond']

If we use the dict structure, we not only need to change the renderer.py as mentioned by @pitchayagan , but also change the inference codes in colab. Instead, I prefer to use the tensor directly, and we only need to do:

change line 151-152 in [unet_autoenc]((https://github.com/phizaz/diffae/blob/master/model/unet_autoenc.py#L152) to

            #tmp = self.encode(x_start)
            #cond = tmp['cond']
            tmp = self.encode(x_start)

and change line 85 in unet_autoenc to

        # return {'cond': cond}
        return cond

Then the training and inference codes work for me.

Just a small edit: change line 151-152 in [unet_autoenc]((https://github.com/phizaz/diffae/blob/master/model/unet_autoenc.py#L152) to

            #tmp = self.encode(x_start)
            #cond = tmp['cond']
            cond = self.encode(x_start)