desy-ml / cheetah

Fast and differentiable particle accelerator optics simulation for reinforcement learning and optimisation applications.
https://cheetah-accelerator.readthedocs.io
GNU General Public License v3.0
27 stars 12 forks source link

[Vectorised branch] Tracking crashes with single `Marker`, or only elements without `lengths` between non skippable elemtents #143

Closed cr-xu closed 2 months ago

cr-xu commented 3 months ago

Currently, the Segment tracking will crash if it only contains.

Since it groups skippable elements in Segment and calls track recursively, it further causes problem if there are only length-less elements between non-skippable elements.

Reproducing the bug:

import cheetah
import torch

beam_in = cheetah.ParticleBeam.from_parameters(num_particles=100)

# Only Marker
segment1 = cheetah.Segment([cheetah.Marker(name="start")])

# This fails
beam_out = segment1.track(beam_in)

# Only length-less elements between non-skippable elements
segment2 = cheetah.Segment(
    [
        cheetah.Cavity(
            length=torch.tensor([0.1]), voltage=torch.tensor([1e6]), name="C2"
        ),
        cheetah.Marker(name="start"),
        cheetah.Cavity(
            length=torch.tensor([0.1]), voltage=torch.tensor([1e6]), name="C1"
        ),  # non-skippable
    ]
)

# This fails too
beam_out = segment2.track(beam_in)

# Only skippable elements
segment3 = cheetah.Segment(
    [
        cheetah.Marker(name="start"),
        cheetah.Drift(length=torch.tensor([0.1]), name="D1"),
        cheetah.Quadrupole(
            length=torch.tensor([0.1]), k1=torch.tensor([1.0]), name="Q1"
        ),
        cheetah.Marker(name="end"),
    ]
)

# This works
beam_out = segment3.track(beam_in)

This will complain:

  File ".../cheetah/accelerator.py", line 2219, in length
    lengths = torch.stack(
RuntimeError: stack expects a non-empty TensorList

This happens only on the vectorised branch, and should be fixed before merging #116

cr-xu commented 3 months ago

Hmm... I guess this is a general issue w.r.t. vectorised branch as the Marker doesn't get broadcasted along with other elements. The transfer map is not properly determined because it doesn't have a length attribute.

One easy fix would be to just assign length as a basic feature in the Element class and set it to 0.0.

What's your opinion @jank324

jank324 commented 3 months ago

I think setting length to 0.0 makes sense. On a base Element (and all that don't override length this should probably also be a read-only property.

We might have to check then where exceptions have been made for lengthless elements, code may potentially be simplified.

cr-xu commented 3 months ago

Are you working on this or should I do it?

jank324 commented 3 months ago

If you could do it, that would be great!