stanford-crfm / levanter

Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax
https://levanter.readthedocs.io/en/latest/
Apache License 2.0
505 stars 80 forks source link

Just-In-Time Mixed Precision #426

Open dlwh opened 9 months ago

dlwh commented 9 months ago

It seems that JAX isn't just-in-time doing the bf16 conversion. Currently in Levanter, we do something like this:

def loss(m, x):
    m = convert(m, bf16)  # produces a sharded bf16 model
    for layer in m.layers:
       x = layer(x)  # implicit all-gather of layer
   ...

Unfortunately JAX isn’t smart enough (maybe shouldn't be smart enough? I dunno) to push the conversion into the fold with the all-gather, so we store a full bf16 copy of all the parameters on each device, which can add up to a lot of parameters.

What would be better (I think?) is to push the conversion into the fold, meaning that our loop looks like

def loss(m, x):
    for layer in m.layers:
       layer = convert(layer, bf16) # produces sharded bf16 copy of just this layer
       x = layer(x)  # implicit all-gather of layer
   ...

I think what I'd like to do is introduce in Haliax a mixed precision context manager analog to axis_mapping that just bundles a jmp.Policy, and then adjust Linear and conv to use these. (Probably ideally they'd take a dtype argument and if it is None it would default to using the context precision.)

Might even make sense to make a single "ComputeContext" (need a better name) that bundles axis_mapping, mesh, and jmp.Policy?

Fixing this could reduce memory usage a lot. Example OOM dump



Total hbm usage >= 21.91G:
    reserved        530.00M
    program           8.88G
    arguments        12.51G

Output size 12.51G; shares 12.51G with arguments.

Program hbm requirement 8.88G:
    global           321.0K
    scoped            3.66M
    HLO temp          8.87G (100.0% utilization: Unpadded (4.61G) Padded (4.61G), 48.1% fragmentation (4.26G))

  Largest program allocations in hbm:

... (Removing the first, biggest one, which is uninvolved

  2. Size: 560.00M
     Operator: op_name="jit(train_step)/jit(main)/convert_element_type[new_dtype=bfloat16 weak_type=False]" source_file="/home/dlwh/venv310/lib/python3.10/site-packages/jmp/_src/policy.py" source_line=31
     Shape: bf16[80,128,28672]{2,1,0:T(8,128)(2,1)}
     Unpadded size: 560.00M
     XLA label: convert.150 = convert(param.14)
     Allocation type: HLO temp
     ==========================

  3. Size: 560.00M
     Operator: op_name="jit(train_step)/jit(main)/convert_element_type[new_dtype=bfloat16 weak_type=False]" source_file="/home/dlwh/venv310/lib/python3.10/site-packages/jmp/_src/policy.py" source_line=31
     Shape: bf16[80,128,28672]{2,1,0:T(8,128)(2,1)}
     Unpadded size: 560.00M
     XLA label: convert.151 = convert(param.15)
     Allocation type: HLO temp
     ==========================

  4. Size: 560.00M
     Operator: op_name="jit(train_step)/jit(main)/convert_element_type[new_dtype=bfloat16 weak_type=False]" source_file="/home/dlwh/venv310/lib/python3.10/site-packages/jmp/_src/policy.py" source_line=31
     Shape: bf16[80,28672,128]{2,1,0:T(8,128)(2,1)}
     Unpadded size: 560.00M
     XLA label: convert.152 = convert(param.16)
     Allocation type: HLO temp
     ==========================

  5. Size: 160.00M
     Shape: bf16[80,128,8,8,128]{1,4,3,2,0:T(8,128)(2,1)}
     Unpadded size: 160.00M
     XLA label: copy.221 = copy(param.8), sharding={devices=[1,64,1,1,1]<=[64]}
     Allocation type: HLO temp
     ==========================
dlwh commented 8 months ago

this is implemented in the use_jamp branch and the accompanying haliax jamp branch