It seems that JAX isn't just-in-time doing the bf16 conversion. Currently in Levanter, we do something like this:
def loss(m, x):
m = convert(m, bf16) # produces a sharded bf16 model
for layer in m.layers:
x = layer(x) # implicit all-gather of layer
...
Unfortunately JAX isn’t smart enough (maybe shouldn't be smart enough? I dunno) to push the conversion into the fold with the all-gather, so we store a full bf16 copy of all the parameters on each device, which can add up to a lot of parameters.
What would be better (I think?) is to push the conversion into the fold, meaning that our loop looks like
def loss(m, x):
for layer in m.layers:
layer = convert(layer, bf16) # produces sharded bf16 copy of just this layer
x = layer(x) # implicit all-gather of layer
...
I think what I'd like to do is introduce in Haliax a mixed precision context manager analog to axis_mapping that just bundles a jmp.Policy, and then adjust Linear and conv to use these. (Probably ideally they'd take a dtype argument and if it is None it would default to using the context precision.)
Might even make sense to make a single "ComputeContext" (need a better name) that bundles axis_mapping, mesh, and jmp.Policy?
Fixing this could reduce memory usage a lot. Example OOM dump
It seems that JAX isn't just-in-time doing the bf16 conversion. Currently in Levanter, we do something like this:
Unfortunately JAX isn’t smart enough (maybe shouldn't be smart enough? I dunno) to push the conversion into the fold with the all-gather, so we store a full bf16 copy of all the parameters on each device, which can add up to a lot of parameters.
What would be better (I think?) is to push the conversion into the fold, meaning that our loop looks like
I think what I'd like to do is introduce in Haliax a mixed precision context manager analog to axis_mapping that just bundles a jmp.Policy, and then adjust Linear and conv to use these. (Probably ideally they'd take a dtype argument and if it is None it would default to using the context precision.)
Might even make sense to make a single "ComputeContext" (need a better name) that bundles axis_mapping, mesh, and jmp.Policy?
Fixing this could reduce memory usage a lot. Example OOM dump