Closed samuelstevens closed 9 months ago
Neither. It's a fair point. In Levanter I just have flags for places where I want to upcast the op (e.g. https://github.com/stanford-crfm/levanter/blob/main/src/levanter/models/gpt2.py#L181), which I think is more or less how it's done in Flax?
When I first started designing Levanter, I thought about arrays/modules/ops having a "semantic dtype" component (output, compute, parameter) and threading jmp through, but decided against it.
If you want something transparent, Haiku has a mechanism that's worth checking out it uses context mappings on ops to do it.
What are your thoughts?
I'm very new to Jax and have only used Equinox, without looking much at Flax or Haiku yet. I ended up simply casting everything to bfloat16 since my training runs were diverging with fp16, even when manually upcasting softmax and layernorms.
I think manually upcasting in model definitions is probably the best practice. I'm used to PyTorch, where I often don't write models from scratch anymore because paper authors provide fairly optimized implementations. But I guess it's fine to write models from scratch in Jax because XLA will optimize the CUDA ops and such.
Thanks for the discussion!
oh yeah. if you can avoid fp16, avoid fp16. it's awful
On Mon, Oct 16, 2023 at 6:19 AM Sam @.***> wrote:
I'm very new to Jax and have only used Equinox, without looking much at Flax or Haiku yet. I ended up simply casting everything to bfloat16 since my training runs were diverging with fp16, even when manually upcasting softmax and layernorms.
I think manually upcasting in model definitions is probably the best practice. I'm used to PyTorch, where I often don't write models from scratch anymore because paper authors provide fairly optimized implementations. But I guess it's fine to write models from scratch in Jax because XLA will optimize the CUDA ops and such.
Thanks for the discussion!
— Reply to this email directly, view it on GitHub https://github.com/stanford-crfm/haliax/issues/43#issuecomment-1764470804, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAACLIIU7WYOJERIUF2RWEDX7UX53AVCNFSM6AAAAAA56BIZMSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTONRUGQ3TAOBQGQ . You are receiving this because you commented.Message ID: @.***>
These pytorch docs have a list of fp16-safe ops and fp16-unsafe ops. I want to make sure my softmax operations run in fp32.
I read the jmp tutorial for haliax but I didn't see anything about promoting the softmax to fp32. Is this done automatically by jax? Or does haliax do this automatically somehow?