pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.52k stars 156 forks source link

CPUoffloadOptimizer issues #1209

Open felipemello1 opened 16 hours ago

felipemello1 commented 16 hours ago

hi all, i was giving the CPUOffloadOptimizer a try and found two issues when using with QLoRA single device in torchtune:

  1. When using a LR scheduler i got. Maybe there is a way to inherit the optimizer class?

    File "/data/users/felipemello/torchtune/torchtune/training/lr_schedulers.py", line 58, in get_cosine_schedule_with_warmup
    return LambdaLR(optimizer, lr_lambda, last_epoch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 336, in __init__
    super().__init__(optimizer, last_epoch, verbose)
    File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 99, in __init__
    raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
    TypeError: CPUOffloadOptimizer is not an Optimizer
  2. When passing model.params() i got the error below. I imagine that a simple fix is to keep only params that require grad, like adamw implementation oes

    File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torchao/prototype/low_bit_optim/cpu_offload.py", line 76, in __init__
    p_cuda.register_post_accumulate_grad_hook(backward_hook)
    File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_tensor.py", line 678, in register_post_accumulate_grad_hook
    raise RuntimeError(
    RuntimeError: cannot register a hook on a tensor that doesn't require gradient

cc: @gau-nernst

gau-nernst commented 15 hours ago

1 is a known issue. You can see my view here https://github.com/pytorch/ao/issues/959#issuecomment-2378225308. I will look into torch.optim.Optimizer base class to see what could go wrong if I make CPUOffloadOptimizer inherit it. For example, on the top of my head, CPUOffloadOptimizer will not have self.state.

In the meantime, CPUOffloadOptimizer requires setting LR manually https://github.com/pytorch/ao/pull/584#issuecomment-2364915318

For 2, it's an oversight from my part. We can simply add a requires grad check here. Will push a fix https://github.com/pytorch/ao/blob/27619174ed5a372a1ce96a0615089c5a08c88566/torchao/prototype/low_bit_optim/cpu_offload.py#L68-L77