pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.29k stars 115 forks source link

Remove unnecessary .to() inside model forward #298

Closed wconstab closed 2 months ago

wconstab commented 2 months ago

Stack from ghstack (oldest at bottom):

This appears to be a holdover from a previous way the initialization worked.

freqs_cis should already be on gpu device after initialization.

See this conversation for reference.