Closed cr-xu closed 1 month ago
Hmm ... you are only giving one tilt. Does it work if you pass a batch of tilts?
This raises the question what we want the "correct" interface to look like.
Hmm ... you are only giving one tilt. Does it work if you pass a batch of tilts?
This raises the question what we want the "correct" interface to look like.
No the problem is actually due to the rotation matrix not properly broadcasted. Putting the tilts in the correct dimension doesn't help.
import torch
from cheetah import ParticleBeam, Drift, Segment, Quadrupole
beam_in = ParticleBeam.from_parameters(num_particles=torch.tensor(1000), energy=torch.tensor([1e9,1e9]), mu_x=torch.tensor([1e-5, 2e-5]))
batch_shape = beam_in.particles.shape[:-2]
segment = Segment([Quadrupole(length=torch.tensor([0.5, 0.5]), tilt=torch.tensor([torch.pi/4, torch.pi/4])), Drift(length=torch.tensor([0.5, 0.5]))])
segment(beam_in)
Another thing is I found out quite a lot of extended indexing was like [:, i, j] instead of [..., i, j] which would probably breakdown for multiple batch-dimensions, right?
Another thing is I found out quite a lot of extended indexing was like [:, i, j] instead of [..., i, j] which would probably breakdown for multiple batch-dimensions, right?
Yes, indeed. There might be situations wehere [:, i, j]
makes sense, but most of the time it should be [..., i, j]
. In the initial vectorised implementation, multi-dimensional batches were not intended. I only changed this afterwards. So it's possible I missed some places.
Rotation matrix is currently not batch-compatible. This is causing problem when elements are rotated (tilt!=0)