AntixK / PyTorch-VAE

A Collection of Variational Autoencoders (VAE) in PyTorch.
Apache License 2.0
6.44k stars 1.05k forks source link

TypeError: loss = recons_loss + kld_weight * kld_loss #95

Open folasefo opened 3 months ago

folasefo commented 3 months ago

Hi When I use my dataset (3,192,192), and I change some parameters,

According to debug, these value before loss = recons_loss + kld_weight * kld_loss kld_loss: tensor(0.4086, device='cuda:1', grad_fn=) kld_weight: 0.00025d recons_loss: tensor(0.1564, device='cuda:1', grad_fn=)

/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/plugins/training_type/ LightningDeprecationWarning: The `pl.plugins.training_type.ddp.DDPPlugin` is deprecated in v1.6 and will be removed in v1.8. Use `pl.strategies.ddp.DDPStrategy` instead.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
======= Training VanillaVAE =======
Global seed set to 1265
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
All distributed processes registered. Starting with 1 processes


  | Name  | Type       | Params
0 | model | VanillaVAE | 3.2 M 
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params
12.678    Total estimated model params size (MB)
Epoch 0:   0%|                                                                                                                                    | 0/4 [00:00<?, ?it/s][rank0]: Traceback (most recent call last):
[rank0]:   File "/home/lulu/PyTorch-VAE/", line 64, in <module>
[rank0]:, datamodule=data)
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/trainer/", line 771, in fit
[rank0]:     self._call_and_handle_interrupt(
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/trainer/", line 722, in _call_and_handle_interrupt
[rank0]:     return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/strategies/launchers/", line 93, in launch
[rank0]:     return function(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/trainer/", line 812, in _fit_impl
[rank0]:     results = self._run(model, ckpt_path=self.ckpt_path)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/trainer/", line 1237, in _run
[rank0]:     results = self._run_stage()
[rank0]:               ^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/trainer/", line 1324, in _run_stage
[rank0]:     return self._run_train()
[rank0]:            ^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/trainer/", line 1354, in _run_train
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/loops/", line 204, in run
[rank0]:     self.advance(*args, **kwargs)
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/loops/", line 269, in advance
[rank0]:     self._outputs =
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/loops/", line 204, in run
[rank0]:     self.advance(*args, **kwargs)
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/loops/epoch/", line 208, in advance
[rank0]:     batch_output =, batch_idx)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/loops/", line 204, in run
[rank0]:     self.advance(*args, **kwargs)
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/loops/batch/", line 88, in advance
[rank0]:     outputs =, optimizers, batch_idx)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/loops/", line 204, in run
[rank0]:     self.advance(*args, **kwargs)
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/", line 203, in advance
[rank0]:     result = self._run_optimization(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/", line 256, in _run_optimization
[rank0]:     self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/", line 369, in _optimizer_step
[rank0]:     self.trainer._call_lightning_module_hook(
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/trainer/", line 1596, in _call_lightning_module_hook
[rank0]:     output = fn(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/core/", line 1625, in optimizer_step
[rank0]:     optimizer.step(closure=optimizer_closure)
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/core/", line 168, in step
[rank0]:     step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/strategies/", line 278, in optimizer_step
[rank0]:     optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/strategies/", line 193, in optimizer_step
[rank0]:     return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/plugins/precision/", line 155, in optimizer_step
[rank0]:     return optimizer.step(closure=closure, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/optim/", line 75, in wrapper
[rank0]:     return wrapped(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/optim/", line 391, in wrapper
[rank0]:     out = func(*args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/optim/", line 76, in _use_grad
[rank0]:     ret = func(self, *args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/optim/", line 148, in step
[rank0]:     loss = closure()
[rank0]:            ^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/plugins/precision/", line 140, in _wrap_closure
[rank0]:     closure_result = closure()
[rank0]:                      ^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/", line 148, in __call__
[rank0]:     self._result = self.closure(*args, **kwargs)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/", line 134, in closure
[rank0]:     step_output = self._step_fn()
[rank0]:                   ^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/", line 427, in _training_step
[rank0]:     training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/trainer/", line 1766, in _call_strategy_hook
[rank0]:     output = fn(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/strategies/", line 344, in training_step
[rank0]:     return self.model(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/nn/parallel/", line 1593, in forward
[rank0]:     else self._run_ddp_forward(*inputs, **kwargs)
[rank0]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/nn/parallel/", line 1411, in _run_ddp_forward
[rank0]:     return self.module(*inputs, **kwargs)  # type: ignore[index]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/anaconda3/envs/pytorch/lib/python3.11/site-packages/pytorch_lightning/overrides/", line 82, in forward
[rank0]:     output = self.module.training_step(*inputs, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/PyTorch-VAE/", line 39, in training_step
[rank0]:     train_loss = self.model.loss_function(*results,
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/lulu/PyTorch-VAE/models/", line 146, in loss_function
[rank0]:     loss = recons_loss + kld_weight * kld_loss
[rank0]:                          ~~~~~~~~~~~^~~~~~~~~~
[rank0]: TypeError: only integer tensors of a single element can be converted to an index

If I set the kld_weight to 1, it worked, but it didn't work out so well: Sample: image

Reconstruction: image

One of my dataset: image

The kld_weight=1 is okay? How to make the reconstruction and sample better? Thanks!

MisterBourbaki commented 3 months ago

Hi there, I have a couple of questions for you, in order to better understand the issue: which version of python do you use? And which parameters did you change?

folasefo commented 3 months ago

Hi there, I have a couple of questions for you, in order to better understand the issue: which version of python do you use? And which parameters did you change? First question Python 3.11.8

Second question In

In vae.yaml

`model_params: name: 'VanillaVAE' in_channels: 3 latent_dim: 3

data_params: data_path: "/home/lulu/PyTorch-VAE/Data/"

data_path: "/home/lulu/PyTorch-VAE/Data/"

train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4

exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025d manual_seed: 1265

trainer_params: gpus: [1] max_epochs: 20

logging_params: save_dir: "logs/" name: "VanillaVAE"

` Thank you!

MisterBourbaki commented 3 months ago

I think the issue is simply that there is an additionnal 'd' at the line "kld_weight" in the YAML file? Making it a string, hence difficult ot handle :)

folasefo commented 3 months ago

I think the issue is simply that there is an additionnal 'd' at the line "kld_weight" in the YAML file? Making it a string, hence difficult ot handle :)

You Are Wonderful! This problem is solved. Thank you!

folasefo commented 3 months ago

I think the issue is simply that there is an additionnal 'd' at the line "kld_weight" in the YAML file? Making it a string, hence difficult ot handle :)

And if I want to make the reconstruction sample better, can you give me some advice? My dataset is pictures of galaxies, I want to extract the geometry of these. Thanks again!

MisterBourbaki commented 3 months ago

This is the question of all Machine Learning :D Be sure to test a few different hyperparameters, check that the validation loss decrease, and so on. As long as both train and val losses do not seem to be able to decrease anymore, it means that the model as achieve its best (with the chosen hyperparam). Otherwise, give the training a few more rounds :)

folasefo commented 3 months ago

This is the question of all Machine Learning :D Be sure to test a few different hyperparameters, check that the validation loss decrease, and so on. As long as both train and val losses do not seem to be able to decrease anymore, it means that the model as achieve its best (with the chosen hyperparam). Otherwise, give the training a few more rounds :)

Okay, I will try it :) Thanks for your patience and help, have a nice day ;)

sunny12345-bit commented 3 months ago

Can you train to achieve good results? I can't get good results with my own dataset

sunny12345-bit commented 3 months ago

I think the issue is simply that there is an additionnal 'd' at the line "kld_weight" in the YAML file? Making it a string, hence difficult ot handle :)

And if I want to make the reconstruction sample better, can you give me some advice? My dataset is pictures of galaxies, I want to extract the geometry of these. Thanks again! Can you train to achieve good results? I can't get good results with my own dataset

MisterBourbaki commented 3 months ago

Hi @sunny12345-bit , if you have some particular issue with the code, it would be best to open a dedicated issue :) It will help to help you!

folasefo commented 3 months ago

Can you train to achieve good results? I can't get good results with my own dataset

Hmm, I have been busy with other assignments and haven't started to revise them :-( And I think my parameters are not suitable, the next step will look at the loss function and adjust parameters. Does the loss function decrease with each epoch of your data?

folasefo commented 3 months ago

I think the issue is simply that there is an additionnal 'd' at the line "kld_weight" in the YAML file? Making it a string, hence difficult ot handle :)

And if I want to make the reconstruction sample better, can you give me some advice? My dataset is pictures of galaxies, I want to extract the geometry of these. Thanks again! Can you train to achieve good results? I can't get good results with my own dataset

Hi, my reconstruction results become better, do you get the good results?

sunny12345-bit commented 2 months ago


May I ask which parameters you changed to achieve better results

folasefo commented 2 months ago


May I ask which parameters you changed to achieve better results

I just have changed all parameters to the origin, and then changed the parameters of the image size to my data image size. And make sure your dataset and latent dim are big enough, or the model can't seize features.