It would be helpful if we could pass a cuda device explicitly to the train and predict nodes for DGX systems, to be able to spawn multiple models on different cards simultaneously on a single machine.
Currently both scripts default to "cuda:0" on device availability: "cuda:0"
I have added this into a fork of the repo, wanted to discuss if this is something useful to integrate here.
It would be helpful if we could pass a cuda device explicitly to the train and predict nodes for DGX systems, to be able to spawn multiple models on different cards simultaneously on a single machine. Currently both scripts default to "cuda:0" on device availability: "cuda:0"
I have added this into a fork of the repo, wanted to discuss if this is something useful to integrate here.
Best, Samia