Closed jank324 closed 9 months ago
@cr-xu there is a design decision we need to make for the differentiability part. That is, because everything needs to be in torch.Tensor
s, we need to figure out how to get the user to do that. Here are some options:
assert
in __init__
would force users to use Tensors and catch wrong use early. However, it wouldn't catch when you reassign a value later. For that we would also need to define getters and setters for every variable that is used in computations and assert
there. That would be a lot of work and is prone to errors when we define new variables or Elements and forget to define getters or setters for something.assert
at the start of each computation when the variables are used. This clutters up the physics code a little, but would be guaranteed to catch errors easily. It might also cause a small slow-down, but I think that's minimal.__init__
and in getters and setters, or at the start of computations. This comes with the same aforementioned pros and cons. As an additional pro, it makes things easy and "clean" to use. As a con, it obscures what the code actually does.Thoughts?
@jank324 Maybe Pydantic is a possible option?
I really liked the idea of using Pydantic at first, but now I played around with it, I realised an object cannot be pydantic.BaseModel
and torch.nn.Module
at the same time. I think the latter is more important for us right now. I will keep playing with that for now. I think we can actually reevaluate user-friendliness once the rest is working. 🙃
Okay .... here is the current state for our notes. I've made all Element
a subclass of nn.Module
. This way, you can assign an nn.Parameter
to values instead of a torch.Tensor
and then those values can be "trained" like the trainable parameters in a neural network, i.e. we can make use of all the optimisers and so on that PyTorch provides.
However, this does not play well with pydantic
, so for now we rely on that the user either has type checking activated or uses the correct types. Currently Cheetah provides no safety net for when you don't use our classes correctly. We can check if we want to change this, once the rest is working.
Another annoyance is that implementing our own __eq__
does not play well with nn.Module
. For now this is not a big issue, and we can also figure out the implications of this and, if needed, how to fix once the rest works.
In addition, I haven't added nn.Parameter
to the Beam classes yet. I worry this might be more difficult, considering the translation of beam parameters to particles.
For now, I will fix all the tests with the new changes first.
I think that's it for this one. Or is anything missing @cr-xu?
This is a pull request to at least make sure that previous operation tracking works as a prerequisite for working automatic differentiation. The latter can be merged as a second work package.
The two tasks (after each of which can be merged into master) are:
torch
and results are alwaystorch.Tensor
and carrygrad_fn
when appropriate.NOTE: This pull request will most likely require adapting the README.