Open asaluja opened 1 year ago
same error for bitsandbytes==0.41.1 in 8 * A100 GPUs
update: this problem was caused by the paged_adamw_32bit optimizer. I solved this problem by using the adamw_torch optimizer and processing the optimizer.pt file using the following script.
import os
import torch
if __name__ == "__main__":
os.rename("optimizer.pt", "optimizer.paged.pt")
state_dict = torch.load("optimizer.paged.pt")
for i in state_dict["state"].keys():
exp_avg = state_dict["state"][i].pop("state1")
exp_avg_sq = state_dict["state"][i].pop("state2")
state_dict["state"][i]["exp_avg"] = exp_avg
state_dict["state"][i]["exp_avg_sq"] = exp_avg_sq
for i in range(len(state_dict["param_groups"])):
state_dict["param_groups"][i]["amsgrad"] = False
state_dict["param_groups"][i]["foreach"] = None
state_dict["param_groups"][i]["maximize"] = False
state_dict["param_groups"][i]["capturable"] = False
state_dict["param_groups"][i]["differentiable"] = False
state_dict["param_groups"][i]["fused"] = None
torch.save(state_dict, "optimizer.pt")
I'm getting this issue too on T4 and L4. I'm going to dig into it more, but do you have more insight into it? Is the state dict serialization wrong after a change? Or the state dict loading?
I'm seeing this same issue on resume with paged_adamw_8bit.
@markrmiller I'm facing the same issue with paged_adamw_8bit, did you find any solution for it? thanks 🙏
Only to start training over with standard Adam. Later I used adam_paged_32bit and hit the same thing - so now I'm just avoiding all these optimizers.
@markrmiller thanks for advice
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
I am facing the same issue on a 8xA100 machine with bitsandbytes==0.42.0
. I am using the paged_adamw_32bit
optimizer. Did any of you find a solution? Please help, thank you 🙏
I am facing the same issue on a 8xA100 machine with
bitsandbytes==0.42.0
. I am using thepaged_adamw_32bit
optimizer. Did any of you find a solution? Please help, thank you 🙏
I have the exact same issue
I got a similar error when trying to save optimizer states (PagedAdam8Bit)
Traceback (most recent call last):
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/ubuntu/projects/scripts/train_ranker.py", line 404, in <module>
train(accelerator, args)
File "/home/ubuntu/projects/scripts/train_ranker.py", line 357, in train
save_training_artifacts(
File "/home/ubuntu/projects/scripts/train_ranker.py", line 74, in save_training_artifacts
accelerator.save_state(save_dir)
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/accelerate/accelerator.py", line 2958, in save_state
save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 168, in save_fsdp_optimizer
optim_state = FSDP.optim_state_dict(model, optimizer)
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1840, in optim_state_dict
return FullyShardedDataParallel._optim_state_dict_impl(
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1263, in _optim_state_dict_impl
return _optim_state_dict(
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1971, in _optim_state_dict
fsdp_osd_state = convert_fn(
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1794, in _convert_state_with_orig_params
_gather_all_orig_param_state(
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1688, in _gather_all_orig_param_state
output_states = _allgather_orig_param_states(
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1518, in _allgather_orig_param_states
dtype, state_buffers = _convert_all_state_info(
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1377, in _convert_all_state_info
assert dtype == info.dtype
AssertionError
I am facing the same issue. Is there any solution other than changing optimizers ?
Hello,
I'm running
bitsandbytes==0.41.1
in a Python 3.10 miniconda environment, 8xA100 GPU (usingaccelerate
for multi-GPU), Cuda 12.2.I'm having problems resuming training (DPO) from a checkpoint:
I'm not quite sure how to debug this error, any ideas?