Multi gpu generation using hf.generate with device map = 'auto' does pipeline parallelism and moves different modules to different gpus. This results in input tensors to certain operations being on different gpus than other inputs to that operation, which results in an error. This PR moves the tensors to match the other tensors. This should not slow down training because during training all of these tensor movements should be no-ops.
Multi gpu generation using hf.generate with device map = 'auto' does pipeline parallelism and moves different modules to different gpus. This results in input tensors to certain operations being on different gpus than other inputs to that operation, which results in an error. This PR moves the tensors to match the other tensors. This should not slow down training because during training all of these tensor movements should be no-ops.