lucidrains / lightweight-gan

Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two
MIT License
1.62k stars 220 forks source link

unable to load save model. please try downgrading the package to the version specified by the saved model #124

Closed sebastiantrella closed 2 years ago

sebastiantrella commented 2 years ago

I have the following problem since today. How to do/solve this?

continuing from previous epoch - 118 loading from version 0.21.4 unable to load save model. please try downgrading the package to the version specified by the saved model Traceback (most recent call last): File "/opt/conda/bin/lightweight_gan", line 8, in sys.exit(main()) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 193, in main fire.Fire(train_from_folder) File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 466, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 184, in train_from_folder run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 59, in run_training model.load(load_from) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 1603, in load raise e File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 1600, in load self.GAN.load_state_dict(load_data['GAN']) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for LightweightGAN: Missing key(s) in state_dict: "G.layers.0.0.2.1.weight", "G.layers.0.0.2.1.bias", "G.layers.0.0.4.weight", "G.layers.0.0.4.bias", "G.layers.0.0.4.running_mean", "G.layers.0.0.4.running_var", "G.layers.1.0.2.1.weight", "G.layers.1.0.2.1.bias", "G.layers.1.0.4.weight", "G.layers.1.0.4.bias", "G.layers.1.0.4.running_mean", "G.layers.1.0.4.running_var", "G.layers.2.0.2.1.weight", "G.layers.2.0.2.1.bias", "G.layers.2.0.4.weight", "G.layers.2.0.4.bias", "G.layers.2.0.4.running_mean", "G.layers.2.0.4.running_var", "G.layers.3.0.2.1.weight", "G.layers.3.0.2.1.bias", "G.layers.3.0.4.weight", "G.layers.3.0.4.bias", "G.layers.3.0.4.running_mean", "G.layers.3.0.4.running_var", "G.layers.3.2.fn.to_lin_q.weight", "G.layers.3.2.fn.to_lin_kv.net.0.weight", "G.layers.3.2.fn.to_lin_kv.net.1.weight", "G.layers.3.2.fn.to_kv.weight", "G.layers.4.0.2.1.weight", "G.layers.4.0.2.1.bias", "G.layers.4.0.4.weight", "G.layers.4.0.4.bias", "G.layers.4.0.4.running_mean", "G.layers.4.0.4.running_var", "G.layers.5.0.2.1.weight", "G.layers.5.0.2.1.bias", "G.layers.5.0.4.weight", "G.layers.5.0.4.bias", "G.layers.5.0.4.running_mean", "G.layers.5.0.4.running_var", "D.residual_layers.3.1.fn.to_lin_q.weight", "D.residual_layers.3.1.fn.to_lin_kv.net.0.weight", "D.residual_layers.3.1.fn.to_lin_kv.net.1.weight", "D.residual_layers.3.1.fn.to_kv.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_q.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.0.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.1.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_q.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.0.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.1.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.weight", "GE.layers.0.0.2.1.weight", "GE.layers.0.0.2.1.bias", "GE.layers.0.0.4.weight", "GE.layers.0.0.4.bias", "GE.layers.0.0.4.running_mean", "GE.layers.0.0.4.running_var", "GE.layers.1.0.2.1.weight", "GE.layers.1.0.2.1.bias", "GE.layers.1.0.4.weight", "GE.layers.1.0.4.bias", "GE.layers.1.0.4.running_mean", "GE.layers.1.0.4.running_var", "GE.layers.2.0.2.1.weight", "GE.layers.2.0.2.1.bias", "GE.layers.2.0.4.weight", "GE.layers.2.0.4.bias", "GE.layers.2.0.4.running_mean", "GE.layers.2.0.4.running_var", "GE.layers.3.0.2.1.weight", "GE.layers.3.0.2.1.bias", "GE.layers.3.0.4.weight", "GE.layers.3.0.4.bias", "GE.layers.3.0.4.running_mean", "GE.layers.3.0.4.running_var", "GE.layers.3.2.fn.to_lin_q.weight", "GE.layers.3.2.fn.to_lin_kv.net.0.weight", "GE.layers.3.2.fn.to_lin_kv.net.1.weight", "GE.layers.3.2.fn.to_kv.weight", "GE.layers.4.0.2.1.weight", "GE.layers.4.0.2.1.bias", "GE.layers.4.0.4.weight", "GE.layers.4.0.4.bias", "GE.layers.4.0.4.running_mean", "GE.layers.4.0.4.running_var", "GE.layers.5.0.2.1.weight", "GE.layers.5.0.2.1.bias", "GE.layers.5.0.4.weight", "GE.layers.5.0.4.bias", "GE.layers.5.0.4.running_mean", "GE.layers.5.0.4.running_var", "D_aug.D.residual_layers.3.1.fn.to_lin_q.weight", "D_aug.D.residual_layers.3.1.fn.to_lin_kv.net.0.weight", "D_aug.D.residual_layers.3.1.fn.to_lin_kv.net.1.weight", "D_aug.D.residual_layers.3.1.fn.to_kv.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_q.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.0.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.1.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_q.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.0.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.1.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.weight". Unexpected key(s) in state_dict: "G.layers.0.0.2.weight", "G.layers.0.0.2.bias", "G.layers.0.0.3.bias", "G.layers.0.0.3.running_mean", "G.layers.0.0.3.running_var", "G.layers.0.0.3.num_batches_tracked", "G.layers.1.0.2.weight", "G.layers.1.0.2.bias", "G.layers.1.0.3.bias", "G.layers.1.0.3.running_mean", "G.layers.1.0.3.running_var", "G.layers.1.0.3.num_batches_tracked", "G.layers.2.0.2.weight", "G.layers.2.0.2.bias", "G.layers.2.0.3.bias", "G.layers.2.0.3.running_mean", "G.layers.2.0.3.running_var", "G.layers.2.0.3.num_batches_tracked", "G.layers.3.0.2.weight", "G.layers.3.0.2.bias", "G.layers.3.0.3.bias", "G.layers.3.0.3.running_mean", "G.layers.3.0.3.running_var", "G.layers.3.0.3.num_batches_tracked", "G.layers.3.2.fn.to_kv.net.0.weight", "G.layers.3.2.fn.to_kv.net.1.weight", "G.layers.4.0.2.weight", "G.layers.4.0.2.bias", "G.layers.4.0.3.bias", "G.layers.4.0.3.running_mean", "G.layers.4.0.3.running_var", "G.layers.4.0.3.num_batches_tracked", "G.layers.5.0.2.weight", "G.layers.5.0.2.bias", "G.layers.5.0.3.bias", "G.layers.5.0.3.running_mean", "G.layers.5.0.3.running_var", "G.layers.5.0.3.num_batches_tracked", "D.residual_layers.3.1.fn.to_kv.net.0.weight", "D.residual_layers.3.1.fn.to_kv.net.1.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.net.0.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.net.1.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.net.0.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.net.1.weight", "GE.layers.0.0.2.weight", "GE.layers.0.0.2.bias", "GE.layers.0.0.3.bias", "GE.layers.0.0.3.running_mean", "GE.layers.0.0.3.running_var", "GE.layers.0.0.3.num_batches_tracked", "GE.layers.1.0.2.weight", "GE.layers.1.0.2.bias", "GE.layers.1.0.3.bias", "GE.layers.1.0.3.running_mean", "GE.layers.1.0.3.running_var", "GE.layers.1.0.3.num_batches_tracked", "GE.layers.2.0.2.weight", "GE.layers.2.0.2.bias", "GE.layers.2.0.3.bias", "GE.layers.2.0.3.running_mean", "GE.layers.2.0.3.running_var", "GE.layers.2.0.3.num_batches_tracked", "GE.layers.3.0.2.weight", "GE.layers.3.0.2.bias", "GE.layers.3.0.3.bias", "GE.layers.3.0.3.running_mean", "GE.layers.3.0.3.running_var", "GE.layers.3.0.3.num_batches_tracked", "GE.layers.3.2.fn.to_kv.net.0.weight", "GE.layers.3.2.fn.to_kv.net.1.weight", "GE.layers.4.0.2.weight", "GE.layers.4.0.2.bias", "GE.layers.4.0.3.bias", "GE.layers.4.0.3.running_mean", "GE.layers.4.0.3.running_var", "GE.layers.4.0.3.num_batches_tracked", "GE.layers.5.0.2.weight", "GE.layers.5.0.2.bias", "GE.layers.5.0.3.bias", "GE.layers.5.0.3.running_mean", "GE.layers.5.0.3.running_var", "GE.layers.5.0.3.num_batches_tracked", "D_aug.D.residual_layers.3.1.fn.to_kv.net.0.weight", "D_aug.D.residual_layers.3.1.fn.to_kv.net.1.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.net.0.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.net.1.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.net.0.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.net.1.weight". size mismatch for G.layers.0.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.1.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.2.0.3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.3.0.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.3.2.fn.to_out.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]). size mismatch for G.layers.4.0.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.5.0.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for D.residual_layers.3.1.fn.to_out.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 1024, 1, 1]). size mismatch for D.to_shape_disc_out.1.fn.fn.to_out.weight: copying a param with shape torch.Size([64, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1024, 1, 1]). size mismatch for D.to_shape_disc_out.3.fn.fn.to_out.weight: copying a param with shape torch.Size([32, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 1024, 1, 1]). size mismatch for GE.layers.0.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.1.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.2.0.3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.3.0.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.3.2.fn.to_out.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]). size mismatch for GE.layers.4.0.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.5.0.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for D_aug.D.residual_layers.3.1.fn.to_out.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 1024, 1, 1]). size mismatch for D_aug.D.to_shape_disc_out.1.fn.fn.to_out.weight: copying a param with shape torch.Size([64, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1024, 1, 1]). size mismatch for D_aug.D.to_shape_disc_out.3.fn.fn.to_out.weight: copying a param with shape torch.Size([32, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 1024, 1, 1]).

lucidrains commented 2 years ago

@sebastiantrella hey! you'll need to run pip install lightweight-gan==0.21.4 to fix your problem

i just uploaded a new release that should give more informative instructions in the future

sebastiantrella commented 2 years ago

@lucidrains , Thanks for your help, but I still was not successful.

I downgraded, but still get:

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com Collecting lightweight-gan==0.21.4 Downloading lightweight_gan-0.21.4-py3-none-any.whl (19 kB) Requirement already satisfied: pillow in /opt/conda/lib/python3.8/site-packages (from lightweight-gan==0.21.4) (8.2.0) Requirement already satisfied: kornia>=0.5.4 in /opt/conda/lib/python3.8/site-packages (from lightweight-gan==0.21.4) (0.6.4) Requirement already satisfied: retry in /opt/conda/lib/python3.8/site-packages (from lightweight-gan==0.21.4) (0.9.2) Requirement already satisfied: torchvision in /opt/conda/lib/python3.8/site-packages (from lightweight-gan==0.21.4) (0.12.0) Requirement already satisfied: torch>=1.10 in /opt/conda/lib/python3.8/site-packages (from lightweight-gan==0.21.4) (1.11.0) Requirement already satisfied: einops>=0.3 in /opt/conda/lib/python3.8/site-packages (from lightweight-gan==0.21.4) (0.4.1) Requirement already satisfied: tqdm in /opt/conda/lib/python3.8/site-packages (from lightweight-gan==0.21.4) (4.62.3) Requirement already satisfied: numpy in /opt/conda/lib/python3.8/site-packages (from lightweight-gan==0.21.4) (1.21.2) Requirement already satisfied: fire in /opt/conda/lib/python3.8/site-packages (from lightweight-gan==0.21.4) (0.4.0) Requirement already satisfied: adabelief-pytorch in /opt/conda/lib/python3.8/site-packages (from lightweight-gan==0.21.4) (0.2.1) Requirement already satisfied: packaging in /opt/conda/lib/python3.8/site-packages (from kornia>=0.5.4->lightweight-gan==0.21.4) (21.0) Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.8/site-packages (from torch>=1.10->lightweight-gan==0.21.4) (3.10.0.2) Requirement already satisfied: colorama>=0.4.0 in /opt/conda/lib/python3.8/site-packages (from adabelief-pytorch->lightweight-gan==0.21.4) (0.4.4) Requirement already satisfied: tabulate>=0.7 in /opt/conda/lib/python3.8/site-packages (from adabelief-pytorch->lightweight-gan==0.21.4) (0.8.9) Requirement already satisfied: six in /opt/conda/lib/python3.8/site-packages (from fire->lightweight-gan==0.21.4) (1.16.0) Requirement already satisfied: termcolor in /opt/conda/lib/python3.8/site-packages (from fire->lightweight-gan==0.21.4) (1.1.0) Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.8/site-packages (from packaging->kornia>=0.5.4->lightweight-gan==0.21.4) (2.4.7) Requirement already satisfied: py<2.0.0,>=1.4.26 in /opt/conda/lib/python3.8/site-packages (from retry->lightweight-gan==0.21.4) (1.10.0) Requirement already satisfied: decorator>=3.4.2 in /opt/conda/lib/python3.8/site-packages (from retry->lightweight-gan==0.21.4) (5.1.0) Requirement already satisfied: requests in /opt/conda/lib/python3.8/site-packages (from torchvision->lightweight-gan==0.21.4) (2.26.0) Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.8/site-packages (from requests->torchvision->lightweight-gan==0.21.4) (3.1) Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.8/site-packages (from requests->torchvision->lightweight-gan==0.21.4) (2021.5.30) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.8/site-packages (from requests->torchvision->lightweight-gan==0.21.4) (1.26.7) Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/lib/python3.8/site-packages (from requests->torchvision->lightweight-gan==0.21.4) (2.0.0) Installing collected packages: lightweight-gan Attempting uninstall: lightweight-gan Found existing installation: lightweight-gan 0.22.1 Uninstalling lightweight-gan-0.22.1: Successfully uninstalled lightweight-gan-0.22.1 Successfully installed lightweight-gan-0.21.4 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv continuing from previous epoch - 118 loading from version 0.21.4 unable to load save model. please try downgrading the package to the version specified by the saved model Traceback (most recent call last): File "/opt/conda/bin/lightweight_gan", line 8, in sys.exit(main()) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 190, in main fire.Fire(train_from_folder) File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 466, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 181, in train_from_folder run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 59, in run_training model.load(load_from) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 1527, in load raise e File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 1524, in load self.GAN.load_state_dict(load_data['GAN']) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for LightweightGAN: Missing key(s) in state_dict: "G.layers.0.0.2.1.weight", "G.layers.0.0.2.1.bias", "G.layers.0.0.4.weight", "G.layers.0.0.4.bias", "G.layers.0.0.4.running_mean", "G.layers.0.0.4.running_var", "G.layers.1.0.2.1.weight", "G.layers.1.0.2.1.bias", "G.layers.1.0.4.weight", "G.layers.1.0.4.bias", "G.layers.1.0.4.running_mean", "G.layers.1.0.4.running_var", "G.layers.2.0.2.1.weight", "G.layers.2.0.2.1.bias", "G.layers.2.0.4.weight", "G.layers.2.0.4.bias", "G.layers.2.0.4.running_mean", "G.layers.2.0.4.running_var", "G.layers.3.0.2.1.weight", "G.layers.3.0.2.1.bias", "G.layers.3.0.4.weight", "G.layers.3.0.4.bias", "G.layers.3.0.4.running_mean", "G.layers.3.0.4.running_var", "G.layers.3.2.fn.to_lin_q.weight", "G.layers.3.2.fn.to_lin_kv.net.0.weight", "G.layers.3.2.fn.to_lin_kv.net.1.weight", "G.layers.3.2.fn.to_kv.weight", "G.layers.4.0.2.1.weight", "G.layers.4.0.2.1.bias", "G.layers.4.0.4.weight", "G.layers.4.0.4.bias", "G.layers.4.0.4.running_mean", "G.layers.4.0.4.running_var", "G.layers.5.0.2.1.weight", "G.layers.5.0.2.1.bias", "G.layers.5.0.4.weight", "G.layers.5.0.4.bias", "G.layers.5.0.4.running_mean", "G.layers.5.0.4.running_var", "D.residual_layers.3.1.fn.to_lin_q.weight", "D.residual_layers.3.1.fn.to_lin_kv.net.0.weight", "D.residual_layers.3.1.fn.to_lin_kv.net.1.weight", "D.residual_layers.3.1.fn.to_kv.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_q.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.0.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.1.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_q.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.0.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.1.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.weight", "GE.layers.0.0.2.1.weight", "GE.layers.0.0.2.1.bias", "GE.layers.0.0.4.weight", "GE.layers.0.0.4.bias", "GE.layers.0.0.4.running_mean", "GE.layers.0.0.4.running_var", "GE.layers.1.0.2.1.weight", "GE.layers.1.0.2.1.bias", "GE.layers.1.0.4.weight", "GE.layers.1.0.4.bias", "GE.layers.1.0.4.running_mean", "GE.layers.1.0.4.running_var", "GE.layers.2.0.2.1.weight", "GE.layers.2.0.2.1.bias", "GE.layers.2.0.4.weight", "GE.layers.2.0.4.bias", "GE.layers.2.0.4.running_mean", "GE.layers.2.0.4.running_var", "GE.layers.3.0.2.1.weight", "GE.layers.3.0.2.1.bias", "GE.layers.3.0.4.weight", "GE.layers.3.0.4.bias", "GE.layers.3.0.4.running_mean", "GE.layers.3.0.4.running_var", "GE.layers.3.2.fn.to_lin_q.weight", "GE.layers.3.2.fn.to_lin_kv.net.0.weight", "GE.layers.3.2.fn.to_lin_kv.net.1.weight", "GE.layers.3.2.fn.to_kv.weight", "GE.layers.4.0.2.1.weight", "GE.layers.4.0.2.1.bias", "GE.layers.4.0.4.weight", "GE.layers.4.0.4.bias", "GE.layers.4.0.4.running_mean", "GE.layers.4.0.4.running_var", "GE.layers.5.0.2.1.weight", "GE.layers.5.0.2.1.bias", "GE.layers.5.0.4.weight", "GE.layers.5.0.4.bias", "GE.layers.5.0.4.running_mean", "GE.layers.5.0.4.running_var", "D_aug.D.residual_layers.3.1.fn.to_lin_q.weight", "D_aug.D.residual_layers.3.1.fn.to_lin_kv.net.0.weight", "D_aug.D.residual_layers.3.1.fn.to_lin_kv.net.1.weight", "D_aug.D.residual_layers.3.1.fn.to_kv.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_q.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.0.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.1.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_q.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.0.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.1.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.weight". Unexpected key(s) in state_dict: "G.layers.0.0.2.weight", "G.layers.0.0.2.bias", "G.layers.0.0.3.bias", "G.layers.0.0.3.running_mean", "G.layers.0.0.3.running_var", "G.layers.0.0.3.num_batches_tracked", "G.layers.1.0.2.weight", "G.layers.1.0.2.bias", "G.layers.1.0.3.bias", "G.layers.1.0.3.running_mean", "G.layers.1.0.3.running_var", "G.layers.1.0.3.num_batches_tracked", "G.layers.2.0.2.weight", "G.layers.2.0.2.bias", "G.layers.2.0.3.bias", "G.layers.2.0.3.running_mean", "G.layers.2.0.3.running_var", "G.layers.2.0.3.num_batches_tracked", "G.layers.3.0.2.weight", "G.layers.3.0.2.bias", "G.layers.3.0.3.bias", "G.layers.3.0.3.running_mean", "G.layers.3.0.3.running_var", "G.layers.3.0.3.num_batches_tracked", "G.layers.3.2.fn.to_kv.net.0.weight", "G.layers.3.2.fn.to_kv.net.1.weight", "G.layers.4.0.2.weight", "G.layers.4.0.2.bias", "G.layers.4.0.3.bias", "G.layers.4.0.3.running_mean", "G.layers.4.0.3.running_var", "G.layers.4.0.3.num_batches_tracked", "G.layers.5.0.2.weight", "G.layers.5.0.2.bias", "G.layers.5.0.3.bias", "G.layers.5.0.3.running_mean", "G.layers.5.0.3.running_var", "G.layers.5.0.3.num_batches_tracked", "D.residual_layers.3.1.fn.to_kv.net.0.weight", "D.residual_layers.3.1.fn.to_kv.net.1.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.net.0.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.net.1.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.net.0.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.net.1.weight", "GE.layers.0.0.2.weight", "GE.layers.0.0.2.bias", "GE.layers.0.0.3.bias", "GE.layers.0.0.3.running_mean", "GE.layers.0.0.3.running_var", "GE.layers.0.0.3.num_batches_tracked", "GE.layers.1.0.2.weight", "GE.layers.1.0.2.bias", "GE.layers.1.0.3.bias", "GE.layers.1.0.3.running_mean", "GE.layers.1.0.3.running_var", "GE.layers.1.0.3.num_batches_tracked", "GE.layers.2.0.2.weight", "GE.layers.2.0.2.bias", "GE.layers.2.0.3.bias", "GE.layers.2.0.3.running_mean", "GE.layers.2.0.3.running_var", "GE.layers.2.0.3.num_batches_tracked", "GE.layers.3.0.2.weight", "GE.layers.3.0.2.bias", "GE.layers.3.0.3.bias", "GE.layers.3.0.3.running_mean", "GE.layers.3.0.3.running_var", "GE.layers.3.0.3.num_batches_tracked", "GE.layers.3.2.fn.to_kv.net.0.weight", "GE.layers.3.2.fn.to_kv.net.1.weight", "GE.layers.4.0.2.weight", "GE.layers.4.0.2.bias", "GE.layers.4.0.3.bias", "GE.layers.4.0.3.running_mean", "GE.layers.4.0.3.running_var", "GE.layers.4.0.3.num_batches_tracked", "GE.layers.5.0.2.weight", "GE.layers.5.0.2.bias", "GE.layers.5.0.3.bias", "GE.layers.5.0.3.running_mean", "GE.layers.5.0.3.running_var", "GE.layers.5.0.3.num_batches_tracked", "D_aug.D.residual_layers.3.1.fn.to_kv.net.0.weight", "D_aug.D.residual_layers.3.1.fn.to_kv.net.1.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.net.0.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.net.1.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.net.0.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.net.1.weight". size mismatch for G.layers.0.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.1.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.2.0.3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.3.0.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.3.2.fn.to_out.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]). size mismatch for G.layers.4.0.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.5.0.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for D.residual_layers.3.1.fn.to_out.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 1024, 1, 1]). size mismatch for D.to_shape_disc_out.1.fn.fn.to_out.weight: copying a param with shape torch.Size([64, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1024, 1, 1]). size mismatch for D.to_shape_disc_out.3.fn.fn.to_out.weight: copying a param with shape torch.Size([32, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 1024, 1, 1]). size mismatch for GE.layers.0.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.1.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.2.0.3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.3.0.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.3.2.fn.to_out.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]). size mismatch for GE.layers.4.0.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.5.0.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for D_aug.D.residual_layers.3.1.fn.to_out.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 1024, 1, 1]). size mismatch for D_aug.D.to_shape_disc_out.1.fn.fn.to_out.weight: copying a param with shape torch.Size([64, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1024, 1, 1]). size mismatch for D_aug.D.to_shape_disc_out.3.fn.fn.to_out.weight: copying a param with shape torch.Size([32, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 1024, 1, 1]).

lucidrains commented 2 years ago

@sebastiantrella darn, I may have messed up somewhere - you may have to try downgrading until you hit the version that supports your model (0.21.3, 0.21.2, 0.21.1, 0.21.0)

lucidrains commented 2 years ago

@sebastiantrella i could also offer an option to force the loading of parameters for whichever modules match, and perhaps you can still salvage by continuing training from there on the newer architecture

lucidrains commented 2 years ago

@sebastiantrella ok, try in the latest version --noload-strict or --load-strict=False

sebastiantrella commented 2 years ago

@lucidrains with the newest version and each of the two parameters, I still get:

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com Collecting lightweight-gan Downloading lightweight_gan-0.22.3-py3-none-any.whl (20 kB) Collecting kornia>=0.5.4 Downloading kornia-0.6.4-py2.py3-none-any.whl (493 kB) |████████████████████████████████| 493 kB 16.5 MB/s Collecting einops>=0.3 Downloading einops-0.4.1-py3-none-any.whl (28 kB) Collecting fire Downloading fire-0.4.0.tar.gz (87 kB) |████████████████████████████████| 87 kB 34.4 MB/s Requirement already satisfied: numpy in /opt/conda/lib/python3.8/site-packages (from lightweight-gan) (1.21.2) Requirement already satisfied: pillow in /opt/conda/lib/python3.8/site-packages (from lightweight-gan) (8.2.0) Collecting torch>=1.10 Downloading torch-1.11.0-cp38-cp38-manylinux1_x86_64.whl (750.6 MB) |████████████████████████████████| 750.6 MB 24.7 MB/s Requirement already satisfied: torchvision in /opt/conda/lib/python3.8/site-packages (from lightweight-gan) (0.11.0a0) Collecting adabelief-pytorch Downloading adabelief_pytorch-0.2.1-py3-none-any.whl (5.8 kB) Collecting retry Downloading retry-0.9.2-py2.py3-none-any.whl (8.0 kB) Requirement already satisfied: tqdm in /opt/conda/lib/python3.8/site-packages (from lightweight-gan) (4.62.3) Requirement already satisfied: packaging in /opt/conda/lib/python3.8/site-packages (from kornia>=0.5.4->lightweight-gan) (21.0) Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.8/site-packages (from torch>=1.10->lightweight-gan) (3.10.0.2) Requirement already satisfied: colorama>=0.4.0 in /opt/conda/lib/python3.8/site-packages (from adabelief-pytorch->lightweight-gan) (0.4.4) Requirement already satisfied: tabulate>=0.7 in /opt/conda/lib/python3.8/site-packages (from adabelief-pytorch->lightweight-gan) (0.8.9) Requirement already satisfied: six in /opt/conda/lib/python3.8/site-packages (from fire->lightweight-gan) (1.16.0) Collecting termcolor Downloading termcolor-1.1.0.tar.gz (3.9 kB) Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.8/site-packages (from packaging->kornia>=0.5.4->lightweight-gan) (2.4.7) Requirement already satisfied: decorator>=3.4.2 in /opt/conda/lib/python3.8/site-packages (from retry->lightweight-gan) (5.1.0) Requirement already satisfied: py<2.0.0,>=1.4.26 in /opt/conda/lib/python3.8/site-packages (from retry->lightweight-gan) (1.10.0) Collecting torchvision Downloading torchvision-0.12.0-cp38-cp38-manylinux1_x86_64.whl (21.0 MB) |████████████████████████████████| 21.0 MB 19.3 MB/s Requirement already satisfied: requests in /opt/conda/lib/python3.8/site-packages (from torchvision->lightweight-gan) (2.26.0) Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.8/site-packages (from requests->torchvision->lightweight-gan) (3.1) Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/lib/python3.8/site-packages (from requests->torchvision->lightweight-gan) (2.0.0) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.8/site-packages (from requests->torchvision->lightweight-gan) (1.26.7) Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.8/site-packages (from requests->torchvision->lightweight-gan) (2021.5.30) Building wheels for collected packages: fire, termcolor Building wheel for fire (setup.py) ... done Created wheel for fire: filename=fire-0.4.0-py2.py3-none-any.whl size=115943 sha256=644bbe017e008c9ae8857790a050344193a5b78b9f98ab7b4777011a162d0213 Stored in directory: /tmp/pip-ephem-wheel-cache-20gi31ng/wheels/1f/10/06/2a990ee4d73a8479fe2922445e8a876d38cfbfed052284c6a1 Building wheel for termcolor (setup.py) ... done Created wheel for termcolor: filename=termcolor-1.1.0-py3-none-any.whl size=4847 sha256=c1ee025a1fd8c30edbc4d7c9437bf7f66bd9ce7609d68221f667e116b2bda7a1 Stored in directory: /tmp/pip-ephem-wheel-cache-20gi31ng/wheels/a0/16/9c/5473df82468f958445479c59e784896fa24f4a5fc024b0f501 Successfully built fire termcolor Installing collected packages: torch, termcolor, torchvision, retry, kornia, fire, einops, adabelief-pytorch, lightweight-gan Attempting uninstall: torch Found existing installation: torch 1.10.0a0+0aef44c Uninstalling torch-1.10.0a0+0aef44c: Successfully uninstalled torch-1.10.0a0+0aef44c Attempting uninstall: torchvision Found existing installation: torchvision 0.11.0a0 Uninstalling torchvision-0.11.0a0: Successfully uninstalled torchvision-0.11.0a0 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. torchtext 0.11.0a0 requires torch==1.10.0a0+0aef44c, but you have torch 1.11.0 which is incompatible. Successfully installed adabelief-pytorch-0.2.1 einops-0.4.1 fire-0.4.0 kornia-0.6.4 lightweight-gan-0.22.3 retry-0.9.2 termcolor-1.1.0 torch-1.11.0 torchvision-0.12.0 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv continuing from previous epoch - 118 loading from version 0.21.4 unable to load save model. please try downgrading the package to the version specified by the saved model (to do so, just run pip install lightweight-gan=={saved_version} Traceback (most recent call last): File "/opt/conda/bin/lightweight_gan", line 8, in sys.exit(main()) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 195, in main fire.Fire(train_from_folder) File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 466, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 186, in train_from_folder run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 59, in run_training model.load(load_from) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 1613, in load raise e File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 1609, in load self.GAN.load_state_dict(load_data['GAN'], strict = self.load_strict) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for LightweightGAN: size mismatch for G.layers.0.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.1.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.2.0.3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.3.0.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.3.2.fn.to_out.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]). size mismatch for G.layers.4.0.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.5.0.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for D.residual_layers.3.1.fn.to_out.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 1024, 1, 1]). size mismatch for D.to_shape_disc_out.1.fn.fn.to_out.weight: copying a param with shape torch.Size([64, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1024, 1, 1]). size mismatch for D.to_shape_disc_out.3.fn.fn.to_out.weight: copying a param with shape torch.Size([32, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 1024, 1, 1]). size mismatch for GE.layers.0.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.1.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.2.0.3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.3.0.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.3.2.fn.to_out.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]). size mismatch for GE.layers.4.0.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.5.0.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for D_aug.D.residual_layers.3.1.fn.to_out.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 1024, 1, 1]). size mismatch for D_aug.D.to_shape_disc_out.1.fn.fn.to_out.weight: copying a param with shape torch.Size([64, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1024, 1, 1]). size mismatch for D_aug.D.to_shape_disc_out.3.fn.fn.to_out.weight: copying a param with shape torch.Size([32, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 1024, 1, 1]). root@na2b0dnel7:/notebooks#

sebastiantrella commented 2 years ago

I tried to downgrade, and for some reason it is working with 0.20.5...Will complete the training of this set now and than switch to newer version. Thanks for your help! Really appreciated!