NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.34k stars 1.39k forks source link

Tkurth/new gbn #1678

Closed azrael417 closed 1 year ago

azrael417 commented 1 year ago

This MR separates the plan creation and execution for cudnn GBN. The issue is that with recent RTC features, the old code breaks under graph capture. The new implementation caches the plan and reuses it for FWD and BWD.

I need to wait for cudnn-frontend to be updated to the latest version, therefore my repo does not have cudnn-frontend as subrepo. I will fix this once the new cudnn-frontend was merged, so please just review the cudnn gbn wrapper implementation for now and treat the rest of that MR as a draft.