geoelements / diffmpm

Differentiable Material Point Method
29 stars 6 forks source link

Refactor to decouple classes #3

Closed chahak13 closed 1 year ago

chahak13 commented 1 year ago

A quick update. With the current setup, the 1d solver "works", i.e. it can solve a 1d axial bar problem without JIT. For JIT, the problem is slightly non-trivial now because we have more classes that move around than before. To make a class JIT-compatible, it needs the tree_flatten and tree_unflatten methods (which will then be used to register the class as a PyTree node). While this is relatively easy, in our case, when we instantiate the objects, we also initialize extra arrays for various things (velocity, force etc.). Now, having this in __init__() doesn't work as during JIT it "reinitializes" the objects and hence resets everything. The way I got around this issue previously was by passing these variables in the constructor and initializing only when passed values are None but this is not very clean. Since I know that method works, I'm going to first try having the initialization of arrays as another method that the solver calls and doesn't happen on instantiation (a bit like what's done in cb-geo/mpm) and see if I can make it workable with JIT. If not, will have to go with the previous method. Doing more reading to see if there are any other ways

chahak13 commented 1 year ago

A problem with the separate initialization method is that those attributes are lost when unflattened. Checked a few different code bases on github that seem to be doing something similar for some variables so will have to go with the case of having all of them in __init__

  1. https://github.com/ott-jax/ott/blob/64d26ac4bb85a732d5e9d11010c9ae1394a99a6c/src/ott/initializers/nn/initializers.py#L35
  2. https://github.com/isabella232/ott/blob/56a9922f26b0e44be4a2d3de46808458cfe3e8cc/ott/core/sinkhorn.py#L268
  3. https://github.com/lindermanlab/ssm-jax/blob/3ee7580d0901e10ac5e237655e81133d5e2a67ca/ssm/hmm/transitions.py#L5
chahak13 commented 1 year ago

We have JIT support now. Here's an example for optimization with JIT (examples/optim_1d.py) image

chahak13 commented 1 year ago

Any suggestions on how to select which values to return from the solver? Should it be user dependent or should it just return all of it i.e. position, velocity, stress, strain, energies etc.? Current implementation in solver.py should definitely not be used as it is hard coding things. Trying to think what would be a good user experience. Also, should I add 2D to this PR itself or once this refactoring is approved, merge this and create new for 2d?

kks32 commented 1 year ago

Let's do a separate 2D implementation, we can approve this PR

chahak13 commented 1 year ago

Okay, sounds good, thanks. Let me clean it up a little and I'll move it out of draft.

chahak13 commented 1 year ago

@kks32 this should be good to review now, I think. Let me know if you have any thoughts or feel free to merge otherwise. Thanks!