Open killah-t-cell opened 3 years ago
Yeah the visualizations get hard there, of course. Animations, things with slides, etc. on top of Makie is probably the direction it has to go.
Agree. But is there a simple way (meaning, just using Plots) to visualize a PINN with 3 space dimensions I am missing though? We usually reserve one dimension for u
and plot it as a surface, but once we have x,y,z that is not possible.
I tried to do this with a 3D wave equation, but couldn't figure it out.
using NeuralPDE, Flux, ModelingToolkit, GalacticOptim, Optim, DiffEqFlux
import ModelingToolkit: Interval, infimum, supremum
@parameters t, x, y, z
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2
Dzz = Differential(z)^2
Dtt = Differential(t)^2
Dt = Differential(t)
#2D PDE
C=1
eq = Dtt(u(t,x,y,z)) ~ C^2*(Dxx(u(t,x,y,z))+Dyy(u(t,x,y,z))+Dzz(u(t,x,y,z)))
# Initial and boundary conditions
bcs = [u(t,0, y, z) ~ 0.,# for all t > 0
u(t,1, y, z) ~ 0.,# for all t > 0
u(t,x, 0, z) ~ 0.,# for all t > 0
u(t,x, 1, z) ~ 0.,# for all t > 0
u(t,x, y, 0) ~ 0.,# for all t > 0
u(t,x, y, 1) ~ 0.,# for all t > 0
u(0,x, y, z) ~ x*(1. - x), #for all 0 < x < 1
Dt(u(0,x, y, z)) ~ 0. ] #for all 0 < x < 1]
# Space and time domains
domains = [t ∈ Interval(0.0,1.0),
x ∈ Interval(0.0,1.0),
y ∈ Interval(0.0,1.0),
z ∈ Interval(0.0,1.0)]
# Discretization
dx = 0.1
# Neural network
chain = FastChain(FastDense(4,16,Flux.σ),FastDense(16,16,Flux.σ),FastDense(16,1))
initθ = Float64.(DiffEqFlux.initial_params(chain))
discretization = PhysicsInformedNN(chain, GridTraining(dx); init_params = initθ)
@named pde_system = PDESystem(eq,bcs,domains,[t,x,y,z],[u(t,x,y,z)])
prob = discretize(pde_system,discretization)
cb = function (p,l)
println("Current loss is: $l")
return false
end
# optimizer
opt = BFGS()
res = GalacticOptim.solve(prob,opt; cb = cb, maxiters=500)
phi = discretization.phi
### More elegant way to retrieve PINN
ts, xs, ys, zs = [infimum(d.domain):0.1:supremum(d.domain) for d in domains]
u_predict = [collect(phi([t,x,y,z], res.minimizer)[1] for x in xs, y in xs, z in zs) for t in ts]
anim = @animate for t ∈ eachindex(ts)
# how do I plot this?
end
gif(anim, "wave3d.gif", fps=10)
It took some researching, but I figured out how to plot this with Makie
using Makie
ts, xs, ys, zs = [infimum(d.domain):0.1:supremum(d.domain) for d in domains]
u_predict = [collect(phi([t,x,y,z], res.minimizer)[1] for x in xs, y in xs, z in zs) for t in ts]
positions = Node(u_predict[1])
scene = volume(xs, ys, zs, positions, colormap = :plasma, colorrange = (minimum(vol), maximum(vol)),figure = (; resolution = (800,800)),
axis = (; type=Axis3, perspectiveness = 0.5, azimuth = 7.19, elevation = 0.57,
aspect = (1,1,1)))
fps = 60
record(scene, "output.mp4", eachindex(ts)) do t
positions[] = u_predict[t]
sleep(1/fps)
end
It also works with just plots
anim = @animate for t ∈ eachindex(ts)
scatter(u_predict[t])
end
gif(anim, "wave3d.gif", fps=10)
It would be cool if it was possible to get a volumetric plot by doing something like:
anim = @animate for t ∈ eachindex(ts) plot(domains[2:4], phi([t,x,y,z], res.minimizer)) end
Domains are already defined, and generating big data arrays seems like a waste.
I had a generally positive experience working with higher dimensional PDEs, but I struggled to figure out how to analyze the results of these 3D+ PDEs effectively.
I think it is important for this package to have more 3D+ tutorials or examples. PINNs really start to make sense in >4D, so it is only logical to show users how to actually either plot or do something with that data at that altitude.
In the mid-term, it would be awesome to have a SciML-wide interface for PDESolutions https://github.com/SciML/SciMLBase.jl/issues/112. But for now, as a user, I have a few questions:
phi
to find the value of a variableu
at – for example – time t = 0.5?u
in an interval0:0.1:1
somehow? It would be useful to grab that, put it in an array and use it for further analysis.These are just some questions that immediately come to mind. In general, it would be helpful to have more examples showing how users can plot results in high dimensions, analyze and manipulate PINN results.
Thanks and I appreciate the direction this package is going!