There was a problem with output/input transform in jax.
Because it calculates gradients in a pointwise fashion, the input that passes through the network for the gradient calculation is of size (n_input,) so it generated an error when trying to access x[:,i] or f[:,i] in the input/output transformation.
Adding reshape(1,-1) before and squeeze() after solved it.
I also added hard BC support for elasticity plate example
As explained in https://github.com/lululxvi/deepxde/pull/1671
There was a problem with output/input transform in jax. Because it calculates gradients in a pointwise fashion, the input that passes through the network for the gradient calculation is of size (n_input,) so it generated an error when trying to access x[:,i] or f[:,i] in the input/output transformation. Adding reshape(1,-1) before and squeeze() after solved it.
I also added hard BC support for elasticity plate example