Open wconstab opened 2 months ago
is it safe to skip the if and just call .contiguous() all the time? maybe that is a no-op in the case that x is already contiguous?
some attempts to fix this (1) gets rid of conditionals on dynamic shapes, which gets me past the first tracing errors https://github.com/pytorch/torchtitan/pull/300 (2) does a hack for computing sm_count from device(0) which is unsafe. we might be able to make a version of this that is tracer-friendly somehow? https://github.com/pytorch/torchtitan/pull/301
After these I still hit a stride issue for the non-conditional usages of stride: (3)
File "/data/users/whc/pytorch/torch/_dynamo/variables/tensor.py", line 322, in var_getattr
unimplemented(f"Illegal getattr invocation {name} in strict mode")
File "/data/users/whc/pytorch/torch/_dynamo/exc.py", line 212, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
from user code:
File "/data/users/whc/torchtitan/torchtitan/models/llama/model.py", line 428, in forward
h = layer(h, freqs_cis)
File "/data/users/whc/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/data/users/whc/torchtitan/torchtitan/models/llama/model.py", line 317, in forward
h = x + self.attention(self.attention_norm(x), freqs_cis)
File "/data/users/whc/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/data/users/whc/torchtitan/torchtitan/models/norms.py", line 61, in forward
return self.fused_rms_norm_fn(
File "/data/users/whc/torchtitan/torchtitan/models/norms.py", line 316, in fused_rms_norm_fn
return TritonFusedRMSNorm.apply(
File "/data/users/whc/torchtitan/torchtitan/models/norms.py", line 294, in backward
dy.stride(0),
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
finally tried setting export strict=False
in pippy _IR.py -- this fixes the dy.stride(0) issue, but then I still crash with a data-ptr access during tracing.
(4)
File "/home/whc/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/triton/runtime/jit.py", line 435, in run
spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
File "/home/whc/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/triton/runtime/jit.py", line 435, in <genexpr>
spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
File "/home/whc/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/triton/runtime/jit.py", line 174, in specialization_key
return (self.value.data_ptr() % JITFunction.divisibility == 0, )
File "/data/users/whc/pytorch/torch/export/_safeguard.py", line 43, in __torch_function__
return func(*args, **kwargs)
RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ
I will try to wrap in a custom op (?)
can we just 'allow in graph' the whole call to fused_rmsnorm? why aren't we doing that already? 🤔
Currently have to work around by using regular
rmsnorm
for PP to be enabledFull trace https://gist.github.com/wconstab/3b68edda6bd30c2414403e91734ccc87 cc @kwen2501 @lessw2020