Closed jank324 closed 2 months ago
This is a tricky problem, which will be very helpful to have.
There are still two cases to be considered, e.g. batch evaluation of incoming_beam
and batch evaluation of segment_setting
The incoming_beam
situation is easier to consider and implement: we have a fixed Segment
and transfer_map
, we would just need to broadcast the transfer map multiplication properly with the batched/stacked incoming beam (which is already halfway done in the ParticleBeam
case anyway).
The changing segment_setting
case is more tricky, but also more important. Ultimately, we're modifying the network (transfer map) during the forward pass. One concern is whether we'll be able to keep the execution speed (somewhat) fast enough.
For bookkeeping purposes, the general feature I would like is the ability to perform fast calculation of outgoing beam parameters for one static incoming beam, and vectorized accelerator settings with the general form [..., n_batch, n_dim]
where n_dim
being the number of accelerator settings to be changed.
Currently, I have to do something like:
def forward(self, X: torch.Tensor) -> torch.Tensor:
input_shape = X.shape
X = X.reshape(-1, n_dim)
Y = torch.zeros(X.shape[:-1])
for i, x in enumerate(X):
self.segment.Q1.k1 = x[0] # Set the input parameters
...
# Track the beam
out_beam = self.segment(self.incoming_beam)
# Compute the objective
obj = ...
Y[i] = obj
return Y.reshape(input_shape[:-1])
which more or less scale linearly (poorly):
i 1.0
CPU times: user 1.55 ms, sys: 1.06 ms, total: 2.61 ms
Wall time: 1.85 ms
i 10.0
CPU times: user 4.6 ms, sys: 14 µs, total: 4.61 ms
Wall time: 4.65 ms
i 100.0
CPU times: user 38.6 ms, sys: 266 µs, total: 38.9 ms
Wall time: 39.4 ms
i 1000.0
CPU times: user 335 ms, sys: 1.22 ms, total: 336 ms
Wall time: 347 ms
i 10000.0
CPU times: user 3.23 s, sys: 15.6 ms, total: 3.25 s
Wall time: 3.26 s
Maybe we can take inspiration from how this was done here: https://github.com/UM-ARM-Lab/pytorch_kinematics
It would be in the spirit of the Cheetah's applications, to enable batched execution. This would be, for example, that a quadrupole's strength is set to vector of strengths, which would then result in a 3D tensor of transfer maps and allow parallel evaluations at much higher speed.