vsitzmann / siren

Official implementation of "Implicit Neural Representations with Periodic Activation Functions"
MIT License
1.72k stars 247 forks source link

How to run "Image Fitting" on GPU with low memory ~ 4GB? #16

Closed zohaibmohammad closed 3 years ago

zohaibmohammad commented 4 years ago

Hi, I am trying to execute the code for image fitting problem. I am setting batch size =1 (default value) as I have 4gb GPU. Still the training stops due to GPU out of memory. Can anyone let me know how could I solve this problem? Thanks.

 python experiment_scripts/train_img.py --model_type=sine --experiment_name=output

SingleBVPNet(
  (image_downsampling): ImageDownsampling()
  (net): FCBlock(
    (net): MetaSequential(
      (0): MetaSequential(
        (0): BatchLinear(in_features=2, out_features=256, bias=True)
        (1): Sine()
      )
      (1): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=256, bias=True)
        (1): Sine()
      )
      (2): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=256, bias=True)
        (1): Sine()
      )
      (3): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=256, bias=True)
        (1): Sine()
      )
      (4): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=1, bias=True)
      )
    )
  )
)

  0%|                                                                          | 0/10000 [00:00<?, ?it/s]

Traceback (most recent call last):
  File "experiment_scripts/train_img.py", line 62, in <module>
    model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn)
  File "/home/mz/code/siren/training.py", line 92, in train
    summary_fn(model, model_input, gt, model_output, writer, total_steps)
  File "/home/mz/code/siren/utils.py", line 334, in write_image_summary
    img_gradient = diff_operators.gradient(model_output['model_out'], model_output['model_in'])
  File "/home/mz/code/siren/diff_operators.py", line 42, in gradient
    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
  File "/home/mz/anaconda3/envs/siren/lib/python3.6/site-packages/torch/autograd/__init__.py", line 158, in grad
    inputs, allow_unused)

RuntimeError: CUDA out opython experiment_scripts/train_img.py --model_type=sine --experiment_name=output1
SingleBVPNet(
  (image_downsampling): ImageDownsampling()
  (net): FCBlock(
    (net): MetaSequential(
      (0): MetaSequential(
        (0): BatchLinear(in_features=2, out_features=256, bias=True)
        (1): Sine()
      )
      (1): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=256, bias=True)
        (1): Sine()
      )
      (2): MetaSequential(
        (0): BatchLinear(in_fef memory. Tried to allocate 256.00 MiB (GPU 0; 3.94 GiB total capacity; 2.76 GiB already allocated; 215.06 MiB free; 2.78 GiB reserved in total by PyTorch) (malloc at /opt/conda/conda-bld/pytorch_1587428091666/work/c10/cuda/CUDACachingAllocator.cpp:289)
YuraYelisieiev commented 4 years ago

@engrmz After I launched the torchsummary on the model I discovered that during forward and back passes SIREN consumes 10GB of memory. Probably it`s the CUDA problem.

YuraYelisieiev commented 4 years ago

Can someone explain why does SIREN consume so much memory?

EdgeLLM commented 4 years ago

Can someone explain why does SIREN consume so much memory?

In my opinion, the reason is that the calculation of first-order or second-order derivatives during the training takes up a lot of memory.

YuraYelisieiev commented 4 years ago

@xiaulinhu can you explain why do the derivatives affect the size of forward-pass? I thought it was calculated using shapes of input and output of the layers. I used torchsummary to determine the size.

MahirGulzar commented 3 years ago

Has anyone managed to solve this?

alexanderbergman7 commented 3 years ago

You can modify the logging functions (in utils.py) to not log the image gradients and laplacian. This should reduce the memory footprint of the model, and still meet the "image fitting" experiment criteria. Those results are included to show that the analytical derivatives of the SIREN are accurate without even being supervised on.