SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
871 stars 156 forks source link

How can I integrate an RNN into an ODEProblem in Lux.jl? #952

Open disadone opened 1 month ago

disadone commented 1 month ago

Hi! Just wondering how the RNN could be mixed into the ODEProblem

In flux times, it seems a Recur layer need to be created. However there is already a Recurrence in Lux.jl Training of UDEs with recurrent networks

How can Lux.jl do the job now? I self defined a GRUcell and it runs well combined with the beginner tutorial Training a Simple LSTM


using ConcreteStructs: @concrete
using Lux
using Static
using Random

IntegerType = Union{Integer,Static.StaticInteger}
BoolType = Union{StaticBool, Bool, Val{true},Val{false}}

@concrete struct FastGRUCell <:Lux.AbstractRecurrentCell
    train_state <: StaticBool
    in_dims <: IntegerType
    out_dims <: IntegerType
    init_bias
    init_weight
    init_state
    dynamics_nonlinearity
    gating_nonlinearity
    α<:AbstractFloat
    layernormQ::StaticBool
end

function FastGRUCell(
        (in_dims,out_dims)::Pair{<:Lux.IntegerType,<:Lux.IntegerType},
        Δt::T, τ::T,layernormQ::BoolType;
        train_state::BoolType=False(),
        init_weight=Lux.glorot_normal,
        init_bias=Lux.zeros32,
        init_state=zeros32,
        dynamics_nonlinearity = Lux.sigmoid_fast,
        gating_nonlinearity = Lux.tanh_fast) where T<:AbstractFloat
    init_weight = ntuple(Returns(init_weight),3)
    init_bias = ntuple(Returns(init_bias),3)
    α = Δt/τ
    return FastGRUCell(
        static(train_state),
        in_dims,out_dims,init_bias,init_weight,init_state,
        dynamics_nonlinearity,gating_nonlinearity,α,static(layernormQ)
    )
end

function Lux.initialparameters(rng::AbstractRNG,gru::FastGRUCell)
    # hidden to hidden
    Wz,Wr,Wh = (Lux.init_rnn_weight(
        rng,init_weight,gru.out_dims,(gru.out_dims,gru.out_dims)) for init_weight in gru.init_weight)
    # input to hidden
    Uz,Ur,Uh = (Lux.init_rnn_weight(
        rng,init_weight,gru.out_dims,(gru.out_dims,gru.in_dims)) for init_weight in gru.init_weight)

    ps = (; Wz,Wr,Wh,Uz,Ur,Uh)

    biasz,biasr,biash = (Lux.init_rnn_weight(rng,init_bias,gru.out_dims,gru.out_dims) for init_bias in gru.init_bias)

    ps = merge(ps, (; biasz,biasr,biash))
    Lux.has_train_state(gru) &&  (ps = merge(ps, (hidden_state=gru.init_state(rng, gru.out_dims),)))
    return ps
end
Lux.initialstates(rng::AbstractRNG,::FastGRUCell) = (rng=Lux.Utils.sample_replicate(rng),)

function (gru::FastGRUCell{True})(x::AbstractMatrix,ps,st::NamedTuple)
    hidden_state = Lux.init_trainable_rnn_hidden_state(ps.hidden_state, x)
    return gru((x, (hidden_state,)), ps, st)
end

function (gru::FastGRUCell{False})(x::AbstractMatrix, ps, st::NamedTuple)
    rng = Lux.replicate(st.rng)
    st = merge(st, (; rng))
    hidden_state = Lux.init_rnn_hidden_state(rng, gru, x)
    return gru((x, (hidden_state,)), ps, st)
end

const _FastGRUCellInputType = Tuple{
    <:AbstractMatrix, Tuple{<:AbstractMatrix}}

function (m::FastGRUCell)(
    (x,(h,))::_FastGRUCellInputType, ps,st::NamedTuple)

    Wzh =  fused_dense_bias_activation(identity,ps.Wz,h,ps.biasz)
    Wrh =  fused_dense_bias_activation(identity,ps.Wr,h,ps.biasr)
    Uzx =  fused_dense_bias_activation(identity,ps.Uz,x,nothing)
    Urx =  fused_dense_bias_activation(identity,ps.Ur,x,nothing)

    z = dynamic(m.layernormQ) ? (m.gating_nonlinearity.(layernorm(Wzh,nothing,nothing) .+ Uzx)) : (@. m.gating_nonlinearity(Wzh+Uzx))
    r = dynamic(m.layernormQ) ? (m.gating_nonlinearity.(layernorm(Wrh,nothing,nothing) .+ Urx)) : (@. m.gating_nonlinearity(Wrh+Urx))

    Whh = fused_dense_bias_activation(identity,ps.Wh, h .* r ,ps.biash)
    Uhh = fused_dense_bias_activation(identity,ps.Uh, x ,nothing)
    h̃ = dynamic(m.layernormQ) ? (m.dynamics_nonlinearity.(layernorm(Whh,nothing,nothing) .+ Uhh)) : (@. m.dynamics_nonlinearity(Whh+Uhh))
    h′ = @. (1-m.α * z) * h + m.α * z * h̃
    return (h′,(h′,)),st
end

# --------------------------------------------------------------------------------------------------
# adapted from https://lux.csail.mit.edu/stable/tutorials/beginner/3_SimpleRNN#Creating-a-Classifier
using Lux, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end

struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:fastgru_cell, :classifier)}
    fastgru_cell::L
    classifier::C
end
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        FastGRUCell(in_dims => hidden_dims, 0.01f0, 1.0f0, true), 
        Dense(hidden_dims => out_dims, sigmoid))
end

function (s::SpiralClassifier)(
    x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}

    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_fastgru = s.fastgru_cell(x_init, ps.fastgru_cell, st.fastgru_cell)

    for x in x_rest
        (y, carry), st_fastgru = s.fastgru_cell((x, carry), ps.fastgru_cell, st_fastgru)
    end

    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    st = merge(st, (classifier=st_classifier, fastgru_cell = st_fastgru))

    return vec(y), st
end

 # ----- loss
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)

# ----- training

function main(model_type)
    dev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() .|> dev

    # Create the model
    model = model_type(2, 8, 1)

    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            # x: (2,50,128), y: (128,)  # dimension time trials
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoZygote(), lossfn, (x, y), train_state)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model(x, train_state.parameters, st_)
            loss = lossfn(ŷ, y)
            acc = accuracy(ŷ, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)

When I try to transfer my self-defined GRUcell to the tutorial MNIST Classification using Neural ODEs, I don't know how to start the job. Really appreciate If anyone could help me! Thanks!

ChrisRackauckas commented 3 weeks ago

The question does not make much sense. Having hidden state which is carried over to the next call makes the equation not an ODE and thus not convergent. If you do what you do here where you init the hidden state on each call, this model is equivalent to just calling the NN that is supposed to be recurrant, and so you might as well call that NN directly. So I don't quite get what you're trying to do?

disadone commented 3 weeks ago

Sorry for confusing. I would like to train a sequence-to-sequence model where the RNN could first derive a series of values and they are then fed into a sequence-to-sequence neuralode stuff as the inhomogeneous equation input. The weight in RNN and parameters in neuralode are trained together.

Maybe the question can be simplified as "How can I train a sequence-to-sequence neuralode with a series of inputs ?"

ChrisRackauckas commented 3 weeks ago

Maybe @avik-pal has an example

avik-pal commented 2 weeks ago

The weight in RNN and parameters in neuralode are trained together.

Do you mean the RNN weights and the neural network weights are shared?

disadone commented 2 weeks ago

The weight in RNN and parameters in neuralode are trained together.

Do you mean the RNN weights and the neural network weights are shared?

No, I mean the output of RNN could be the input of neualode at each time point.

ChrisRackauckas commented 1 week ago

But without state? Then it's not an RNN?