Open wconstab opened 3 months ago
short term is the stride check can be removed to explore tracing (this check is rarely needed, confirmed on llama_7b).
Longer term this will either need a refactor to support dynamic strides (harder) or given the rarity, just a simple assert that we don't support non-contiguous.
I did not look into this closely, but could we rely on .contiguous()
being a no-op if already contiguous and remove the stride check? (There might be ever-so-slightly more CPU overhead if there is a Python <> C++ switch from .contiguous()
, but I think this should be okay for our purpose.)
The incompatibility is that during backwards, fused_rmsnorm does dynamic control flow over strides, which isn't safe for export tracing used by PP.
Which leads to a stacktrace ending in
Would it be possible to refactor this in a more export friendly way, or is that difficult?
cc @lessw2020, @kwen2501