desy-ml / cheetah

Fast and differentiable particle accelerator optics simulation for reinforcement learning and optimisation applications.
https://cheetah-accelerator.readthedocs.io
GNU General Public License v3.0
27 stars 12 forks source link

Batched execution #100

Closed jank324 closed 2 months ago

jank324 commented 7 months ago

It would be in the spirit of the Cheetah's applications, to enable batched execution. This would be, for example, that a quadrupole's strength is set to vector of strengths, which would then result in a 3D tensor of transfer maps and allow parallel evaluations at much higher speed.

cr-xu commented 7 months ago

This is a tricky problem, which will be very helpful to have.

There are still two cases to be considered, e.g. batch evaluation of incoming_beam and batch evaluation of segment_setting

The incoming_beam situation is easier to consider and implement: we have a fixed Segment and transfer_map, we would just need to broadcast the transfer map multiplication properly with the batched/stacked incoming beam (which is already halfway done in the ParticleBeam case anyway).

The changing segment_setting case is more tricky, but also more important. Ultimately, we're modifying the network (transfer map) during the forward pass. One concern is whether we'll be able to keep the execution speed (somewhat) fast enough.

For bookkeeping purposes, the general feature I would like is the ability to perform fast calculation of outgoing beam parameters for one static incoming beam, and vectorized accelerator settings with the general form [..., n_batch, n_dim] where n_dim being the number of accelerator settings to be changed.

Currently, I have to do something like:

def forward(self, X: torch.Tensor) -> torch.Tensor:
    input_shape = X.shape
    X = X.reshape(-1, n_dim)
    Y = torch.zeros(X.shape[:-1])
    for i, x in enumerate(X):
        self.segment.Q1.k1 = x[0] # Set the input parameters 
        ...
        # Track the beam
        out_beam = self.segment(self.incoming_beam)
        # Compute the objective
        obj = ...
        Y[i] = obj
    return Y.reshape(input_shape[:-1])

which more or less scale linearly (poorly):

i 1.0
CPU times: user 1.55 ms, sys: 1.06 ms, total: 2.61 ms
Wall time: 1.85 ms
i 10.0
CPU times: user 4.6 ms, sys: 14 µs, total: 4.61 ms
Wall time: 4.65 ms
i 100.0
CPU times: user 38.6 ms, sys: 266 µs, total: 38.9 ms
Wall time: 39.4 ms
i 1000.0
CPU times: user 335 ms, sys: 1.22 ms, total: 336 ms
Wall time: 347 ms
i 10000.0
CPU times: user 3.23 s, sys: 15.6 ms, total: 3.25 s
Wall time: 3.26 s
jank324 commented 7 months ago

Maybe we can take inspiration from how this was done here: https://github.com/UM-ARM-Lab/pytorch_kinematics