pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
212 stars 20 forks source link

add wait_tensor() after all_gather in float8 to fix mem leak #262

Closed bdhirsh closed 4 months ago

bdhirsh commented 4 months ago

I'm going to write a more detailed post internally to explain this memory leak

Tracking issue for a better fix in inductor: https://github.com/pytorch/pytorch/issues/126338

facebook-github-bot commented 4 months ago

@bdhirsh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 4 months ago

@bdhirsh merged this pull request in pytorch-labs/float8_experimental@6891cbe4293cdbef061e96cdfe06af064efe3efb.