pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.29k stars 115 forks source link

[fused_rmsnorm] Avoid conditional on dynamic stride #300

Closed wconstab closed 3 weeks ago

wconstab commented 2 months ago

Stack from ghstack (oldest at bottom):

The conditional expression forces evaluation of a symbolic quantity, which is not a problem in eager but prevents tracing/export and blocks using Pipeline Parallelism via tracing frontend.

Tensor.contiguous operator already handles returning self if the striding of self is already contiguous.

Fixes https://github.com/pytorch/PiPPy/issues/1108

wconstab commented 1 month ago

i think i abandoned this becuase even with this fix the PP tracer would not work due to other issues with the fused-rmsnorm custom op registration.

did we figure out a solution there? or are we totally giving up on fused_rmsnorm and switching to compile?

cc @lessw2020

wconstab commented 3 weeks ago

abandoned