Closed chahak13 closed 1 year ago
A problem with the separate initialization method is that those attributes are lost when unflattened. Checked a few different code bases on github that seem to be doing something similar for some variables so will have to go with the case of having all of them in __init__
We have JIT support now. Here's an example for optimization with JIT (examples/optim_1d.py
)
Any suggestions on how to select which values to return from the solver? Should it be user dependent or should it just return all of it i.e. position, velocity, stress, strain, energies etc.? Current implementation in solver.py
should definitely not be used as it is hard coding things. Trying to think what would be a good user experience. Also, should I add 2D to this PR itself or once this refactoring is approved, merge this and create new for 2d?
Let's do a separate 2D implementation, we can approve this PR
Okay, sounds good, thanks. Let me clean it up a little and I'll move it out of draft.
@kks32 this should be good to review now, I think. Let me know if you have any thoughts or feel free to merge otherwise. Thanks!
A quick update. With the current setup, the 1d solver "works", i.e. it can solve a 1d axial bar problem without JIT. For JIT, the problem is slightly non-trivial now because we have more classes that move around than before. To make a class JIT-compatible, it needs the
tree_flatten
andtree_unflatten
methods (which will then be used to register the class as a PyTree node). While this is relatively easy, in our case, when we instantiate the objects, we also initialize extra arrays for various things (velocity, force etc.). Now, having this in__init__()
doesn't work as during JIT it "reinitializes" the objects and hence resets everything. The way I got around this issue previously was by passing these variables in the constructor and initializing only when passed values areNone
but this is not very clean. Since I know that method works, I'm going to first try having the initialization of arrays as another method that the solver calls and doesn't happen on instantiation (a bit like what's done in cb-geo/mpm) and see if I can make it workable with JIT. If not, will have to go with the previous method. Doing more reading to see if there are any other ways