pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
194 stars 18 forks source link

add wait_tensor() after all_gather in float8 to fix mem leak #261

Closed bdhirsh closed 2 months ago

bdhirsh commented 2 months ago

I'm going to write a more detailed post internally to explain this memory leak