desy-ml / cheetah

Fast and differentiable particle accelerator optics simulation for reinforcement learning and optimisation applications.
https://cheetah-accelerator.readthedocs.io
GNU General Public License v3.0
27 stars 12 forks source link

Error when tracking `ParameterBeam` through segment that is on CPU? #68

Closed jank324 closed 10 months ago

jank324 commented 10 months ago

I found this in the ParameterBeam part of the Element.__call__ function:

if self.device != "cpu":
    raise DeviceError

I think this doesn't make any sense (though I'm not sure). We should look into this at some point.

jank324 commented 10 months ago

Also ... this is checked after the actual calculations, which would have thrown an error about device mismatches already. We should really fix this!

cr-xu commented 10 months ago

I think initially the reasoning was that it doesn't make sense to shift the ParameterBeam to GPU, because the overhead would be more time-consuming than the actual tracking. But yes, it shouldn't really through an error now, I'll just fix this.