Closed anijain2305 closed 1 year ago
cc @eellison
The root cause here is because layer_norm and layer_norm_backward decomps do incorrect stride propagation. (working on cleaning up cross ref pr and submitting).
CC @ngimel - was a fix for this in the works or I am making that up?
layer_norm should be fixed by https://github.com/pytorch/pytorch/pull/84799, backward probably is still not respecting eager strides.
I think aot_eager
does not really go through decomps.
@anijain2305 its used for fake tensor stride propagation
Repro