import bitsandbytes as bnb
import torch
import torch.nn as nn
model = nn.Linear(10,2).cuda()
model.train()
# We create an optimizer
optimizer = bnb.optim.LAMB(model.parameters())
# We create dummy input / output
input = torch.rand(size=(10,10)).cuda()
target = torch.zeros(10).cuda()
# We compute prediction / loss
optimizer.zero_grad()
prediction = model(input)
loss =nn.CrossEntropyLoss()(prediction, target.long())
loss.backward()
optimizer.step()
It will result in something like that :
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[12], [line 19](vscode-notebook-cell:?execution_count=12&line=19)
[16](vscode-notebook-cell:?execution_count=12&line=16) loss =nn.CrossEntropyLoss()(prediction, target.long())
[18](vscode-notebook-cell:?execution_count=12&line=18) loss.backward()
---> [19](vscode-notebook-cell:?execution_count=12&line=19) optimizer.step()
File /home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:391, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
[386](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:386) else:
[387](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:387) raise RuntimeError(
[388](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:388) f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
[389](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:389) )
--> [391](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:391) out = func(*args, **kwargs)
[392](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:392) self._optimizer_step_code()
[394](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:394) # call optimizer step post hooks
File /home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
[112](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
[113](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
[114](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/utils/_contextlib.py:114) with ctx_factory():
--> [115](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/utils/_contextlib.py:115) return func(*args, **kwargs)
File /home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/optim/optimizer.py:287, in Optimizer8bit.step(self, closure)
[284](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/optim/optimizer.py:284) self.init_state(group, p, gindex, pindex)
[286](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/optim/optimizer.py:286) self.prefetch_state(p)
...
-> [1584](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/functional.py:1584) optim_func = str2optimizer32bit[optimizer_name][0]
[1585](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/functional.py:1585) elif g.dtype == torch.float16:
[1586](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/functional.py:1586) optim_func = str2optimizer32bit[optimizer_name][1]
KeyError: 'lamb'
System Info
wsl Ubuntu22.04, Python3.10, bnb 0.43.1
Reproduction
In order to reproduce the issue you can do this :
It will result in something like that :
Expected behavior
The optimizer should be working.