desy-ml / cheetah

Fast and differentiable particle accelerator optics simulation for reinforcement learning and optimisation applications.
https://cheetah-accelerator.readthedocs.io
GNU General Public License v3.0
27 stars 12 forks source link

Missing batch-execution compatibility in Dipole, Rotation matrix, etc. #154

Closed cr-xu closed 1 month ago

cr-xu commented 1 month ago

Rotation matrix is currently not batch-compatible. This is causing problem when elements are rotated (tilt!=0)

import torch
from cheetah import ParticleBeam, Drift, Segment, Quadrupole

beam_in = ParticleBeam.from_parameters(num_particles=torch.tensor(1000), energy=torch.tensor([1e9,1e9]), mu_x=torch.tensor([1e-5, 2e-5]))
batch_shape = beam_in.particles.shape[:-2]
segment = Segment([Quadrupole(length=torch.tensor([0.5]), tilt=torch.tensor(torch.pi/4)), Drift(length=torch.tensor([0.5]))]).broadcast(batch_shape)
segment(beam_in)
jank324 commented 1 month ago

Hmm ... you are only giving one tilt. Does it work if you pass a batch of tilts?

This raises the question what we want the "correct" interface to look like.

cr-xu commented 1 month ago

Hmm ... you are only giving one tilt. Does it work if you pass a batch of tilts?

This raises the question what we want the "correct" interface to look like.

No the problem is actually due to the rotation matrix not properly broadcasted. Putting the tilts in the correct dimension doesn't help.

import torch
from cheetah import ParticleBeam, Drift, Segment, Quadrupole

beam_in = ParticleBeam.from_parameters(num_particles=torch.tensor(1000), energy=torch.tensor([1e9,1e9]), mu_x=torch.tensor([1e-5, 2e-5]))
batch_shape = beam_in.particles.shape[:-2]
segment = Segment([Quadrupole(length=torch.tensor([0.5, 0.5]), tilt=torch.tensor([torch.pi/4, torch.pi/4])), Drift(length=torch.tensor([0.5, 0.5]))])
segment(beam_in)
cr-xu commented 1 month ago

Another thing is I found out quite a lot of extended indexing was like [:, i, j] instead of [..., i, j] which would probably breakdown for multiple batch-dimensions, right?

jank324 commented 1 month ago

Another thing is I found out quite a lot of extended indexing was like [:, i, j] instead of [..., i, j] which would probably breakdown for multiple batch-dimensions, right?

Yes, indeed. There might be situations wehere [:, i, j] makes sense, but most of the time it should be [..., i, j]. In the initial vectorised implementation, multi-dimensional batches were not intended. I only changed this afterwards. So it's possible I missed some places.