stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
141 stars 9 forks source link

How does haliax work with mixed precision? #43

Closed samuelstevens closed 9 months ago

samuelstevens commented 9 months ago

These pytorch docs have a list of fp16-safe ops and fp16-unsafe ops. I want to make sure my softmax operations run in fp32.

I read the jmp tutorial for haliax but I didn't see anything about promoting the softmax to fp32. Is this done automatically by jax? Or does haliax do this automatically somehow?

dlwh commented 9 months ago

Neither. It's a fair point. In Levanter I just have flags for places where I want to upcast the op (e.g. https://github.com/stanford-crfm/levanter/blob/main/src/levanter/models/gpt2.py#L181), which I think is more or less how it's done in Flax?

When I first started designing Levanter, I thought about arrays/modules/ops having a "semantic dtype" component (output, compute, parameter) and threading jmp through, but decided against it.

If you want something transparent, Haiku has a mechanism that's worth checking out it uses context mappings on ops to do it.

What are your thoughts?

samuelstevens commented 9 months ago

I'm very new to Jax and have only used Equinox, without looking much at Flax or Haiku yet. I ended up simply casting everything to bfloat16 since my training runs were diverging with fp16, even when manually upcasting softmax and layernorms.

I think manually upcasting in model definitions is probably the best practice. I'm used to PyTorch, where I often don't write models from scratch anymore because paper authors provide fairly optimized implementations. But I guess it's fine to write models from scratch in Jax because XLA will optimize the CUDA ops and such.

Thanks for the discussion!

dlwh commented 9 months ago

oh yeah. if you can avoid fp16, avoid fp16. it's awful

On Mon, Oct 16, 2023 at 6:19 AM Sam @.***> wrote:

I'm very new to Jax and have only used Equinox, without looking much at Flax or Haiku yet. I ended up simply casting everything to bfloat16 since my training runs were diverging with fp16, even when manually upcasting softmax and layernorms.

I think manually upcasting in model definitions is probably the best practice. I'm used to PyTorch, where I often don't write models from scratch anymore because paper authors provide fairly optimized implementations. But I guess it's fine to write models from scratch in Jax because XLA will optimize the CUDA ops and such.

Thanks for the discussion!

— Reply to this email directly, view it on GitHub https://github.com/stanford-crfm/haliax/issues/43#issuecomment-1764470804, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAACLIIU7WYOJERIUF2RWEDX7UX53AVCNFSM6AAAAAA56BIZMSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTONRUGQ3TAOBQGQ . You are receiving this because you commented.Message ID: @.***>