pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.3k stars 115 forks source link

Fused RMSNorm incompatible with PP tracing (dynamic stride) #217

Open wconstab opened 3 months ago

wconstab commented 3 months ago

The incompatibility is that during backwards, fused_rmsnorm does dynamic control flow over strides, which isn't safe for export tracing used by PP.

        dy = dy.view(-1, dy.shape[-1])
        if dy.stride(-1) != 1:
            dy = dy.contiguous()

Which leads to a stacktrace ending in

    File "/data/users/whc/pytorch/torch/_dynamo/variables/tensor.py", line 326, in var_getattr
      unimplemented(f"Illegal getattr invocation {name} in strict mode")     
    File "/data/users/whc/pytorch/torch/_dynamo/exc.py", line 204, in unimplemented
      raise Unsupported(msg)                                                                                                                                                                      
  torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode

Would it be possible to refactor this in a more export friendly way, or is that difficult?

cc @lessw2020, @kwen2501

lessw2020 commented 3 months ago

short term is the stride check can be removed to explore tracing (this check is rarely needed, confirmed on llama_7b).

Longer term this will either need a refactor to support dynamic strides (harder) or given the rarity, just a simple assert that we don't support non-contiguous.

awgu commented 3 months ago

I did not look into this closely, but could we rely on .contiguous() being a no-op if already contiguous and remove the stride check? (There might be ever-so-slightly more CPU overhead if there is a Python <> C++ switch from .contiguous(), but I think this should be okay for our purpose.)