Sharath-girish / Shacira

Official Pytorch implementation of SHACIRA: Scalable HAsh-grid Compression for Implicit Neural Representations
Other
30 stars 3 forks source link

KeyError: 'state_dict' when using pearl.yaml #5

Open HyungGeun-Cho opened 8 months ago

HyungGeun-Cho commented 8 months ago

Hi! Thank you your great work.

I'm an undergraduate student currently interested in reconstructing gigapixel images using implicit neural representations. Thus, I was searching for INR papers, which tried gigapixel image fitting, and found your SHACIRA!

First, I tried to reproduce kodak images using the provided kodak.yaml and it did work. (It showed a bit lower PSNR, but I assume that it might depend on what GPU I use..!)

python3 app/image/main_image.py --config app/image/configs/kodak.yaml --dataset-path data/kodak

Then, I tried to fit the pearl image (23466x20000) using the provided pearl.yaml, but the error occurred as below.

(The option valid_every: 1 also occurred the another error, so I temporarily modified to valid_every: -1.) (I actually used the pearl image and pearl.yaml, but the below error log is just to show the what error occurred which shows different epochs and etc.)

python3 app/image/main_image.py --config app/image/configs/pearl.yaml --dataset-path data/gigapixel

2024-01-02 10:42:22,451|    INFO| Image 1/ 2 EPOCH 1/3 | PSNR: 7.02E+00 | BPP: 1.29E+00 | total size (kB): 6.34E+01 | total loss: 2.67E-01 | rgb loss: 2.66E-01 | ent loss: 1.28E-04 | total size (kB): 6.34E+01 | temp: 4.26E-01 | sga: True
100%|████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 11.87it/s]
2024-01-02 10:42:22,681|    INFO| Image 1/ 2 EPOCH 2/3 | PSNR: 7.01E+00 | BPP: 1.29E+00 | total size (kB): 6.34E+01 | total loss: 2.61E-01 | rgb loss: 2.60E-01 | ent loss: 1.28E-04 | total size (kB): 6.34E+01 | temp: 1.82E-01 | sga: True
100%|████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 12.34it/s]
2024-01-02 10:42:22,899|    INFO| Image 1/ 2 EPOCH 3/3 | PSNR: 7.02E+00 | BPP: 1.29E+00 | total size (kB): 6.34E+01 | total loss: 2.55E-01 | rgb loss: 2.54E-01 | ent loss: 1.28E-04 | total size (kB): 6.34E+01 | temp: 1.00E-01 | sga: False
Traceback (most recent call last):
  File "/code/Shacira/app/image/main_image.py", line 604, in <module>
    trainer.train()
  File "/code/Shacira/wisp/trainers/image_trainer.py", line 375, in train
    self.iterate()
  File "/code/Shacira/wisp/trainers/base_trainer.py", line 396, in iterate
    self.post_training()
  File "/code/Shacira/wisp/trainers/image_trainer.py", line 477, in post_training
    self.pipeline.load_state_dict(self.best_state['state_dict'])
KeyError: 'state_dict'

I changed various parameters and figured it out the error depends on which sample_mode I chose. It seems like kodak.yaml uses all the coordinates, while pearl.yaml only samples a few coordinates for memory efficiency.

dataset:
    dataloader_num_workers: 1
    num_samples: 18
    sample_mode: 'wreplace'

I would really appreciate it if you could tell me how to resolve this issue!

Thank you in advance.