aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
467 stars 72 forks source link

OOM when try to fit LA on ViT #232

Closed ruili-pml closed 1 month ago

ruili-pml commented 2 months ago

Hi,

I'm trying to do a subnetwork laplace with diagonal covariance on a ViT base. For now I'm trying to make the mlp part in the transformer block Bayesian and I got OOM error, which is caused by the Jacobian computation. As I'm not familiar with the backend computation of Jacobian, I was wondering is this something I can fix relatively easily or it's just not feasible?

Thanks. Rui

File "/scratch/work/src/vit_la_fit.py", line 51, in <module> la.fit(train_loader) File "/scratch/work/lir3/Laplace/laplace/baselaplace.py", line 703, in fit loss_batch, H_batch = self._curv_closure(X, y, N=N) File "/scratch/work/lir3/Laplace/laplace/baselaplace.py", line 1758, in _curv_closure return self.backend.diag(X, y, N=N, **self._asdl_fisher_kwargs) File "/scratch/work/lir3/Laplace/laplace/curvature/curvature.py", line 417, in diag Js, f = self.last_layer_jacobians(x) if self.last_layer else self.jacobians(x) File "/scratch/work/lir3/Laplace/laplace/curvature/curvature.py", line 115, in jacobians Js, f = torch.func.jacrev(model_fn_params_only, has_aux=True)( File "/scratch/work/lir3/.conda_envs/dbnn_env/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 601, in wrapper_fn flat_jacobians_per_input = compute_jacobian_stacked() File "/scratch/work/lir3/.conda_envs/dbnn_env/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 532, in compute_jacobian_stacked chunked_result = vmap(vjp_fn)(basis) File "/scratch/work/lir3/.conda_envs/dbnn_env/lib/python3.9/site-packages/torch/_functorch/apis.py", line 188, in wrapped return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs) File "/scratch/work/lir3/.conda_envs/dbnn_env/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 266, in vmap_impl return _flat_vmap( File "/scratch/work/lir3/.conda_envs/dbnn_env/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 38, in fn return f(*args, **kwargs) File "/scratch/work/lir3/.conda_envs/dbnn_env/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 379, in _flat_vmap batched_outputs = func(*batched_inputs, **kwargs) File "/scratch/work/lir3/.conda_envs/dbnn_env/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 328, in wrapper result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents, File "/scratch/work/lir3/.conda_envs/dbnn_env/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 116, in _autograd_grad grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs, File "/scratch/work/lir3/.conda_envs/dbnn_env/lib/python3.9/site-packages/torch/autograd/__init__.py", line 394, in grad result = Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.71 GiB. GPU 0 has a total capacty of 31.73 GiB of which 102.19 MiB is free. Including non-PyTorch memory, this process has 31.63 GiB memory in use. Of the allocated memory 31.17 GiB is allocated by PyTorch, and 81.54 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

wiseodd commented 2 months ago

Can you try using the AsdlGNN backend? Something like:

from laplace.curvature import AsdlGGN

la = Laplace(model, ..., backend=AsdlGGN)

The reason being: the default backend is CurvlinopsGGN which defaults to torch.func (vmap + jacrev) for the Jacobian computation => memory intensive for large nets due to the vmap.

ruili-pml commented 2 months ago

Thanks for the quick replied! I tried change the backend but got a new error

    la = Laplace(model, 'classification', subset_of_weights="subnetwork",
  File "/scratch/work/lir3/Laplace/laplace/laplace.py", line 48, in Laplace
    return laplace_class(model, likelihood, *args, **kwargs)
  File "/scratch/work/lir3/Laplace/laplace/subnetlaplace.py", line 107, in __init__
    raise ValueError("SubnetLaplace can only be used with GGN and EF.")
ValueError: SubnetLaplace can only be used with GGN and EF.
wiseodd commented 2 months ago

Yes that's a bug. I opened a fix here: https://github.com/aleximmer/Laplace/pull/239.

pip install git+https://github.com/aleximmer/lapalce.git@subnet-backend

Let me know if using AsdlGGN (or AsdlEF) fixes your problem.

Also, depending on what you want to achieve, you may want to just use Laplace with subset_of_weights="full" and then disable the gradients of the parameters you don't want to include. Laplace will automatically ignore them.

See https://github.com/aleximmer/Laplace?tab=readme-ov-file#subnetwork-laplace (third paragraph) and https://github.com/aleximmer/Laplace/issues/217#issuecomment-2278311460

ruili-pml commented 1 month ago

Thanks a lot! I couldn't install like the way you suggested but after doing the change manually, it solved the previous bug.

And yeah I ended up using the way you suggested, it's easier that way as I'm ignoring the distribution on bias and if I use the subnetwork class, I need to hack into the library to get it to work.