GanjinZero / RRHF

[NIPS2023] RRHF & Wombat
792 stars 49 forks source link

This loss seems to consume a lot of memory. #13

Open piekey1994 opened 1 year ago

piekey1994 commented 1 year ago

The idea of this paper is really great and much easier to understand than ppo. However, if there are six candidate responses, then at least batch size should be equal to 6 when calculating loss once. If the model scale is large, it seems difficult for a GPU to support a forward operation. I think the tokens generated in the paper has been cut to 192, which is far lower than the 2048 configured in ordinary LLM training. Is this also the reason? Is there any optimization strategy to solve this problem? For example, a step only calculates the rank loss of a pair of responses and the sft loss of the best response of the current pair group. I don't know if this is feasible

GanjinZero commented 1 year ago

When you are doing ordinary LLM training, you have a batch size that is the same as the max response count you can have. If you can train an LLM with a length 2048 with bsz=4, you can also train RRHF with a length 2048 with query=4. I don't think our produced loss has larger memory consumption than vanilla pre-training.

For some ideas to minimize memory consumption, you can pre-select queries and only calculate the loss on them.

piekey1994 commented 1 year ago

When you are doing ordinary LLM training, you have a batch size that is the same as the max response count you can have. If you can train an LLM with a length 2048 with bsz=4, you can also train RRHF with a length 2048 with query=4. I don't think our produced loss has larger memory consumption than vanilla pre-training.

For some ideas to minimize memory consumption, you can pre-select queries and only calculate the loss on them.

I know what you mean, but for example, if I want to train a 60b llama model now, I may only use batch size=2 or 1 for training. If so, how can we train an RRHF model? By consuming more gpu memory, I mean that when training ppo, I don't need to calculate the loss of multiple responses in one step.

GanjinZero commented 1 year ago

If you can only use bsz=2, you can still use RRHF to rank these two responses. If you can only have bsz=1, we must need to either truncate input or use something like LORA.

GanjinZero commented 1 year ago

There is a possible thing for saving GPU memory that we have not implemented is every response share the same query. Thus we do not need to recompute the query many times. If query is much longer than response, this will save many gpu memory.