desy-ml / cheetah

Fast and differentiable particle accelerator optics simulation for reinforcement learning and optimisation applications.
https://cheetah-accelerator.readthedocs.io
GNU General Public License v3.0
27 stars 12 forks source link

11 differentiable cheetah #70

Closed jank324 closed 9 months ago

jank324 commented 10 months ago

This is a pull request to at least make sure that previous operation tracking works as a prerequisite for working automatic differentiation. The latter can be merged as a second work package.

The two tasks (after each of which can be merged into master) are:

NOTE: This pull request will most likely require adapting the README.

jank324 commented 10 months ago

@cr-xu there is a design decision we need to make for the differentiability part. That is, because everything needs to be in torch.Tensors, we need to figure out how to get the user to do that. Here are some options:

Thoughts?

cr-xu commented 10 months ago

@jank324 Maybe Pydantic is a possible option?

image
jank324 commented 10 months ago

I really liked the idea of using Pydantic at first, but now I played around with it, I realised an object cannot be pydantic.BaseModel and torch.nn.Module at the same time. I think the latter is more important for us right now. I will keep playing with that for now. I think we can actually reevaluate user-friendliness once the rest is working. 🙃

jank324 commented 9 months ago

Okay .... here is the current state for our notes. I've made all Element a subclass of nn.Module. This way, you can assign an nn.Parameter to values instead of a torch.Tensor and then those values can be "trained" like the trainable parameters in a neural network, i.e. we can make use of all the optimisers and so on that PyTorch provides.

However, this does not play well with pydantic, so for now we rely on that the user either has type checking activated or uses the correct types. Currently Cheetah provides no safety net for when you don't use our classes correctly. We can check if we want to change this, once the rest is working.

Another annoyance is that implementing our own __eq__ does not play well with nn.Module. For now this is not a big issue, and we can also figure out the implications of this and, if needed, how to fix once the rest works.

In addition, I haven't added nn.Parameter to the Beam classes yet. I worry this might be more difficult, considering the translation of beam parameters to particles.

For now, I will fix all the tests with the new changes first.

jank324 commented 9 months ago

I think that's it for this one. Or is anything missing @cr-xu?