desy-ml / cheetah

Fast and differentiable particle accelerator optics simulation for reinforcement learning and optimisation applications.
https://cheetah-accelerator.readthedocs.io
GNU General Public License v3.0
34 stars 13 forks source link

Deepcopy in BPM prevents gradient calculation #225

Open Hespe opened 3 months ago

Hespe commented 3 months ago

In the BPM, there is a call to deepcopy before passing on the beam down the line: https://github.com/desy-ml/cheetah/blob/b38a654e0e8971d42d40fad4ffd9f92068140267/cheetah/accelerator/bpm.py#L51

Unfortunetly, this copy currently inhibits taking gradients of the BPM reading. The reduced example

import torch
import cheetah

beam = cheetah.ParameterBeam.from_parameters(mu_x=torch.tensor([0.0], requires_grad=True))

bpm = cheetah.BPM()
bpm.track(beam)

fails with

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment. If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see https://github.com/pytorch/pytorch/pull/103001

Hespe commented 2 weeks ago

Maybe the deepcopy can simply be replaced by a call to torch.clone?