giuvecchio / matfuse-sd

MatFuse: Controllable Material Generation with Diffusion Models (CVPR2024)
https://gvecchio.com/matfuse/
MIT License
31 stars 2 forks source link

Error loading autoencoder weights, and wandb fails in LDM training #7

Closed Night1099 closed 4 months ago

Night1099 commented 4 months ago

I changed this line in autoencoder.py

sd = torch.load(path, map_location="cpu")["vqmodel"]

To

sd = torch.load(path, map_location="cpu")

to fix this error when running ldm training

Traceback (most recent call last):
  File "/workspace/matfuse-sd/src/main.py", line 697, in <module>
    model = instantiate_from_config(config.model).to(device)
  File "/workspace/matfuse-sd/src/ldm/util.py", line 119, in instantiate_from_config
    return get_obj_from_str(config["target"])(config.get("params", dict()))
  File "/workspace/matfuse-sd/src/ldm/models/diffusion/ddpm.py", line 678, in init
    self.instantiate_first_stage(first_stage_config)
  File "/workspace/matfuse-sd/src/ldm/models/diffusion/ddpm.py", line 749, in instantiate_first_stage
    model = instantiate_from_config(config)
  File "/workspace/matfuse-sd/src/ldm/util.py", line 119, in instantiate_from_config
    return get_obj_from_str(config["target"])(config.get("params", dict()))
  File "/workspace/matfuse-sd/src/ldm/models/autoencoder.py", line 393, in init
    self.init_from_ckpt(ckpt_path, ignore_keys=[])
  File "/workspace/matfuse-sd/src/ldm/models/autoencoder.py", line 396, in init_from_ckpt
    sd = torch.load(path, map_location="cpu")["vqmodel"]
KeyError: 'vqmodel'

Script then loads weights correctly and then errors out on wandb logging with this error

Traceback (most recent call last):
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/pytorch_lightning/plugins/training_type/ddp_spawn.py", line 201, in new_process
    results = trainer.run_stage()
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 996, in run_stage
    return self._run_train()
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1031, in runtrain
    self._run_sanity_check(self.lightning_module)
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1115, in runsanity_check
    self._evaluation_loop.run()
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(args, **kwargs)
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
    dl_outputs = self.epoch_loop.run(
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(args, kwargs)
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 110, in advance
    output = self.evaluation_step(batch, batch_idx, dataloader_idx)
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 154, in evaluation_step
    output = self.trainer.accelerator.validation_step(step_kwargs)
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/pytorch_lightning/accelerators/accelerator.py", line 211, in validation_step
    return self.training_type_plugin.validation_step(step_kwargs.values())
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/pytorch_lightning/plugins/training_type/ddp_spawn.py", line 362, in validation_step
    return self.model(args, kwargs)
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in callimpl
    return forward_call(input, **kwargs)
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1040, in forward
    output = self._run_ddp_forward(inputs, kwargs)
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1000, in runddp_forward
    return module_to_run(*inputs[0], kwargs[0])
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in callimpl
    return forward_call(input, **kwargs)
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/pytorch_lightning/overrides/base.py", line 93, in forward
    output = self.module.validation_step(inputs, kwargs)
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, kwargs)
  File "/workspace/matfuse-sd/src/ldm/models/diffusion/ddpm.py", line 572, in validation_step
    wandb.log(loss_dict_ema)
  File "/root/anaconda3/envs/sdiff/lib/python3.10/site-packages/wandb/sdk/lib/preinit.py", line 36, in preinit_wrapper
    raise wandb.Error(f"You must call wandb.init() before {name}()")
wandb.errors.Error: You must call wandb.init() before wandb.log()
Night1099 commented 4 months ago

fixed by moving

wandb.init(project=os.environ.get("WANDB_PROJECT", "matfuse"), entity=os.environ.get("WANDB_ENTITY"))

to under main funtion

Night1099 commented 4 months ago

Fixed in Pull Request