FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.48k stars 605 forks source link

Potential bug of RNN training flow #2455

Closed Bonjour-Lemonde closed 1 week ago

Bonjour-Lemonde commented 4 months ago

I had a strange problem using Flux RNN, my training data contains myX:one-hot vector, and myY:a number. The training data shown below worked very well using feedforward network(epoch=20,R2=0.9), but very low using Flux RNN(epoch=200,R2=0.2), what’s more, I am sure it is not the model architecture, because it trained well for other training data(R2=1)refer to as otherX otherY. I also found that the problem is all about my X, because my RNN network worked also well on [otherX, myY],[otherX,otherY], but not [myX,myY],[myX,otherY]. Thus I suggest it is associated with some bug of RNN training flow. below is the code. Hope anyone could help! Thanks!

# julia version 1.10.3
using Flux
oriX=["ATAGGAGGCGCGTGACAGAGTCCCTGTCCAATTACCTACCCAAA", "ATAGGAGGCGCAAGAGAGAAGCCCAGACCAATAACCTACCCAAA", "ATAGGAGGCTAACGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGAGGCGCCTGAGAAAAGCCCAGACCAATTACCTACCCAAA", "ATAGGACGCGCATGAGAGATGCCCTGACCAATTACCTACCCAAA", "ATAGGTGGTGCATGAGATAAGCACAGCTCAATACCCTACCCAAA", "ATAGGAGACGCAGGGGCGAAGCCCGGACCATTTACCTACCCAAA", "ATAGGTGGTGCATGAGATAATCCCAGACCAATTACCTACCCAAA", "ATAGGAGGCTCATGAGATAAGGCTTGACCAATTACCTACCCAAA", "ATAGGAGGCTCATGAGAGCAGCCCAGATTAATTACCTACCCAAA", "ATAGGAGGCGCGTGAGAGAGGACCCGACCAATTACTCACCCAAC", "ATAGGCAGCGCATGAGAGAAGCCCAGACCAATTACCTACTCAAC", "ATAGGAGGCTAACGAGAGAAGCCCAGACCACTTACCTACCCAAA", "ATAGGAGGCGCATGAGAAAAGCCCCGCCCAATTACCTACCCAAG", "ATAGGCGGCGCTTGAGAGAAGCCCATACCCATTACCTACCCAAA", "ATAGGCGGCACATGAGACAAGCCGAAGCCAATTACCTACCCAAA", "ATAGGCTGCGCATGAGAGAAGGCGACACAAATTACCTACCCAAA", "ATAGGCGGCACATGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGTGCAAGAGAGACGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCGTATGAGAGAAGCCCAGCTCAATTACCTACCCAAA", "ATAGGAGGCGCATGAGATAACCCACCACCAAGTACCTACTCAAA", "ATAGGTGGCGCATGAGAGCACCTCAGACGAAGTACCTACCCAAA", "ATAGGCGGCGCATGAGATAAGCCTAGACCATTTACCTACCCAAA", "ATAGGTGGCGCATGAGATAAGCGCATAACACCAACTTACCCAAC", "ATAGGCGGCGCATGAGACAAATCCAGGCCAATTATCTACCCAAA", "ATAGGCGGCTCATGAGATAAGCCCAGACCAAATACCTACCCAAA", "ATAGGAGGCGCATGAGAGAATCCCAAACCAATTCCCTACAAACC", "ATAGGCGGCGCATGAGACAAGCCCATACCAATTACCTACCCAAA", "ATAGGTGCGACTTGAGAGATGCCCATATCGACTACCTACCCGAA", "ATAGGCGGTGCATGACTGACGCCCAGACCAATTACCTACCCAAA", "ATAGGGGGCTAATGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGAGGCGCATGAGATAAGCCCAGACCAATTACCTACCCCGA", "ATAGGAGGTGCACGAGAGTTGCCCAGACCAATTAACTTCCCAAA", "ATAGGCGGCGCATGAGAAAAGCCCAGACCAATTACCTACCCAAA", "ATAGGTGGCCCGCGAGTTAGGACGAGACTAATTCCCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCCATTACCTACCCAAA", "ATAGGCGGCGGACGAGAGAAGCCCAGACCAATTACCTACCCATA", "ATAGGTGGCGCATGAAATAAAACCAGTGCAATTACCTACCCATA", "ATAGGCGACGCATGAGAAAAGCCCAGACCCATTACCTACCCAAA", "ATAGGCTGCGCATGAGAGAAGCCCAGACCAATTATCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAATTCCCTACCCAAA", "ATAGGAGTCGCCTGACAGATGACCATACCAATTACCTATCCAAA", "ATAGGCCGCGGATTAGACAACATCTTACCAATTCCCTGCCCAAA", "ATAGGCGGTGCAAGAGCGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGAAGCTAAAGGGAGTAGCTCAGTACAGTTAACTACCCCAA", "ATAGGCCGCGCATGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAGTTACCTACCCAAA", "ATAGGAAGCGCATGAGAAAAGCCCAGACAAATCACCTACCGAAC", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAGTTACCTACCCAAC", "ATAGGCGGCACATGAGCGCAGCCCAGTCCAATTACCTACCCAAA", "ATAGGCGGCGCATGACACAGGCCCAGACCAATGACCTACCCAAA", "ATAGGCAGCGCATGAGAGAAGCCCAGACCAATTACCTACTCAAA", "ATAGGCGACGAATGAGTGAAGCCCACATTAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAACTACATACCCAAA", "ATAGGCGGCGCATGAGACAAATCCAGGCCAATTACCTACCCAAA", "ATAGGCCGCCGATGAGAAAAGCCCGACGCACTTAACTACCCGAA", "ATAGGCGGTGCATGAGAGAGACGCAGTGCAAATACCTACCCAAA", "ATAGGCGGCGGATTAGAGAAGTCCAGACTATTTACCTACCCAAA", "ATAGGCGGCGAATGAGAGAAGCCCAGACCAATTACCTACCCAGA", "ATAGGCGGCGCATGAGATAAGCCCAGTCGAATTACCTACCCAAA", "ATAGGCCGCGCATGAGAAAAGCCTAGACCAATTGCCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAATTACCTACCCACA", "ATAGGCGACGCATGAGAGAAGCCCAGACGAATTACCTACCCAAA", "ATAGGTCCAGCATTAAGGCAGGCCAGACCCTTTACCTACCCAAA", "ATAGGAGGGACATGCGATAGGCTCAGACCAATTTCCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGGCCAATTAACTACCCAAA", "ATAGGCGGCGCATGAGAGTAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGGGAAGCCCAGACCCATTCCCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAATTAGCTACCCAAA", "ATAGGCGACGTATGAGAGAATCCCTGACCATTTACCTACCCAAA", "ATAGGCGGCGCATGATATAAGCCCAGCCCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACATATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCTAAGACCCATTACCTACCCAAA", "ATAGGCCGCGCATGAGAGAAGCTCAGACCCATTACCTACCCAAA", "ATAGGTGGCGCATGAGAGAAGCCCAGACCAATTACCTACACAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAATTACCTGCCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGTCCGATTTACTACCCAAG", "ATAGGCGGAGCATATGAGATGCCCAGACCAAATACCTACCCAAA", "ATAGGCGGCGCATGACAGAAGCCCTGACCGATAACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGAGCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCTCAGACCAATTACCTACCCAAA", "ATAGGTGTCGCTTGAAAATAGCCCAGACGAATTACCTACCCAAA", "ATAGGCGGCGCATGAGCGTTGCACAGACCAATTACCTACCCAAA", "ATAGGCGGCGTATGAGAGAAGCGCGGCCCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAGGCCCTGACCAAATAACTACCCAAA", "ATAGGCGGCTCATGAGAGAAGCCCAGACCAACTGCCTACCCAAA", "ATAGGCAGCGCATGAGTGAAGCCCAGACCAGTTACCTCCCCAAA", "ATAGGCAGCAGATGACAGTAGCCCCGACCAAATTACTACTCAAA", "ATAGGCGGCGCATGAGAGGAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCCCAGGAGAGCATCCAAGACCAATTACCTACCCAAA", "ATAGGCGGCGCATACGAGAAGCGCAGACTAATTACCTACCCAAA", "ATAGGCGGCGCATGACATAAGCCTAGATCAATTACCTACCCAAG", "ATAGGCGGCACATGACACAGGCCCAGACCAATGACCTACCCAAA", "ATAGGCGGCGCAGGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCGCGTAAGAGAAGCCTAGACAAATTACCTACCGAAA", "ATAGGCGACGCATCTGCGAATCCCACACCAATTACCTACCCGAA", "ATAGGCGACGCATGAGAGAAGCCCAGACCAATTAACTATCCATC", "ATAGGCGGCGCATGAGAGCAGCCCAAACCAATTGCCTACCCAAA", "ATAGGCGGCGCGTGAGAGTAGCCCTGACCAGTTTCCTGCCCAAA"]

ytrain = Float32[3.5539734 2.7561886 2.8113236 2.7176633 2.7606876 2.6220577 2.3115172 2.4817004 1.9276409 2.5030751 1.8989105 1.6381569 3.112245 1.9992489 1.8364053 2.1545537 1.8151137 1.9761252 2.0710406 1.8238684 1.5769696 2.2978039 2.0652819 1.6795048 1.4621212 1.8550924 1.3247801 2.0052798 1.5950761 2.1166725 1.1718857 1.443101 1.4597932 2.0249891 1.659723 1.7782362 1.3042092 1.3574703 1.7164876 1.4561561 1.6886593 1.5327756 1.3272716 1.2478243 1.6909612 0.9371975 1.3504946 1.7342895 1.0429348 1.6653012 1.6186994 1.6343817 1.1894267 1.6500783 1.1910686 1.5190029 0.93479043 1.5677443 1.2633525 1.4441946 1.8120437 1.6296253 1.3869075 1.7520566 1.247555 1.4638474 1.4413416 1.5457458 1.3801547 1.312296 0.96203357 1.571632 0.2540248 1.0096036 0.8302187 0.73939687 1.4816427 1.1275434 1.1184824 1.3548776 1.3924822 1.2923665 0.9824461 1.2085876 1.3007151 1.4721189 1.3741052 0.7266495 0.5496262 1.3403294 0.931344 0.7101498 1.3628994 1.8999943 1.2633573 1.1379782 0.6508444 0.5403087 1.435614 1.319527]

Xtrain = [map(x -> Flux.onehot.(x, "ACGT"), collect(join(oriX[idx]))) for idx in 1:100]

Xtrain_ffnn = hcat([vcat(x...) for x ∈ Xtrain]...)

# lossFunction and accuracy
function accuracy(m, X, y)
    Flux.reset!(m) # Only important for recurrent network
    R²(y, m(X))
end

function lossFun(m, X, y)
    Flux.reset!(m) # Only important for recurrent network
    Flux.mse(m(X),y)
end

# first learn the train data on feedforward
ffnn = Chain(
    Dense(176 => 128, relu),
    Dense(128 => 128, relu),
    Dense(128 => 1)
)
opt_ffnn = ADAM()
θ_ffnn = Flux.params(ffnn) # Keep track of the trainable parameters
epochs = 100 # Train the model for 100 epochs
for epoch ∈ 1:epochs
    # Train the model using batches of size 32
    for idx ∈ Iterators.partition(shuffle(1:size(Xtrain_ffnn, 2)), 32)
        X, y = Xtrain_ffnn[:, idx], ytrain[:, idx]
        ∇ = gradient(θ_ffnn) do 
            # Flux.logitcrossentropy(ffnn(X), y)
            Flux.mse(ffnn(X),y)
        end
        Flux.update!(opt_ffnn, θ_ffnn, ∇)
    end
    X, y = Xtrain_ffnn, ytrain
    @show accuracy(ffnn, Xtrain_ffnn, ytrain)
end

# then learn the train data by seq2one(RNN)

struct Seq2One
    rnn # Recurrent layers
    fc  # Fully-connected layers
end
Flux.@functor Seq2One # Make the structure differentiable
# Define behavior of passing data to an instance of this struct
function (m::Seq2One)(X)
    # Run recurrent layers on all but final data point
    [m.rnn(x) for x ∈ X[1:end-1]]
    # Pass last data point through both recurrent and fully-connected layers
    m.fc(m.rnn(X[end])) 
end

# Create the sequence-to-one network using a similar layer architecture as above
seq2one = Seq2One(
    Chain(
        RNN(4 => 128, relu),
        RNN(128 => 128, relu)
    ),
    Dense(128 => 1)
)
opt_rnn = ADAM()
θ_rnn = Flux.params(seq2one) # Keep track of the trainable parameters
epochs = 200 # Train the model for 10 epochs
for epoch ∈ 1:epochs
    # Train the model using batches of size 32
    for idx ∈ Iterators.partition(shuffle(1:size(Xtrain, 1)), 32)
        Flux.reset!(seq2one) # Reset hidden state
        X, y = Xtrain[idx], ytrain[:, idx]
        X = [hcat([x[i] for x ∈ X]...) for i ∈ 1:seqlen] # Reshape X for RNN format
        ∇ = gradient(θ_rnn) do 
            # Flux.logitcrossentropy(seq2one(X), y)
            Flux.mse(seq2one(X),y)
        end
        Flux.update!(opt_rnn, θ_rnn, ∇)
    end
    X, y = [hcat([x[i] for x ∈ Xtrain]...) for i ∈ 1:seqlen], ytrain
    @show accuracy(seq2one, X, y)
end
CarloLucibello commented 1 week ago

probably a duplicate of #2185