pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.94k stars 357 forks source link

[Feature request] Add grad norm monitoring/logging #1407

Closed gau-nernst closed 1 week ago

gau-nernst commented 2 weeks ago

Personally I have found that monitoring grad norm is useful to understand stability of training. It is also useful to set an appropriate clipping value (though I don't think torchtune supports grad norm clipping atm?).

Some considerations:

ebsmothers commented 2 weeks ago

@gau-nernst thanks for this suggestion. Actually we had a similar discussion a few months back in #897, maybe I was too much of a stickler about it at that time 😅. In addition to the comments I left there, I agree with your point on not slowing down training. I think your proposal to only calculate at logging step is reasonable, but in practice I think many of our configs set log_every_n_steps=1 so in that case it doesn't do much for us anyways.

One alternative to using nn.utils.clip_grad_norm is to provide a bit more flexibility in the metrics we log in general. There is a world where we allow passing a list of callables that return any custom metrics that a user may want to compute and update the log dict accordingly somewhere around here. The main question is how we actually design this, but I can imagine the signature of each callable being optional loss, optional model, and optional optimizer, which should be reasonably flexible (I don't wanna get into per-batch vs per-step vs per-epoch logging though cause that's a whole other can of worms). The downside of this is that (a) we don't actually clip the grad norm (though technically I guess the user can do it themselves by calling that API) and (b) it might be a bit overengineered.

Otherwise there is the nn.utils.clip_grad_norm route. My main question here is around how we enable optional grad norm logging (so we don't have to slow down training) and/or clipping nicely from the config. Maybe something like clip_grad_norm: Optional[float] and log_grad_norm: bool?

Personally I am open to either of these approaches, would be interested to hear your thoughts on the pros and cons here as well.

Agree that doing this properly for FSDP will need a bit more thought (I assume we would want the norm across all ranks and not per-rank? Also I believe if we use ). But fine to punt it for now.

gau-nernst commented 2 weeks ago

From the discussion in #897, I agree with you that we should not enable gradient clipping by default. It should be set explicitly by the user (burned many times when default hparams are different across HF models :new_moon_with_face:)

Default log_every_n_steps=1 is not a big problem I think. If a user cares about perf, I think they will set this to a larger value (perhaps we should benchmark this some time too! What is the impact of logging every step). Keeping it default to 1 is fine, it's useful for debugging (logging show up early).

In terms of benchmark, we can also check how much calculating grad norm every step is gonna cost us. Maybe it's not that much? :thinking:

provide a bit more flexibility in the metrics we log in general

I think this is nice on paper but a nightmare to design and maintain :cry:. Realistically, I'm not sure if there are that many other useful metrics to log (except from task-specific things, which should be hard-coded in their own respective recipes already).

For now, I think it is reasonable to have clip_grad_norm: Optional[float].

So now the question is what to do when clip_grad_norm=None. To summarize the options we have so far:

ebsmothers commented 2 weeks ago

@gau-nernst I think your proposal makes sense.

So now the question is what to do when clip_grad_norm=None. To summarize the options we have so far:

  • Also log grad norm, always. Only calculate it on logging step. Should benchmark the impact of this to make sure.
  • Expose optional flag log_grad_norm: bool, which lets user to control whether to log grad norm.

In the absence of any data, I would lean towards the second option just to be safe. However, if we do find that the perf impact of logging grad norm is negligible, the first option would be fine too (and simpler). For benchmarking purposes we may want to look at distributed too since inevitably we will want to add it at some point and the clip_grad_norm call is likely to be more expensive in that case.

ebsmothers commented 2 weeks ago

Hey @gau-nernst are you working on this one already? If not we may have someone who can help out here

gau-nernst commented 2 weeks ago

I'm not working on this. You can assign this to someone else.