SciML / NeuralPDE.jl

Physics-Informed Neural Networks (PINN) Solvers of (Partial) Differential Equations for Scientific Machine Learning (SciML) accelerated simulation
https://docs.sciml.ai/NeuralPDE/stable/
Other
930 stars 198 forks source link

Rewrite NeuralPDE #687

Open YichengDWu opened 1 year ago

YichengDWu commented 1 year ago

If anyone wants to rewrite NeuralPDE, I want to share some of my expectations here, some of which have been reflected in Sophon.jl.

The following code is mostly pseudocode, used only to illustrate the concept. I will gradually add what I can think of.

On PhysicsInformedNN

A PINN should be created and accessed in the form of named tuples, just like Lux.Chain.

julia> chain=Chain(u = Dense(2,16), v = Dense(2,16))
Chain(
    u = Dense(2 => 16),                 # 48 parameters
    v = Dense(2 => 16),                 # 48 parameters
)         # Total: 96 parameters,
          #        plus 0 states, summarysize 32 bytes.

Then wrap it in a PINN structure. It can be accessed through getindex, but getproperty is preferred.

pinn = PINN(chain)
pinn.u = chain.layers.u # Something like this, but also need to handle state

Now we no longer need to record the order of dependent variables in the PDESystem. Parsing thus becomes simpler.

# u(t,x) + v(t,x) lowers to
begin
    phi_u = pinn.u
    phi_v = pinn.v
end
begin
    θ_u = θ.u
    θ_v = θ.v
end
phi_u(coord_u, θ_u) .+ phi_v(coord_v, θ_v)

The code generating begin ... end should be wrapped in a reusable function, and so is u(t,x)->phi_u(coord_u, θ). Perhaps @rule can play a role here, I'm not sure. We need to break down transform_expressions into small functions or rewrite rules

Single output & multioutput

I think we can just change

phi_u=phi.u, θ_u = θ.u

to

phi_u=phi, θ_u = θ

there is no need to make a difference in parsing other than that.

On coordinate

The current parsing is like this

# u(t,x), v(t)
(coord, θ)->
let x, t = coord[[1],:], coord[[2],:]
    coord1 = vcat(x,t)
    coord2 = vcat(t)
    # computation...
end

Here unnecessary memory allocation caused by vcat has occurred. I hope there is a get_coord function that performs the least amount of vcat.

function get_coord(u)
    if arguments(u) == all_indvars
        return :(coord)
    else if length(arguments(u)) == 1
        return arguments(u)[1] # say :t
    else 
        return :(vcat($(arguments(u)...)))
    end
end

Similarly, we prefer to use named suffixes.

# u(t,x), v(t)
(coord, θ)->
let (x, t) = (coord[[1],:], coord[[2],:])
    begin 
        coord_u = coord
        coord_v = t
    end
    # ...
end

Periodic boundary conditions

Periodic boundary conditions are not currently handled correctly, see #469. Because the same dependent variable has different inputs. It needs to be treated specially.

A more general approach is to parse expressions like u(1.0,t) into

(coord, θ)->
begin 
    phi_u = pinn.u
end
begin 
    θ_u = θ.u
end
let (x, t) = (coord[[1], :], coord[[2], :])
    begin                                  
        coord_u = coord
    end
    phi_u(Base.Fix1(fill, 1.0)(size(x)), θ_u)
end

In this way, periodic conditions are naturally parsed correctly.

# u(-1.0, t) ~ u(1.0, t) lowers to
begin 
     phi_u = pinn.u
end
begin 
    θ_u = θ.u
end
let (x, t) = (coord[[1], :], coord[[2], :])
    begin                                   
        coord_u = coord
    end
    phi_u(vcat(Base.Fix1(fill, -1.0)(size(x)), t)) .- phi_u(Base.Fix1(fill, 1.0)(size(x)), t))
end

Even if something like this appears in the equation:

u(x,t)+u(1.0,t)+v(x,1.0)

It can also be correctly parsed.

On derivative

I imagine having the following parsing function:

expr = Dxxx(Dtt(u(t,x)))
function get_directions(expr)
    # some code
    return [2,2,2,1,1] # from outermost to innermost
end

function get_derivative(expr)
    directions = get_directions(input) #  [2,2,2,1,1]
    mixed = any(!==(first(directions)), directions) # Is this derivative mixed?
    if mixed
        # generate an expression here
        orders = get_orders(directions) # e.g. [2,2,2,1,1] -> [3,2]
        directions = unique(directions) # e.g. [2, 1]
        εs = map(get_ε, orders, directions)

        # generate an expression here
        return quote
            finitediff((x,ps)-> finitediff(phi_u, x, ps, ($εs[2]), $(Val(orders[2]))),
                       coord_u, θ, $(εs[1]), $(Val(orders[1]))) 
        end
    else 
        order = length(directions)
        ε = get_ε(first(directions), order)
        return :(finitediff(phi_u, coord_u, θ, $ε, $(Val(order))))
    end
end

Sample

There should be an independent sample function, which can be used for resampling, and then passed into prob

data = sample(pde_system, sampler)
prob = remake(prob; p = data)

Here length(pde_datasets)==length(eqs), length(boundary_datasets)= length(bcs). Note that I believe the same data points should be used for all equations, which may help convergence. In any case, it saves memory.

function sample(pde_system, sampler)
    pde_dataset = sample(bounds,sampler)
    pde_datasets = [pde_dataset for _ in 1:length(pde_system.eqs)]
    return [pde_datasets; boundary_datasets]
end

Scalaring

Before scalarize, each equation or boundary condition returns the residuals at each data point, not a scalar. We scalarize only at the very end.

function scalarize(phi, weights::Tuple{Vararg{Function, N}},
                              loss_functions::Tuple{Vararg{Function, N}}) where N
    ex = :(mean($(weights[1])(phi, data, θ) .*
                abs2.($(loss_functions[1])(data[1], θ))))
    for i in 2:N
        ex = :(mean($(weights[i])(phi, data, θ) .*
                    abs2.($(loss_functions[i])(data[$i], θ))) + $ex)
    end
    loss_f = :((θ, data) -> $ex)
    return eval(loss_f)
end

Note that weights are a tuple of functions, each assigning a weight to each data point in each equation, the most basic case is Returns(1).

Many adaptive algorithms need point-wise residuals, and we can hack weights to achieve changing weights without needing to make additional code changes.

ChrisRackauckas commented 1 year ago

@xtalax can you describe where your parser rewrite is at? I think it would be fine to even do an earlier v6 with some features lost if it gets closer to these goals. One thing I think needs a total rewrite anyways is the integro differential equation support.

xtalax commented 1 year ago

I have changed everything in the callstack up to parse equation to use symbolics and have started on parse_equation, I can pivot to this style though

YichengDWu commented 1 year ago

Sophon can generate the symbolic loss functions I want, but it also falls short of what I'm talking about here. Its underlying implementation is just as hard to read as NeuralPDE, especially after the introduction of DeepONet. I removed support for integration and default parameters there, which should be kept in NeuralPDE.

What I would expect is to use advanced tools to transform expressions, such as MacroTools.postwalk, which seems like something you're already using.

I'm happy with the design of Sophon's interface:

prob = Sophon.discretize(pde_system, pinn, sampler, strategy)

By the way, I don't actually use RuntimeGeneratedFunctions explicitly at all, and everything seems to work fine.