lululxvi / deepxde

A library for scientific machine learning and physics-informed learning
https://deepxde.readthedocs.io
GNU Lesser General Public License v2.1
2.47k stars 712 forks source link

Jax transform #1705

Closed bonneted closed 2 months ago

bonneted commented 2 months ago

As explained in https://github.com/lululxvi/deepxde/pull/1671

There was a problem with output/input transform in jax. Because it calculates gradients in a pointwise fashion, the input that passes through the network for the gradient calculation is of size (n_input,) so it generated an error when trying to access x[:,i] or f[:,i] in the input/output transformation. Adding reshape(1,-1) before and squeeze() after solved it.

I also added hard BC support for elasticity plate example