desy-ml / cheetah

Fast and differentiable particle accelerator optics simulation for reinforcement learning and optimisation applications.
https://cheetah-accelerator.readthedocs.io
GNU General Public License v3.0
33 stars 13 forks source link

Moving elements and beams to devices doesn't work #113

Closed jank324 closed 2 months ago

jank324 commented 9 months ago

It should be possible, like with any other PyTorch Module, to move Elements and Beams by calling

element = element.to(device)

or

element = element.cpu()

or

element = element.cuda()

Unfortunately, this doesn't work right now.

jank324 commented 9 months ago

This will require registering all relevant variables as buffers:

https://stackoverflow.com/questions/60908827/why-pytorch-nn-module-cuda-not-moving-module-tensor-but-only-parameters-and-bu

cr-xu commented 8 months ago

This is still not fixed, right? So if I read the code correctly, now the device is specified only when one creates the Element, for changes afterwards one needs to shift every parameter to device manually...?

Are we still planning to integrate this feature, and how exactly... (explicitly registering the relevant parameters in buffer for all elements?)

jank324 commented 8 months ago

Right now it's bodged ... so you can do what you need to do in some way, but it's far from ideal.

This fix will be buffers (for everything that isn't a parameter). If we do it that way, everything should be movable like your average nn.Linear and so on.

However, this raises the question, what would happen if you do this ...

quadrupole.k1 = torch.tensor([1.0])

... i.e. reassign a property that is now a buffer. Am I changing the values in the buffer? Or -- more likely -- am I completely overriding the buffer with a normal tensor and this assignment doesn't really work?

I think I saw you doing something in the BO example of cheetah-demos, where in GPyTorch that have some (hacky) work around for this. So we have to figure out how to make this work.

Either way, I think this will be an add-on on top of the batched execution #116 for a later version (0.7.1 or so).

jank324 commented 3 months ago

An interesting thread to keep around here: https://discuss.pytorch.org/t/what-is-the-difference-between-register-buffer-and-register-parameter-of-nn-module/32723/10