IST-DASLab / MicroAdam

This repository contains code for the MicroAdam paper.
Apache License 2.0
10 stars 3 forks source link

Crashes when trying to use MicroAdam optimizer with float32 parameters and device other than cuda:0 #1

Open fzmushko opened 1 week ago

fzmushko commented 1 week ago

I am trying to use MicroAdam optimizer, but i face crashes when trying to perform optimizer.step().

Setup: Empty conda enviroment, only torch and ista-daslab-optimizers are installed via pip install torch ista-daslab-optimizers. torch 2.4.1, ista-daslab-optimizers 1.1.3, CUDA 12.1, Python 3.9, A100-SXM4-80GB. I run the following code:

import torch
from ista_daslab_optimizers import MicroAdam 

device = "cuda:0"
dtype = torch.float32
model = torch.nn.Sequential(torch.nn.Linear(1024, 1024), torch.nn.Linear(1024, 1024), torch.nn.Linear(1024, 1024)).to(device).to(dtype)
opt = MicroAdam(model.parameters(), m=10, lr=1e-3, quant_block_size=100_000, k_init=0.01)

loss = model(torch.randn(1024, device=device, dtype=dtype)).norm()
loss.backward()
opt.step()
opt.zero_grad()

and receive the following error:

[CUDA Kernel] maxSharedMemSizePerSM_bytes = 167936, maxSharedMemSizePerSM_kilobytes = 164, floats_count = 41728
python3: /place/vartmp/pip-install-eb9bmit3/ista-daslab-optimizers_308cf0f0f9d840d1b694b7e4ad9fb163/kernels/micro_adam/micro_adam_asymm_block_quant_inv.cu:63: void asymm_block_quant_inv_cuda(LL, LL, at::Tensor, at::Tensor, at::Tensor, at::Tensor, float): Assertion `torch::ScalarType::BFloat16 == x.scalar_type()' failed.
Aborted

The same error occurs when I try to run it in mixed precision with torch.autocast.

Since it says something about BFloat16 I have set dtype = torch.bfloat16 and this code works. However, if I change device from cuda:0 to cuda:1, I again encounter an error:

[CUDA Kernel] maxSharedMemSizePerSM_bytes = 167936, maxSharedMemSizePerSM_kilobytes = 164, floats_count = 41728
Traceback (most recent call last):
  File "/extra_disk_1/zmushko-fa/test_ist.py", line 13, in <module>
    opt.step()
  File "/home/zmushko-fa/miniconda3/envs/microadam_new_torch/lib/python3.9/site-packages/torch/optim/optimizer.py", line 484, in wrapper
    out = func(*args, **kwargs)
  File "/home/zmushko-fa/miniconda3/envs/microadam_new_torch/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/zmushko-fa/miniconda3/envs/microadam_new_torch/lib/python3.9/site-packages/ista_daslab_optimizers/micro_adam/micro_adam.py", line 131, in step
    nqe, ng, nu, ne, sp_u, sp_qe = self.update_step(p, lr, wd)
  File "/home/zmushko-fa/miniconda3/envs/microadam_new_torch/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/zmushko-fa/miniconda3/envs/microadam_new_torch/lib/python3.9/site-packages/ista_daslab_optimizers/micro_adam/micro_adam.py", line 193, in update_step
    I[index, :k_index] = torch.topk(input=grad[0:d_index_topk].abs().view(topk_full_blocks_count, d_block_size),
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Running with CUDA_LAUNCH_BLOCKING=1 returns:

[CUDA Kernel] maxSharedMemSizePerSM_bytes = 167936, maxSharedMemSizePerSM_kilobytes = 164, floats_count = 41728
Error detected by GPU Assert:
        Error 700: an illegal memory access was encountered 
        File: /place/vartmp/pip-install-eb9bmit3/ista-daslab-optimizers_308cf0f0f9d840d1b694b7e4ad9fb163/kernels/micro_adam/micro_adam_asymm_block_quant_inv.cu
        Line: 80

To sum up: 1) Looks like optimizer doesn't work with float32 master weights. Is it supposed to be so? 2) Looks like there is a bug when using device other than cuda:0.

fzmushko commented 1 week ago

I have also tried new enviroment with torch==2.3.0 which was used in original paper (install.sh script). In this case I encounter another error:

Traceback (most recent call last):
  File "/extra_disk_1/zmushko-fa/test_ist.py", line 2, in <module>
    from ista_daslab_optimizers import MicroAdam
  File "/home/zmushko-fa/miniconda3/envs/microadam/lib/python3.9/site-packages/ista_daslab_optimizers/__init__.py", line 2, in <module>
    from .micro_adam import *
  File "/home/zmushko-fa/miniconda3/envs/microadam/lib/python3.9/site-packages/ista_daslab_optimizers/micro_adam/__init__.py", line 1, in <module>
    from .micro_adam import MicroAdam
  File "/home/zmushko-fa/miniconda3/envs/microadam/lib/python3.9/site-packages/ista_daslab_optimizers/micro_adam/micro_adam.py", line 7, in <module>
    from ..tools import get_first_device, get_gpu_mem_usage, block_split, CopyDirection
  File "/home/zmushko-fa/miniconda3/envs/microadam/lib/python3.9/site-packages/ista_daslab_optimizers/tools.py", line 6, in <module>
    import ista_daslab_tools
ImportError: /home/zmushko-fa/miniconda3/envs/microadam/lib/python3.9/site-packages/ista_daslab_tools.cpython-39-x86_64-linux-gnu.so: undefined symbol: _ZN3c1021throwNullDataPtrErrorEv
fzmushko commented 1 week ago

I will also duplicate this issue in ista-daslab repository.

ionutmodo commented 6 days ago

Hi! Thank you for reaching out! I will have a deeper look at this starting next week and will try to solve it as soon as possible.

ionutmodo commented 6 days ago

I have also tried new enviroment with torch==2.3.0 which was used in original paper (install.sh script). In this case I encounter another error:

Traceback (most recent call last):
  File "/extra_disk_1/zmushko-fa/test_ist.py", line 2, in <module>
    from ista_daslab_optimizers import MicroAdam
  File "/home/zmushko-fa/miniconda3/envs/microadam/lib/python3.9/site-packages/ista_daslab_optimizers/__init__.py", line 2, in <module>
    from .micro_adam import *
  File "/home/zmushko-fa/miniconda3/envs/microadam/lib/python3.9/site-packages/ista_daslab_optimizers/micro_adam/__init__.py", line 1, in <module>
    from .micro_adam import MicroAdam
  File "/home/zmushko-fa/miniconda3/envs/microadam/lib/python3.9/site-packages/ista_daslab_optimizers/micro_adam/micro_adam.py", line 7, in <module>
    from ..tools import get_first_device, get_gpu_mem_usage, block_split, CopyDirection
  File "/home/zmushko-fa/miniconda3/envs/microadam/lib/python3.9/site-packages/ista_daslab_optimizers/tools.py", line 6, in <module>
    import ista_daslab_tools
ImportError: /home/zmushko-fa/miniconda3/envs/microadam/lib/python3.9/site-packages/ista_daslab_tools.cpython-39-x86_64-linux-gnu.so: undefined symbol: _ZN3c1021throwNullDataPtrErrorEv

In your initial message I see you are using CUDA 12.1. In our development we used CUDA 12.2. If you have access to a cluster, please try running module load cuda/12.2 or module load cuda/12.4 and then activate your environment. This error happens because the CUDA kernels for MicroAdam were built using CUDA 12.2, which includes some fixes to CUDA 12.1.

ionutmodo commented 6 days ago

To sum up:

  1. Looks like optimizer doesn't work with float32 master weights. Is it supposed to be so?
  2. Looks like there is a bug when using device other than cuda:0.
  1. MicroAdam currently supports bfloat16 because our main focus was to reduce memory usage, as stated in the paper. Please use our optimizer with dtype=torch.bfloat16.
  2. I suggest setting your CUDA_VISIBLE_DEVICES accordingly. For example, if your system has 8 GPUs and you want to use GPU 1, please run your program using CUDA_VISIBLE_DEVICES=1 python main.py and keep device=cuda:0. I believe this should be quick fix before I try to reproduce the error.

Thank you for opening this issue and please let me know how it works! I am happy to help!

fzmushko commented 6 days ago
  1. MicroAdam currently supports bfloat16 because our main focus was to reduce memory usage, as stated in the paper. Please use our optimizer with dtype=torch.bfloat16.

Ah, I see. I just didn't notice that mentioned in either the article or the README (though, perhaps I wasn't paying enough attention). I assumed that float32 support should be present because pure bf16 training is sometimes leads to worse results without advanced techniques as stochastic rounding.

However, I tried running the same code with float32 weights and manually casting the gradients to bf16 after the loss.backward(). In this case, the code doesn't break, and I suppose this approach might make sense because before adding the final bf16 update to the fp32 weights, it will first be cast to fp32. This could ensure that the weight update remains accurate enough.

  1. I suggest setting your CUDA_VISIBLE_DEVICES accordingly. For example, if your system has 8 GPUs and you want to use GPU 1, please run your program using CUDA_VISIBLE_DEVICES=1 python main.py and keep device=cuda:0. I believe this should be quick fix before I try to reproduce the error.

It works, thank you.

ionutmodo commented 6 days ago

I agree, manually casting the gradient to bfloat16 works. We did this in the FFCV repository where we had some issues with mixed precision. I will mention the bfloat16 format in the readme, I think we missed that. Thank you for pointing out! Please let me know whether there are some other issues that I can help with.