ucl-bug / jwave

A JAX-based research framework for differentiable and parallelizable acoustic simulations, on CPU, GPUs and TPUs
GNU Lesser General Public License v3.0
135 stars 21 forks source link

simulate_wave_propagation solver error when using accuracy that's not the default accuracy of 8 #224

Open whajjali opened 9 months ago

whajjali commented 9 months ago

Describe the bug When testing the time domain simulate_wave_propagation solver for the case of FiniteDifference parameters and/or initial conditions, it looks like the solver automatically converts them to accuracy=8 no matter what accuracy is specified when defining the FiniteDifference objects. In particular, the smoothing step done for the initial condition converts it to accuracy=8 if a different accuracy was used when constructing the initial condition. If smooth_initial=False is chosen, then the following error comes up after the first time step:

"TypeError: Scanned function carry input and carry output must have the same pytree structure, but they differ: the input carry component fields[0] is a <class 'jaxdf.discretization.FiniteDifferences'> with pytree metadata ('params', 'domain'), ('accuracy',), (2,) but the corresponding component of the carry output is a <class 'jaxdf.discretization.FiniteDifferences'> with pytree metadata ('params', 'domain'), ('accuracy',), (8,), so the pytree node metadata does not match"

which means that the mass_conservation_rhs and momentum_conservation_rhs return FiniteDifference objects of accuracy=8 rather than the specified accuracy. Upon further investigation, I noticed that the replace_params method always defaults to the default accuracy of 8.

To Reproduce Check the attached pdf file. homogeneous_medium_FD_test.pdf