yandex-research / swarm

Official code for "SWARM Parallelism: Training Large Models Can Be Surprisingly Communication-Efficient"
121 stars 14 forks source link

Backwards fault recovery #2

Open NikolayBlagoev opened 6 months ago

NikolayBlagoev commented 6 months ago

Great work with this paper and congraturlations!

I had a quick question how disconnects are handled during a backwards pass. From the paper it seems that timeouts are only triggered on a forward pass. But during a backwards pass you need to return the gradients of a node's input to a node which has had that batch pass through it. I couldn't find any explanation on how this is handled in the paper and from what I see in the code, it seems just a random new expert is chosen, which doesn't seem to be a sound solution.

I was wondering if I am missing something.

Also, to replicate the results, what commands do you use to run the setup.py? Both install and build seem to be insufficient