JuliaAI / MLJModelInterface.jl

Lightweight package to interface with MLJ
MIT License
37 stars 8 forks source link

Question on the use of the Update! method and is_same_except() #212

Open pasq-cat opened 22 hours ago

pasq-cat commented 22 hours ago

Hi, i was trying to implement the update method for laplaceredux but I am having a problem.

this is the model

MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
    model::Flux.Chain = nothing
    flux_loss = Flux.Losses.mse
    optimiser = Adam()
    epochs::Integer = 1000::(_ > 0)
    batch_size::Integer = 32::(_ > 0)
    subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
    subnetwork_indices = nothing
    hessian_structure::Union{HessianStructure,Symbol,String} =
        :full::(_ in (:full, :diagonal))
    backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
    σ::Float64 = 1.0
    μ₀::Float64 = 0.0
    P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
    fit_prior_nsteps::Int = 100::(_ > 0)
end

this is the fit function that i have written

function MMI.fit(m::LaplaceRegressor, verbosity, X, y)

    #X = MLJBase.matrix(X) |> permutedims
    #y = reshape(y, 1, :)

    if Tables.istable(X)
        X = Tables.matrix(X)|>permutedims
    end

    # Reshape y if necessary
    y = reshape(y, 1, :)

    data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
    opt_state = Flux.setup(m.optimiser, m.model)
    loss_history=[]
    push!(loss_history, m.flux_loss(m.model(X), y ))

    for epoch in 1:(m.epochs)

        loss_per_epoch= 0.0

        for (X_batch, y_batch) in data_loader
            # Forward pass: compute predictions
            y_pred = m.model(X_batch)

            # Compute loss
            loss = m.flux_loss(y_pred, y_batch)

            # Compute gradients 
            grads = gradient(m.model) do model
                # Recompute predictions inside gradient context
                y_pred = model(X_batch)
                m.flux_loss(y_pred, y_batch)
            end

            # Update parameters using the optimizer and computed gradients
            Flux.Optimise.update!(opt_state ,m.model , grads[1])

            # Accumulate the loss for this batch
            loss_per_epoch += sum(loss)  # Summing the batch loss

        end

        push!(loss_history,loss_per_epoch )

        # Print loss every 100 epochs if verbosity is 1 or more
        if verbosity >= 1 && epoch % 100 == 0
            println("Epoch $epoch: Loss: $loss_per_epoch ")
        end
    end

    la = LaplaceRedux.Laplace(
        m.model;
        likelihood=:regression,
        subset_of_weights=m.subset_of_weights,
        subnetwork_indices=m.subnetwork_indices,
        hessian_structure=m.hessian_structure,
        backend=m.backend,
        σ=m.σ,
        μ₀=m.μ₀,
        P₀=m.P₀,
    )

    # fit the Laplace model:
    LaplaceRedux.fit!(la, data_loader)
    optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)

    fitresult = la
    report = (loss_history = loss_history,)
    cache = (deepcopy(m),opt_state, loss_history)
    return fitresult, cache, report
end

and now follows the incomplete update function that i was trying. I have removed the loop part since it's not important.

function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X, y)

println(" running MMI:update")

old_model = old_cache[1]

if Tables.istable(X)
    X = Tables.matrix(X)|>permutedims
end

# Reshape y if necessary
y = reshape(y, 1, :)

println(MMI.is_same_except(m, old_model, :epochs))

cache=()
report=()
return old_fitresult, cache, report
end

the issue is that if i try to rerun the model by changing only the number of epochs is_same_except still gives me

false

even though :epochs is listed as exception

using MLJ
flux_model = Chain(
    Dense(4, 10, relu),
    Dense(10, 10, relu),
    Dense(10, 1)
)
model = LaplaceRegressor(model=flux_model)

X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
mach = machine(model, X, y) 
MLJBase.fit!(mach)

model.epochs=2000

MLJBase.fit!(mach)

so what is the correct way to implement is_same_except? thank you

ablaom commented 13 hours ago

Not sure what the problem might be. Can you provide a MWE demonstrating that is_same_except is not working as you expect. I.e, some variation of this (which is working for me):

using MLJModelInterface
import MLJModelInterface as MMI

mutable struct Classifier <: Probabilistic
    x::Int
    y::Int
end

model = Classifier(1, 2)
model2 = deepcopy(model)
model2.y = 7

@assert MMI.is_same_except(model, model2, :y)

Or, if you suspect some other problem, a more self-contained MWE would be helpful.