pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

multi_tensor_sgd triggers extra xla execution #7051

Open shenh10 opened 5 months ago

shenh10 commented 5 months ago

🐛 Bug

torch/benchmarks/dynamo testing suit sets SGD as the optimizer and set foreach flag to be True, to leverage the fast implementation of foreach cuda implementation. However it seems calling multi_tensor_sgd would trigger lazy tensor execution with openxla backend.

          self.optimizer = torch.optim.SGD(params, lr=0.01, foreach=True)
          # Disable multi_tensor_sgd for benchmarking, there isn't a large performance benefit (~1%) to compiling
          # this optimizer because it is a single foreach add, and increases compile time.
          # After autotuning and fake tensor caching lands, we can enable, becuase the compile time impact will be lower.
          # Fake Tensor caching: https://github.com/pytorch/pytorch/pull/113873
          # Autotuning: https://github.com/pytorch/pytorch/issues/117447
          self.optimizer.step = torch._dynamo.disable(self.optimizer.step)

From XLA debug info it seems multi_tensor_sgd is called and early results are expected from torch._foreachadd(device_params, device_grads, alpha=-lr) image

To Reproduce

Use any demo and modify optimizer to SGD like above

Environment

JackCaoG commented 5 months ago

hmm @bhavya01 Can you take a look at this issue since you are offcall this week? I am not sure why torch._foreach_add_(device_params, device_grads, alpha=-lr) will trigger the execution of the graph instead of record the operation in graph. My best guess that we didn't lower _foreach_add_.

bhavya01 commented 5 months ago

@shenh10 Can you help reproduce the issue with a simpler demo? I tried running a simpler script but didn't see this happening.

This gist contains the script that I used and the output. https://gist.github.com/bhavya01/0346b0d47931ba60751dbe79b01268a0

bhavya01 commented 5 months ago

@shenh10 Gentle ping if you can help reproduce the issue.

shenh10 commented 5 months ago

@shenh10 Gentle ping if you can help reproduce the issue.

I apologize for the delay in my response as I have not been able to dedicate time to this task recently. I will provide you with an update by the end of next week.