pytorch / torchrec

Pytorch domain library for recommendation systems
BSD 3-Clause "New" or "Revised" License
1.88k stars 406 forks source link

[Question] Is there gradient accumulation support for training? #2332

Open liuslnlp opened 4 weeks ago

liuslnlp commented 4 weeks ago

I am tuning hyper-parameters on two different compute clusters. Since the number of GPUs on these clusters varies, I need to use gradient accumulation (GA) to ensure that the total batch size is equal. Does torchrec support GA?

JacoCheung commented 2 days ago

Although this is a feature which I'm looking for as well, conisdering the embedding lookup backend is FBGEMM which combines optimizer update with backward at each single step, I would expect there is no GA supported.

gouchangjiang commented 2 days ago

Although this is a feature which I'm looking for as well, conisdering the embedding lookup backend is FBGEMM which combines optimizer update with backward at each single step, I would expect there is no GA supported.

Hi Jaco. According to your experience, how hard it is to add this GA functionality into the FGGEMM CPU/CUDA kernel?

JacoCheung commented 2 days ago

Hi @gouchangjiang I'm not a fbgemm expert, but I think it's not a trivial workload. Though it's feasible it may violate the design principle of fbgemm.

The principle of FBGEMM is to eliminate wgrad write back and so users can not access the wgrad. You can of course allocate a buffer and pass it into the backward kernels and remove the update and optimizer state related code(the original fbgemm kernel codes are optimizer templated & partial-instantiated) . But you have to pay:

  1. Extra memory footprint and time. Typically the wgrad is a sparse tensor (You may not want to have a dense tensor), and thus the shape is dynamic.
  2. Sparse tensor accumulation and exposure of update. GA means that you have to explictly trigger an update method. If the wgrad is a sparse tensor, you have to implement your own accumulation operations and optimizer.
  3. Adapter from fbgemm to torchrec EBC/EC. TorchRec has a deep calling stack, even you manage to expose the wgrad from fbgemm, you still need changes in torchrec codebase.
gouchangjiang commented 1 day ago

Thank you @JacoCheung . That's quite a lot of work.