FluxML / FluxTraining.jl

A flexible neural net training library inspired by fast.ai
https://fluxml.ai/FluxTraining.jl
MIT License
117 stars 25 forks source link

Read access to `Learner.model` disallowed, cannot override via `stateaccess()` #158

Closed alec-hoyland closed 8 months ago

alec-hoyland commented 8 months ago

Expected Behavior

I can override default state access restrictions using the stateaccess function to read learner.model.weight within the context of FluxTraining.step!(metric::MyMetric) to do some computation during the validation phase.

What I want to do is write a custom metric that has access to learner.model.weight, e.g. something like:

function FluxTraining.step!(metric::MyMetric, learner, phase)
    if phase isa metric.P
        metric.last = l1_metric(learner.model.weight)
        OnlineStats.fit!(metric.statistic, metric.last)
    else
        metric.last = nothing
    end
end

Error

I can't get this to work and I am unsure if it is a mistake on my end or if this is a bug.

Package info

[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FluxTraining = "7bf95e4d-ca32-48da-9824-f0dc5310474f"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"

System info

Pop! OS 22.04 LTS x86_64 with Julia version 1.9.0

Stacktrace

julia> include("mwe.jl")
main (generic function with 1 method)

julia> main()
Epoch 1 TrainingPhase() ...
┌─────────────────┬───────┬─────────┐
│           Phase │ Epoch │    Loss │
├─────────────────┼───────┼─────────┤
│ TrainingPhase() │   1.0 │ 0.23645 │
└─────────────────┴───────┴─────────┘
Epoch 1 ValidationPhase() ...
ERROR: FluxTraining.ProtectedException("Read access to Learner.model disallowed.")
Stacktrace:
  [1] getfieldperm(data::Learner, field::Symbol, perm::Nothing)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/protect.jl:63
  [2] getproperty(protected::FluxTraining.Protected{Learner}, field::Symbol)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/protect.jl:18
  [3] step!(metric::MyMetric{Number}, learner::FluxTraining.Protected{Learner}, phase::ValidationPhase)
    @ Main ~/code/TinnitusStimulusFitter.jl/scripts/stimuli_modeling/mwe.jl:50
  [4] on(#unused#::FluxTraining.Events.StepEnd, phase::ValidationPhase, metrics::Metrics, learner::FluxTraining.Protected{Learner})
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/metrics.jl:74
  [5] _on(e::FluxTraining.Events.StepEnd, p::ValidationPhase, cb::Metrics, learner::Learner)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/callback.jl:254
  [6] handle(runner::FluxTraining.LinearRunner, event::FluxTraining.Events.StepEnd, phase::ValidationPhase, learner::Learner)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/callbacks/execution.jl:12
  [7] (::FluxTraining.var"#handlefn#82"{Learner, ValidationPhase})(e::FluxTraining.Events.StepEnd)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:129
  [8] runstep(stepfn::FluxTraining.var"#79#80"{Learner}, learner::Learner, phase::ValidationPhase, initialstate::NamedTuple{(:xs, :ys), Tuple{Matrix{Float32}, Matrix{Float32}}})
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:134
  [9] step!(learner::Learner, phase::ValidationPhase, batch::Tuple{Matrix{Float32}, Matrix{Float32}})
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:84
 [10] (::FluxTraining.var"#71#72"{Learner, ValidationPhase, MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random._GLOBAL_RNG, Val{nothing}}})(#unused#::Function)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:24
 [11] runepoch(epochfn::FluxTraining.var"#71#72"{Learner, ValidationPhase, MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random._GLOBAL_RNG, Val{nothing}}}, learner::Learner, phase::ValidationPhase)
    @ FluxTraining ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:105
 [12] epoch!
    @ ~/.julia/packages/FluxTraining/xCOPx/src/training.jl:22 [inlined]
 [13] main()
    @ Main ~/code/TinnitusStimulusFitter.jl/scripts/stimuli_modeling/mwe.jl:100
 [14] top-level scope
    @ REPL[3]:1

Minimum Working Example


using Flux
using FluxTraining
using OnlineStats
using LinearAlgebra

"""
Custom type for my metric.
This is basically a duplicate of the standard Metric.
"""
mutable struct MyMetric{T} <: FluxTraining.AbstractMetric
    statistic::OnlineStats.OnlineStat{T}
    _statistic::Any
    name::Any
    device::Any
    P::Any
    last::Union{Nothing, T}
end

"""
Outer constructor for MyMetric.
"""
function MyMetric(;
        statistic = OnlineStats.Mean(Float32),
        device = cpu,
        phase = ValidationPhase,
        name = "MyMetric")
    return MyMetric(statistic, deepcopy(statistic), name, device, phase, nothing)
end

"""
Reset MyMetric back to the initial value.
"""
function FluxTraining.reset!(metric::MyMetric{T}) where T
    metric.statistic = deepcopy(metric._statistic)
end

"""
We will use the L1 norm as an "example function"
that requires access to model weights.
"""
function l1_metric(W::Matrix)
    return norm(W, 1) / size(W, 1)
end

"""
Compute the metric by taking the L1 norm of the model weight matrix.
"""
function FluxTraining.step!(metric::MyMetric, learner, phase)
    if phase isa metric.P
        metric.last = l1_metric(learner.model.weight)
        OnlineStats.fit!(metric.statistic, metric.last)
    else
        metric.last = nothing
    end
end

function Base.show(io::IO, metric::MyMetric{T}) where {T}
    print(io, "Metric(", metric.name, ")")
end

FluxTraining.runafter(::MyMetric) = (FluxTraining.Recorder,)
FluxTraining.stepvalue(metric::MyMetric) = metric.last
FluxTraining.metricname(metric::MyMetric) = metric.name

function FluxTraining.epochvalue(metric::MyMetric)
    if isnothing(metric.last)
        nothing
    else
        OnlineStats.value(metric.statistic)
    end
end

function FluxTraining.stateaccess(::MyMetric)
    return (
        model = FluxTraining.Read(),
        params = FluxTraining.Read(),
        cbstate = (metricsstep = FluxTraining.Write(), metricsepoch = FluxTraining.Write(), history = FluxTraining.Read()),
        step = FluxTraining.Read(),
    )
end

function main()
    in_dim = 10
    out_dim = 1
    n_samples = 64
    model = Dense(in_dim => out_dim, identity; bias=false)

    X = rand(in_dim, n_samples) |> f32
    y = rand(out_dim, n_samples) |> f32
    train_dataloader = Flux.DataLoader((X, y))
    val_dataloader = deepcopy(train_dataloader)

    callbacks = [FluxTraining.Metrics(MyMetric())]
    opt_state = Flux.Adam(1f-4)

    learner = FluxTraining.Learner(model, Flux.mse; callbacks = callbacks, optimizer = opt_state)

    for i = 1:3
        FluxTraining.epoch!(learner, FluxTraining.TrainingPhase(), train_dataloader)
        FluxTraining.epoch!(learner, ValidationPhase(), val_dataloader)
    end
end
alec-hoyland commented 8 months ago

Associated post: https://discourse.julialang.org/t/how-to-read-model-weights-using-fluxtraining-stateaccess-issues/106409

alec-hoyland commented 8 months ago

stateaccess refers to Callbacks not Metrics, so it should be

function FluxTraining.stateaccess(::FluxTraining.Metrics)
    return (
        model = (weight = FluxTraining.Read(),),
        params = FluxTraining.Read(),
        cbstate = (metricsstep = FluxTraining.Write(), metricsepoch = FluxTraining.Write(), history = FluxTraining.Read()),
        step = FluxTraining.Read(),
    )
end