FluxML / MLJFlux.jl

Wrapping deep learning models from the package Flux.jl for use in the MLJ.jl toolbox
http://fluxml.ai/MLJFlux.jl/
MIT License
145 stars 17 forks source link

Error when batch_size different from 1 in NeuralNetworkRegressor #225

Open MathNog opened 1 year ago

MathNog commented 1 year ago

When I pass batch_size as a parameter to the NeuralNetworkRegressor() the model can´t be fitted because of a dimension mismatch.

I have written the following code:

mutable struct LSTMBuilder <: MLJFlux.Builder
    input_size :: Int
    num_units :: Dict
    num_layers :: Int
end
function MLJFlux.build(lstm::LSTMBuilder, rng, n_in, n_out)

    input_size, num_units, num_layers = lstm.input_size, lstm.num_units, lstm.num_layers
    init = Flux.glorot_uniform(rng)
    Random.seed!(1234)
    layers = [LSTM(n_in,num_units[1]), Dropout(0.1)]
    for i in 1:num_layers-1
        layers = vcat(layers,[LSTM(num_units[i],num_units[i+1]), Dropout(0.1)])
    end
    layers = vcat(layers, Dense(num_units[num_layers],n_out))
    Random.seed!(1234)
    model = Chain(layers)

    return model
end
model = NeuralNetworkRegressor(builder=LSTMBuilder(60, 4, 2),
                        rng = Random.GLOBAL_RNG,
                        epochs = 200,
                        loss = Flux.mse,
                        optimiser = ADAM(0.001),
                        batch_size = 16)

And the error messagem when training it is:

[ Info: Training machine(JackknifeRegressor(model = NeuralNetworkRegressor(builder = LSTMBuilder(input_size = 60, …), …), …), …).
Optimising neural net: 100%[=========================] Time: 0:00:03
┌ Error: Problem fitting the machine machine(JackknifeRegressor(model = NeuralNetworkRegressor(builder = LSTMBuilder(input_size = 60, …), …), …), …). 
└ @ MLJBase C:\Users\matheuscn.ELE\.julia\packages\MLJBase\5cxU0\src\machines.jl:682
[ Info: Running type checks... 
[ Info: Type checks okay. 
ERROR: DimensionMismatch: array could not be broadcast to match destination

I suspect that this error is caused by the fact that there is no Flux.reset!() after each batch update inside the training loop.

ablaom commented 1 year ago

Thanks @MathNog for reporting.

I've not tried to reproduce, but your analysis sounds reasonable. (Current tests do include changing batch size for some non-recurrent networks.)

Each time MLJModelnterface.fit is called, a new Flux model will be built, so I suppose the issue is that the last batch within an epoch can be smaller than the others (if I remember correctly we do allow this, rather than just dumping the last batch). Is this also your thinking? So it may suffice to rule that out.

It's a while since I looked at RNNs, but I would have thought calling reset! after every batch update would muck up inference. Do I misunderstand?

MathNog commented 1 year ago

Thanks for the comment, @ablaom, and I believe you are correct in your suggestion.

I have altered the both MLJFlux.fit! and MLJFlux.train! inside the scope of my own project adding que Flux.reset! command excatly as you have said. However, in order to add that line I also had to change the code structure a little, while making sure the final result is the same.

function MLJFlux.fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, verbosity, X, y)
    loss = model.loss
    # initiate history:
    n_batches = length(y)
    parameters = Flux.params(chain)
    losses = Vector{Float32}(undef,n_batches)
    for i in 1:n_batches
        losses[i] = loss(chain(X[i]), y[i]) + penalty(parameters) / n_batches
        Flux.reset!(chain)
    end
    history = [mean(losses),]
    for i in 1:epochs
        current_loss = MLJFlux.train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)
        push!(history, current_loss)
    end
    return chain, history
end
"Train! retirada do MLJFlux"
function MLJFlux.train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)
    loss = model.loss
    n_batches = length(y)
    training_loss = zero(Float32)
    for i in 1:n_batches
        parameters = Flux.params(chain)
        gs = Flux.gradient(parameters) do
            yhat = chain(X[i])
            batch_loss = loss(yhat, y[i]) + penalty(parameters) / n_batches
            training_loss += batch_loss
            return batch_loss
        end
        Flux.update!(optimiser, parameters, gs)
        Flux.reset!(chain)
    end
    return training_loss / n_batches
end

I have also noticed that, in order to everything run smoothly, the function MLJModelInterface.predict, in src/regressor.jl should also be modified by adding the reset! command, and I have made it work as follows.

function MLJModelInterface.predict(model::MLJFlux.NeuralNetworkRegressor, fitresult, Xnew)
    chain = fitresult[1]
    Xnew_ = MLJFlux.reformat(Xnew)
    forec = Vector{Float32}(undef,size(Xnew_,2))
    for i in 1:size(Xnew_,2)
        Flux.reset!(chain)
        forec[i] = chain(values.(MLJFlux.tomat(Xnew_[:, i])))[1]
    end
    return forec
end 

With all those changes, I could train and predict a NeuralNetworkRegressor with batchsize different from 1 with no issues. I hope those examples may help in someway the development of the project.

ablaom commented 1 year ago

Thanks for that, but I think I was not clear enough. My understanding is that a Flux RNN must be trained on batches that are all the same size. Calling reset! between batches will stop Flux complaining, but by doing so you are interfering with the normal training of the weights. It's roughly akin to, say, resetting some random weights to zero between batches.

I'm not an expert on RNN's, so I may have this wrong. Perhaps @ToucheSir can comment.

If I'm right, then the more appropriate remedy is to ensure all batches have the same size, when the batch size does not divide the number of observations, so that the last batch is smaller than the others. For example, we could simply ignore the last batch. To justify this, we would need to ensure we are also shuffling observations between epochs, which is not implemented, if I remember correctly.

ToucheSir commented 1 year ago

With the caveat that I have not read through the entire thread, it's perfectly fine to have different batch sizes while training an RNN. reset! exists precisely to, well, reset the internal state before feeding in the next batch. What you do want to be careful of however is how the batch dimension is represented, because it's different from most other NN models you'd deal with (batch dim is not the last dim, sequence of timesteps, etc).