pytorch / PiPPy

Pipeline Parallelism for PyTorch
BSD 3-Clause "New" or "Revised" License
664 stars 79 forks source link

PP Tracer doesn't work with fused_rmsnorm #1108

Open wconstab opened 2 months ago

wconstab commented 2 months ago

Currently have to work around by using regular rmsnorm for PP to be enabled

torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
        # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm

Full trace https://gist.github.com/wconstab/3b68edda6bd30c2414403e91734ccc87 cc @kwen2501 @lessw2020

wconstab commented 2 months ago

is it safe to skip the if and just call .contiguous() all the time? maybe that is a no-op in the case that x is already contiguous?

image image
wconstab commented 2 months ago

some attempts to fix this (1) gets rid of conditionals on dynamic shapes, which gets me past the first tracing errors https://github.com/pytorch/torchtitan/pull/300 (2) does a hack for computing sm_count from device(0) which is unsafe. we might be able to make a version of this that is tracer-friendly somehow? https://github.com/pytorch/torchtitan/pull/301

After these I still hit a stride issue for the non-conditional usages of stride: (3)

    File "/data/users/whc/pytorch/torch/_dynamo/variables/tensor.py", line 322, in var_getattr                                                                                              
      unimplemented(f"Illegal getattr invocation {name} in strict mode")                      
    File "/data/users/whc/pytorch/torch/_dynamo/exc.py", line 212, in unimplemented                                                                                                         
      raise Unsupported(msg)                                                                                                                                                                
  torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode                                                                                                           

  from user code:                                                                                                                                                                           
     File "/data/users/whc/torchtitan/torchtitan/models/llama/model.py", line 428, in forward                                                                                               
      h = layer(h, freqs_cis)                                                                                                                                                               
    File "/data/users/whc/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl                                                                                                     
      return forward_call(*args, **kwargs)                                                    
    File "/data/users/whc/torchtitan/torchtitan/models/llama/model.py", line 317, in forward                                                                                                
      h = x + self.attention(self.attention_norm(x), freqs_cis)                                                                                                                             
    File "/data/users/whc/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl                                                                                                     
      return forward_call(*args, **kwargs)                                                                                                                                                  
    File "/data/users/whc/torchtitan/torchtitan/models/norms.py", line 61, in forward                                                                                                       
      return self.fused_rms_norm_fn(                                                                                                                                                        
    File "/data/users/whc/torchtitan/torchtitan/models/norms.py", line 316, in fused_rms_norm_fn                                                                                            
      return TritonFusedRMSNorm.apply(                                                                                                                                                      
    File "/data/users/whc/torchtitan/torchtitan/models/norms.py", line 294, in backward                                                                                                     
      dy.stride(0),                                                                                                                                                                         

  Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information                                                                                                                   

finally tried setting export strict=False in pippy _IR.py -- this fixes the dy.stride(0) issue, but then I still crash with a data-ptr access during tracing.

(4)

    File "/home/whc/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/triton/runtime/jit.py", line 435, in run
      spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
    File "/home/whc/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/triton/runtime/jit.py", line 435, in <genexpr>
      spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
    File "/home/whc/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/triton/runtime/jit.py", line 174, in specialization_key
      return (self.value.data_ptr() % JITFunction.divisibility == 0, )
    File "/data/users/whc/pytorch/torch/export/_safeguard.py", line 43, in __torch_function__
      return func(*args, **kwargs)
  RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ

I will try to wrap in a custom op (?)

can we just 'allow in graph' the whole call to fused_rmsnorm? why aren't we doing that already? 🤔