FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.47k stars 603 forks source link

RNN doesn't work as expected #2185

Open alerem18 opened 1 year ago

alerem18 commented 1 year ago

i tried to implement a RNN MODEL to classify Mnist Dataset but i get an accuracy around 40-50% even with running it for more than 20 epochs, while in pytorch, i'll get an accuracy upto 90% after just 4-5 epochs

here is my code:

using Flux
using Flux: onehotbatch, onecold, params, gradient
using MLDatasets: MNIST
using Base.Iterators: partition
using TensorCast
using Statistics: mean
using Random: shuffle

#---------------------------------- DATA -------------------------------------
DATA_TRAIN = MNIST.traindata(Float32)
DATA_TEST = MNIST.testdata(Float32)

#-------------------------------- PREPROCESS DATA ------------------------------
@cast x_train[j][i, k] := DATA_TRAIN[1][i, j, k] # reshape to vector of size 28 with matrix of size 28 x 60000
@cast x_test[j][i, k] := DATA_TEST[1][i, j, k] # reshape to vector of size 28 with matrix of size 28 x 10000

# create onehotbatch for train label
y_train = onehotbatch(DATA_TRAIN[2], 0:9)
y_test = DATA_TEST[2]

#------------------------------ CONSTANTS ---------------------------------------
INPUT_DIM = size(x_train[1], 1)
OUTPUT_DIM = 10 # number of classes
LR = 0.001 # learning rate
EPOCHS = 100
BATCH_SIZE = 1000
TOTAL_SAMPLES = size(x_train[1], 2)

#--------------------------------- BUILD MODEL -----------------------------------
struct RnnModel
  rnn
  fc
end

Flux.@functor RnnModel

# pass input thorough MODEL
function (m::RnnModel)(input_data)

  # warmup rnn
  [m.rnn(x) for x ∈ input_data[1:end - 1]]

  # pass latest layer to fc layer
  m.fc(m.rnn(input_data[end]))
end

# build MODEL
model = RnnModel(
  Chain(RNN(INPUT_DIM, 128), relu, RNN(128, 64), relu, RNN(64, 32), relu),
  Chain(Dense(32, OUTPUT_DIM), softmax)
)

#----------------------------- HELPER FUNCTIONS --------------------------------------
loss_fn(x, y) = Flux.Losses.logitcrossentropy(model(x), y)
function accuracy(x, y)
  Flux.reset!(model)
  mean(onecold(model(x), 0:9) .== y)
end

θ = params(model) # model parameters to be updated during training
opt = Flux.ADAM(LR) # optimizer function

#---------------------------- RUN TRAINING ----------------------------------------------
for epoch ∈ 1:EPOCHS
  for idx ∈ partition(1:TOTAL_SAMPLES, BATCH_SIZE)
    Flux.reset!(model)
    features = [x[:, idx] for x ∈ x_train]
    labels = y_train[:, idx]
    gs = gradient(θ) do
      loss = loss_fn(features, labels)
      loss
    end

    # update model
    Flux.Optimise.update!(opt, θ, gs)
  end

  # evaluate model
  @info epoch
  @show accuracy(x_test, y_test)
end

what i'm doing wrong?

ToucheSir commented 1 year ago

I'm surprised this works at all with the input format given. What does the PyTorch code look like and have you verified it's doing the same thing?

alerem18 commented 1 year ago

I'm surprised this works at all with the input format given. What does the PyTorch code look like and have you verified it's doing the same thing?

what should be the format? don't look at softmax with logitcrossentropy, i typed it here wrongly, it shouldn't be a vector of length seq_len with matrix of (features, batch_size)?

alerem18 commented 1 year ago

pytorch is quite different, it got a shape of (batch_size, seq_len, features) also i get much worse results by just reshape the data differently: @cast x_train[i][j, k] := DATA_TRAIN[1][i, j, k] # reshape to vector of size 28 with matrix of size 28 x 60000 @cast x_test[i][j, k] := DATA_TEST[1][i, j, k] # reshape to vector of size 28 with matrix of size 28 x 10000 the top reshapes will lead to a worse result

ToucheSir commented 1 year ago

pytorch is quite different, it got a shape of (batch_size, seq_len, features)

Flux supports something very similar. This is why it's important to see the PyTorch code as well, I have a feeling this is not an apples-to-apples comparison.

alerem18 commented 1 year ago

pytorch is quite different, it got a shape of (batch_size, seq_len, features)

Flux supports something very similar. This is why it's important to see the PyTorch code as well, I have a feeling this is not an apples-to-apples comparison.

pytorch implementation:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torchvision.datasets import MNIST
    from torchvision.transforms import ToTensor
    from torch.utils.data import DataLoader

    # ------------------------------ DATA -----------------------------------
    train_data = MNIST(train=True, root='data', transform=ToTensor())
    test_data = MNIST(train=False, root='data', transform=ToTensor())
    train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=1000, shuffle=True)

    # ---------------------------- MODEL --------------------------------------
    class RNN(nn.Module):
        def __init__(self, input_dim, output_dim):
            super(RNN, self).__init__()
            self.rnn = nn.RNN(input_dim, 128, batch_first=True)
            self.fc = nn.Linear(128, output_dim)

        def forward(self, x, h):
            x, h = self.rnn(x, h)
            x = F.relu(x)
            x = self.fc(x)
            # get last layer from rnn
            return x[:, -1, :], h

        def init_hidden(self, batch_size):
            return torch.zeros([1, batch_size, 128])

    # ----------------------- HELPER -----------------------------------
    # seq_len = 28, input_dim=28, num_classes=10
    model = RNN(input_dim=28, output_dim=10)
    loss_fn = nn.CrossEntropyLoss()  # includes softmax layer too so we don't need it in the model

    def accuracy(X, y):
        total_samples = X.shape[0]
        h = model.init_hidden(batch_size=total_samples)
        with torch.no_grad():
            pred_values, _ = model(X, h)
            return torch.sum(pred_values.max(1)[1] == y) / total_samples

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # --------------------------------- TRAIN LOOP ------------------------
    for epoch in range(1, 11):
        for data in train_loader:
            features = data[0].squeeze(1) # convert (batch_size, 1, 28, 28) to (batch_size, 28, 28)
            h = model.init_hidden(batch_size=features.shape[0]) # hidden state
            labels = data[1]
            predicted_values, _ = model(features, h)
            loss = loss_fn(predicted_values, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # get test data
        test_data = next(iter(test_loader))
        test_features = test_data[0].squeeze(1) # convert (batch_size, 1, 28, 28) to (batch_size, 28, 28)
        test_labels = test_data[1]
        print(f"epoch : {epoch}\\10\taccuracy : {accuracy(test_features, test_labels)}")

epoch : 1\10 accuracy : 0.8370000123977661 epoch : 2\10 accuracy : 0.8920000195503235

of course i used smaller batch__size in Flux like 64, 32, but still the same result i could reach to 74% accuracy in Flux using 30 epochs and Momentum Optimizer but in pytorch at the first epoch we have a high accuracy already

ToucheSir commented 1 year ago

Thanks, will try to take a look over the next couple of days. One quick observation though:

# includes softmax layer too so we don't need it in the model

This is also true for Flux's logitcrossentropy. So you shouldn't have a softmax in your Flux model either, and I suspect that having it is hurting performance because of the redundancy with the loss function.

alerem18 commented 1 year ago

yes i know, i've used logitcrossentropy without softmax, also softmax with crossentropy, but still same results i just typed the julia code here wrongly

skyleaworlder commented 1 year ago

I haven't run the code. But is there a possibility that model input is mistaken? 50% accuracy really reminds me of my one data-processing experience.

alerem18 commented 1 year ago

I haven't run the code. But is there a possibility that model input is mistaken? 50% accuracy really reminds me of my one data-processing experience.

how i should prepare my data then? mnist has a shape of 28 x 28 x 60000, if use second dimension as seq_len it's different from where i use first dimension as seq_len while there shouldn't be any difference at all and accuracy is not high for both inputs

jeremiedb commented 1 year ago

I took a shot rewriting the model as I would have go implementing it. It results in a 91% right after first epch, batchsize=64. Batchsize 1000 also works fine, just starts with accuracy of 77% but reaches 96% after 10 epochs.

using Flux
using Flux: onehotbatch, onecold, params, gradient
using MLDatasets: MNIST
using Base.Iterators: partition
using Statistics: mean
using Random: shuffle

#---------------------------------- DATA -------------------------------------
DATA_TRAIN = MNIST.traindata(Float32)
DATA_TEST = MNIST.testdata(Float32)

#-------------------------------- PREPROCESS DATA ------------------------------
x_train = [x for x in eachslice(DATA_TRAIN[1], dims=2)] # reshape to vector of size 28 with matrix of size 28 x 60000
x_test = [x for x in eachslice(DATA_TEST[1], dims=2)] # reshape to vector of size 28 with matrix of size 28 x 10000

# create onehotbatch for train label
y_train = onehotbatch(DATA_TRAIN[2], 0:9)
y_test = DATA_TEST[2]

#------------------------------ CONSTANTS ---------------------------------------
INPUT_DIM = size(x_train[1], 1)
OUTPUT_DIM = 10 # number of classes
LR = 0.001f0 # learning rate
EPOCHS = 10
BATCH_SIZE = 64
TOTAL_SAMPLES = size(x_train[1], 2)

#--------------------------------- BUILD MODEL -----------------------------------
model = Chain(
  RNN(INPUT_DIM => 128, relu),
  Dense(128, OUTPUT_DIM)
)

#----------------------------- HELPER FUNCTIONS --------------------------------------
function loss_fn_2(m, x, y)
  out = [m(xi) for xi in x] # generate output for each of the 28 timesteps
  Flux.Losses.logitcrossentropy(out[end], y) # compute loss based on predictions of the latest timestep
end

function accuracy_eval(m, x, y)
  Flux.reset!(m)
  out = [m(xi) for xi in x]
  mean(onecold(out[end], 0:9) .== y)
end  

θ = params(model) # model parameters to be updated during training
opt = Flux.ADAM(LR) # optimizer function

#---------------------------- RUN TRAINING ----------------------------------------------
for epoch ∈ 1:EPOCHS
  for idx ∈ partition(1:TOTAL_SAMPLES, BATCH_SIZE)
    features = [x[:, idx] for x ∈ x_train]
    labels = y_train[:, idx]
    Flux.reset!(model)
    gs = gradient(θ) do
      loss = loss_fn_2(model, features, labels)
    end
    # update model
    Flux.Optimise.update!(opt, θ, gs)
  end

  # evaluate model
  @info epoch
  @show accuracy_eval(model, x_test, y_test)
end

I think the data preprocessing was done fine (I just dropped the TensorCast dependency as I got an issue and felt simpler not using it).

I'm really unclear what went wrong with your implementation. It's really just a speculation, but perhaps the gradients didn't get propagated through the following part:

  [m.rnn(x) for x ∈ input_data[1:end - 1]]
  m.fc(m.rnn(input_data[end]))

as there's no explicit passing of the of the inital computation to the second. Again, just a wild guess here.

alerem18 commented 1 year ago

your code works but i really don't know why my code isn't working if the data preprocessing is the same i tried a different implementation similar to yours for calculation loss

    using Flux
    using Flux: onehotbatch, onecold, params, gradient
    using MLDatasets: MNIST
    using Base.Iterators: partition, product
    using TensorCast
    using Statistics: mean
    using Random: shuffle
    using StatsBase
    using ChainRulesCore, Zygote
    ChainRulesCore.@non_differentiable foreach(f, ::Tuple{})
    Zygote.refresh()
   # ---------------------------------- DATA -------------------------------------
    TRAIN_DATA, TRAIN_LABELS = MNIST.traindata(Float32)
    TEST_DATA, TEST_LABELS = MNIST.testdata(Float32)
    TRAIN_LABELS = onehotbatch(TRAIN_LABELS, 0:9)
    # convert 3d arrays to vector of 2d arrays
    @cast TRAIN_FEATURES[i][j, k] := TRAIN_DATA[i, j, k]
    @cast TEST_FEATURES[i][j, k] := TEST_DATA[i, j, k]

    INPUT_DIM = size(TRAIN_FEATURES[1], 1)
    DATA = [([x[:, idx] for x in TRAIN_FEATURES], TRAIN_LABELS[:, idx]) for idx ∈ partition(shuffle(1:size(TRAIN_LABELS, 2)), 1000)]

    # ----------------------------------- MODEL --------------------------------------------
    model = Chain(
        RNN(INPUT_DIM, 128, relu),
        Dense(128, 10)
    )
    # --------------------------------- HELPER -----------------------------------------------
    function loss_fn(X, Y)
        Flux.reset!(model)
        out = [model(x) for x ∈ X]
        Flux.Losses.logitcrossentropy(out[end], Y)
    end

    function accuracy(X, Y)
        Flux.reset!(model) # Only important for recurrent network
        out = [model(x) for x ∈ X]
        mean(onecold(out[end], 0:9) .== Y)
    end

    θ = params(model)
    opt = Flux.ADAM()
    evalcb() = @show(accuracy(TEST_FEATURES, TEST_LABELS))
    # ----------------------------------- TRAIN -------------------------
    Flux.@epochs 30 Flux.train!(loss_fn, θ, DATA, opt, cb = Flux.throttle(evalcb, 5))

still doesn't work

jeremiedb commented 1 year ago

Not on a computer right now, but I think you should remove the reset! from the loss function. And therefore, stick to a custom training loop instead of train!

alerem18 commented 1 year ago

Not on a computer right now, but I think you should remove the reset! from the loss function. And therefore, stick to a custom training loop instead of train!

i found out if i delete model in loss and accuracy function, i get bad results else it's working as expected: loss_fn(X, Y), accuracy(X, Y) ===> bad results loss_fn(m, X, Y), accuracy(m, X, Y) ==> good results

can you explain why this happens because it's too weird

ToucheSir commented 1 year ago

Can you show the before and after code for that change? It's not immediately clear what the difference would be.

CarloLucibello commented 1 year ago

@alerem18 if you manage to clarify what's the difference causing a bad result we can decide if we have an actual bug or not

alerem18 commented 1 year ago

loss_fn(X, Y), accuracy(X, Y) ===> bad results loss_fn(m, X, Y), accuracy(m, X, Y) ==> good results

passing model thorough loss and accuracy functions will work as expected, if you don't pass it to those functions, you'll get bad results, model doesn't improve after a while, accuracy will be around 50-60%

ToucheSir commented 1 year ago

What we're asking for is full code examples that show the good and bad results. Without that, loss_fn(X, Y) and loss_fn(m, X, Y) could be completely different functions for all we know. Having a complete example will allow us to run and try to reproduce the behaviour.

alerem18 commented 1 year ago
using Flux
using Flux: gradient, logitcrossentropy, params, Momentum
using OneHotArrays: onecold, onehotbatch
using MLDatasets: MNIST
using Random: shuffle
using Statistics: mean
using Base.Iterators: partition

# ------------------- data --------------------------
train_x, train_y = MNIST(split=:train).features, MNIST(split=:train).targets
test_x, test_y = MNIST(split=:test).features, MNIST(split=:test).targets
train_y = onehotbatch(train_y, 0:9)
train_x = [x for x ∈ eachslice(train_x, dims=2)]
test_x = [x for x ∈ eachslice(test_x, dims=2)]
# ------------------ constants ---------------------
INPUT_SIZE = 28
NUM_CLASSES = 10
BATCH_SIZE = 1000
EPOCHS = 5
# ------------------ model --------------------------
model = Chain(
    RNN(INPUT_SIZE, 128, relu),
    RNN(128, 64, relu),
    Dense(64, NUM_CLASSES)
)

# ---------------- helper --------------------------
loss_fn(m, X, y) = logitcrossentropy([m(x) for x ∈ X][end], y)
accuracy(m, X, y) = mean(onecold([m(x) for x ∈ X][end], 0:9) .== y)
opt = Momentum()
θ = params(model)

# --------------- train -----------------------------
for epoch ∈ 1:EPOCHS
    for idx ∈  partition(shuffle(1:size(train_y, 2)), BATCH_SIZE)
        Flux.reset!(model)
        X = [x[:, idx] for x ∈ train_x]
        y = train_y[:, idx]
        gs = gradient(θ) do 
            loss_fn(model, X, y)
        end
        Flux.Optimise.update!(opt, θ, gs)
    end
    Flux.reset!(model)
    test_acc = accuracy(model, test_x, test_y)
    @info "Epoch : $epoch | accuracy : $test_acc"
end

[ Info: Epoch : 1 | accuracy : 0.3968 [ Info: Epoch : 2 | accuracy : 0.7918 [ Info: Epoch : 3 | accuracy : 0.896 [ Info: Epoch : 4 | accuracy : 0.9365 [ Info: Epoch : 5 | accuracy : 0.9465

edit loss and accuracy functions like below and you get this results

loss_fn(X, y) = logitcrossentropy([model(x) for x ∈ X][end], y)
accuracy(X, y) = mean(onecold([model(x) for x ∈ X][end], 0:9) .== y)

[ Info: Epoch : 1 | accuracy : 0.2795 [ Info: Epoch : 2 | accuracy : 0.4944 [ Info: Epoch : 3 | accuracy : 0.3561 [ Info: Epoch : 4 | accuracy : 0.5146 [ Info: Epoch : 5 | accuracy : 0.5229 [ Info: Epoch : 6 | accuracy : 0.5467 [ Info: Epoch : 7 | accuracy : 0.598 [ Info: Epoch : 8 | accuracy : 0.6085 [ Info: Epoch : 9 | accuracy : 0.5953 [ Info: Epoch : 10 | accuracy : 0.6038 [ Info: Epoch : 11 | accuracy : 0.6063 [ Info: Epoch : 12 | accuracy : 0.6336 [ Info: Epoch : 13 | accuracy : 0.65 [ Info: Epoch : 14 | accuracy : 0.6488 [ Info: Epoch : 15 | accuracy : 0.5951 [ Info: Epoch : 16 | accuracy : 0.5911 [ Info: Epoch : 17 | accuracy : 0.61 [ Info: Epoch : 18 | accuracy : 0.6357 [ Info: Epoch : 19 | accuracy : 0.6215 [ Info: Epoch : 20 | accuracy : 0.6576

jeremiedb commented 1 year ago

Thanks, I can reproduce. Cause isn't obvious to me but the behavior seems to point that the reset! performed within the train loop fails to affect the effective model`s state used by the loss and accuracy functions.

In all cases, it appears safer to use the explicit reference to model for the loss and accuracy functions. It also looks like a an non obvious behavior that can lead to unexpected bad behavior, hence would be worth documenting if we could confirm the root cause.

ToucheSir commented 1 year ago

A quick sanity check would be moving θ = params(model) inside the inner training loop and seeing if that makes a difference. It's not immediately obvious to me why it would, but might as well eliminate one possibility.

jeremiedb commented 1 year ago

Unfortunately, no luck with adding training params instantiation within the training loop. The following results in the same accuracy plateau around 60%:

        ps = params(model)
        Flux.reset!(model)
        gs = gradient(ps) do 
            loss_fn2(X, y)
        end
alerem18 commented 1 year ago

In the modified code, the loss_fn and accuracy functions do not take the params of the model as input, and they call the model directly within the function to compute the loss and accuracy.

The params function is used to extract the trainable parameters of a model, which is necessary for computing gradients and updating the model parameters during training. When params is used, the optimizer is able to track the gradients of the model parameters and update them accordingly during optimization.

By not using params, the optimizer is not able to track the gradients of the model parameters correctly and this can lead to incorrect optimization and lower accuracy.

Therefore, not using params in the modified code is a mistake and can result in lower accuracy.

any thoughts?

ToucheSir commented 1 year ago

Part of the "magic" of passing a Params to gradient is that the trainable parameters do not have to be directly passed to the loss function. Instead, Zygote will track trainable parameters by object ID (basically hashes of memory addresses) and accumulate gradients accordingly. This is why we call using params working with "implicit parameters".

The problem here is that something is causing the aforementioned tracking to not work. Ordinarily both versions of the code should behave similarly, so this is a bug. It's also why we've moving away from magical implicit params to directly passing the model/trainable params to gradient and the loss function: it's way less bug-prone, easier to understand for users and easier to debug for developers.

jeremiedb commented 1 year ago

For reference, this is how you could use the new explicit gradient / Optimsers.jl mode:


loss_fn1(m, X, y) = logitcrossentropy([m(x) for x ∈ X][end], y)
accuracy1(m, X, y) = mean(onecold([m(x) for x ∈ X][end], 0:9) .== y)

rule = Flux.Optimisers.Adam()
opts = Flux.Optimisers.setup(rule, model);

for epoch ∈ 1:5
    for idx ∈  partition(shuffle(1:size(train_y, 2)), BATCH_SIZE)
        X = [x[:, idx] for x ∈ train_x]
        y = train_y[:, idx]
        Flux.reset!(model)
        gs = gradient(model) do m
            loss_fn1(m, X, y)
        end
        Flux.Optimisers.update!(opts, model, gs[1]);
    end
    Flux.reset!(model)
    test_acc = accuracy1(model, test_x, test_y)
    @info "Epoch : $epoch | accuracy : $test_acc"
end
fujiehuang commented 1 year ago

the RNN gradient with Zygote might have a bug. Here's my short test code. Keeping outputs in an array and in a scalar give me different gradients. How come?

using Flux 
using Random
Random.seed!(149)

layer1 = Flux.Recur(Flux.RNNCell(1 => 1, identity))

x = Float32[0.8, 0.9]
y = Float32(-0.7)

Flux.reset!(layer1)
e1, g1 = Flux.withgradient(layer1) do m
    yhat = 0.0
    for i in 1:2 
        yhat = m([x[i]])
    end
    loss = Flux.mse(yhat, y)
    println(loss)
    return loss 
end
println("flux gradients: ", g1[1])

Flux.reset!(layer1)
e2, g2 = Flux.withgradient(layer1) do m
    yhat = [m([x[i]]) for i in 1:2]
    loss = Flux.mse(yhat[end], y)
    println(loss)
    return loss 
end
println("flux gradients: ", g2[1])
jeremiedb commented 1 year ago

There's effectively something fishy going on with the RNN gradients.

using Flux 
layer2 = Flux.Recur(Flux.RNNCell(1, 1, identity))
layer2.cell.Wi .= 5.0
layer2.cell.Wh .= 4.0
layer2.cell.b .= 0f0
layer2.cell.state0 .= 7.0
x = [[2f0], [3f0]]
Flux.reset!(layer2)
ps = Flux.params(layer2)
e2, g2 = Flux.withgradient(ps) do
    out = [layer2(xi) for xi in x]
    sum(out[2])
end

julia> g2[ps[1]]
1×1 Matrix{Float32}:
 3.0

julia> g2[ps[2]]
1×1 Matrix{Float32}:
 38.0

julia> g2[ps[3]]
1-element Fill{Float32}, with entry equal to 1.0

julia> g2[ps[4]] # nothing

Theoretical gradients are:

julia> ∇Wi = x[1] .* layer2.cell.Wh .+ x[2] 
1×1 Matrix{Float32}:
 11.0

julia> ∇Wh = 2 .* layer2.cell.Wh .* layer2.cell.state0 .+ x[1] .* layer2.cell.Wi 
1×1 Matrix{Float32}:
 66.0

julia> ∇b = layer2.cell.Wh .+ 1
1×1 Matrix{Float32}:
 5.0

julia> ∇state0 = layer2.cell.Wh .^ 2
1×1 Matrix{Float32}:
 16.0

Worst, the gradients are different (yet still wrong) if using the explicit mode :\

I tested on older version of Flux and things got even more weird. I got the same bad gradients going back to v0.11.4. However, when trying out of Julia 1.6.5... correct gradients with all tested Flux versions, v0.11.4 up to v0.13.4 and latest Zygote v0.6.58 (both implicit and explicit modes)!

The same bad gradients were observed on Julia 1.7.2 and 1.9.0-rc1.

So, it seems like something changed btween Julia v1.6 and v1.7 that had an impact on gradient correctness. Any idea @ToucheSir?

ToucheSir commented 1 year ago

If I had to guess, something about lowering changed between those two versions. The more concerning part is that our test suite didn't catch this. I've always had a sinking feeling that https://github.com/FluxML/Flux.jl/blob/master/test/layers/recurrent.jl did not provide sufficient coverage, and unfortunately this only confirms that...

jeremiedb commented 1 year ago

I'll open a PR by tomorrow to add the above gradients tests. I'm also disappointed not to have taken the time to manually validate those RNN gradients until now. Zygote is quite a footgun :\

liuyxpp commented 11 months ago

For reference, this is how you could use the new explicit gradient / Optimsers.jl mode:

loss_fn1(m, X, y) = logitcrossentropy([m(x) for x ∈ X][end], y)
accuracy1(m, X, y) = mean(onecold([m(x) for x ∈ X][end], 0:9) .== y)

rule = Flux.Optimisers.Adam()
opts = Flux.Optimisers.setup(rule, model);

for epoch ∈ 1:5
    for idx ∈  partition(shuffle(1:size(train_y, 2)), BATCH_SIZE)
        X = [x[:, idx] for x ∈ train_x]
        y = train_y[:, idx]
        Flux.reset!(model)
        gs = gradient(model) do m
            loss_fn1(m, X, y)
        end
        Flux.Optimisers.update!(opts, model, gs[1]);
    end
    Flux.reset!(model)
    test_acc = accuracy1(model, test_x, test_y)
    @info "Epoch : $epoch | accuracy : $test_acc"
end

In this implementation, the Flux.reset!(model) is outside of loss function. Does it mean the model initial state is preserved among a batch? I don't think it is the expected behavior for most of the cases. The model should be reset for each single sample not each single batch.

ToucheSir commented 11 months ago

If you consider the initial state non-trainable, then I think it's mostly equivalent since other libraries are passing all zeros as the initial state. If you have a custom initial state however or want it to be trainable (which PyTorch at least does not appear to support directly), then it is not the same as you say. I'm unsure why the original design is the way it is (cc @mkschleg for possible theories), but reworking the initial state is one of those things we're investigating for our overhaul of the RNN API.

jeremiedb commented 10 months ago

Regarding the initialization of initial state, although it may not be the common form encountered in PyTorch, this paper with LSTM author as co-author points to the relevance of learning the initial state (see section 5.1 at page 135). Also this blog post discussing it: https://r2rt.com/non-zero-initial-states-for-recurrent-neural-networks.html. I also had the vague souvenir that MXNet used to have learnable initial state as a feature, but couldn't confirm.

By applying reset!, the state of the model is set to a learnable inital state that will be applied to all observations in the input data X. Notice how the loss function iterates over all the "time-steps" of the input data: [m(x) for x ∈ X]. This means for the first timestep, each individual observation belonging to that batch will have a common hidden state input. Following that first timestep, the state of the model will be different for each observation of the batch. It's not advised to put reset! within a loss funciton as it's not a learnable operation, but an assignation one where the state of the model is assigned with the initial-state parameter, so it's ready to received a new batch-sequence data.