ucl-bug / jaxdf

A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations
GNU Lesser General Public License v3.0
117 stars 7 forks source link

fix: avoids changing parameters of OnGrid inside jax transformations #96

Closed astanziola closed 1 year ago

astanziola commented 1 year ago

During the construction of OnGrid objects, the params field was implemented as a dynamic property using setters and getters. The reason was to ease the initialization of scalar fields, by avoiding the user to manually input and extra dimensions to numpy arrays that is used to specify the dimensionality of the field. This was creating problems with jax.vmap, for reasons explained in this page of the jax documentation.

This PR fixes it and adds the required tests

codecov[bot] commented 1 year ago

Codecov Report

Merging #96 (71f7edb) into main (ff8ba47) will increase coverage by 25.96%. The diff coverage is 100.00%.

@@             Coverage Diff             @@
##             main      #96       +/-   ##
===========================================
+ Coverage   36.49%   62.46%   +25.96%     
===========================================
  Files          12       12               
  Lines         959      959               
===========================================
+ Hits          350      599      +249     
+ Misses        609      360      -249     
Impacted Files Coverage Δ
jaxdf/discretization.py 85.02% <100.00%> (+37.12%) :arrow_up:
jaxdf/operators/differential.py 46.38% <100.00%> (+31.91%) :arrow_up:
jaxdf/operators/dummy.py 93.75% <100.00%> (+56.25%) :arrow_up:
jaxdf/operators/functions.py 37.14% <100.00%> (+10.47%) :arrow_up:
jaxdf/operators/magic.py 67.07% <100.00%> (+26.21%) :arrow_up:
jaxdf/core.py 78.62% <0.00%> (+20.61%) :arrow_up:
... and 2 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.