YichengDWu / Sophon.jl

Efficient, Accurate, and Streamlined Training of Physics-Informed Neural Networks
https://yichengdwu.github.io/Sophon.jl/dev/
MIT License
55 stars 5 forks source link

Rename `residual_function_n`? #224

Closed arthur-bizzi closed 1 year ago

arthur-bizzi commented 1 year ago

As it is, it's hard to grasp from the code alone what each residual_function does. The FAQ seems to indicate that n = 1 returns the total loss, but that seems at odds with the actual implementation, where there seem to be roughly as many functions as there are equations for the losses. Would it be too breaking to instead have this as arrays of functions residual_function_eq[i] and residual_function_bc[i]?

Even just a short description somewhere in the docs would be of great help to those of us using the package mostly for the symbolic machinery and the loss functions.

arthur-bizzi commented 1 year ago

Another reason for this request: as it is, the residual functions remain defined in the REPL but are not directly visible from the workspace. Running the computations again for a system of smaller order and trying to compute the total loss can then lead (and has lead, in my case) to some very hard to catch bugs.

Though this workflow might not be very common (running tests for multiple systems of different order), having discretize return loss function vectors seems cleaner overall. I would be willing to try and code this up if it matches your vision for the package, @YichengDWu.

YichengDWu commented 1 year ago

The FAQ seems to indicate that n = 1 returns the total loss

No, it would be the first residual function. residual_function_2 is the second one, and so on.

Actually, I have mentioned in the docs on how to inspect the generated expression of the residual functions. Once you have those native Julia expressions, you could call eval on them to obtain the function objects.

julia> residual_function_exprs = Sophon.symbolic_discretize(poisson, pinn, sampler, strategy);

julia> residual_functions = eval.(ans)
3-element Vector{Function}:
 #9 (generic function with 1 method)
 #11 (generic function with 1 method)
 #13 (generic function with 1 method)