bitsandbytes-foundation / bitsandbytes

Accessible large language models via k-bit quantization for PyTorch.
https://huggingface.co/docs/bitsandbytes/main/en/index
MIT License
6.31k stars 634 forks source link

Bug when using optimizer LAMB 32bits #1350

Open FrsECM opened 2 months ago

FrsECM commented 2 months ago

System Info

wsl Ubuntu22.04, Python3.10, bnb 0.43.1

Reproduction

In order to reproduce the issue you can do this :

import bitsandbytes as bnb
import torch
import torch.nn as nn

model = nn.Linear(10,2).cuda()
model.train()
# We create an optimizer
optimizer = bnb.optim.LAMB(model.parameters())
# We create dummy input / output
input = torch.rand(size=(10,10)).cuda()
target = torch.zeros(10).cuda()

# We compute prediction / loss
optimizer.zero_grad()
prediction = model(input)
loss =nn.CrossEntropyLoss()(prediction, target.long())

loss.backward()
optimizer.step()

It will result in something like that :

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[12], [line 19](vscode-notebook-cell:?execution_count=12&line=19)
     [16](vscode-notebook-cell:?execution_count=12&line=16) loss =nn.CrossEntropyLoss()(prediction, target.long())
     [18](vscode-notebook-cell:?execution_count=12&line=18) loss.backward()
---> [19](vscode-notebook-cell:?execution_count=12&line=19) optimizer.step()

File /home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:391, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    [386](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:386)         else:
    [387](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:387)             raise RuntimeError(
    [388](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:388)                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    [389](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:389)             )
--> [391](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:391) out = func(*args, **kwargs)
    [392](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:392) self._optimizer_step_code()
    [394](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:394) # call optimizer step post hooks

File /home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    [112](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
    [113](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
    [114](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/utils/_contextlib.py:114)     with ctx_factory():
--> [115](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/utils/_contextlib.py:115)         return func(*args, **kwargs)

File /home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/optim/optimizer.py:287, in Optimizer8bit.step(self, closure)
    [284](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/optim/optimizer.py:284)             self.init_state(group, p, gindex, pindex)
    [286](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/optim/optimizer.py:286)         self.prefetch_state(p)
...
-> [1584](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/functional.py:1584)     optim_func = str2optimizer32bit[optimizer_name][0]
   [1585](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/functional.py:1585) elif g.dtype == torch.float16:
   [1586](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/functional.py:1586)     optim_func = str2optimizer32bit[optimizer_name][1]

KeyError: 'lamb'

Expected behavior

The optimizer should be working.