FluxML / FluxTraining.jl

A flexible neural net training library inspired by fast.ai
https://fluxml.ai/FluxTraining.jl
MIT License
117 stars 25 forks source link

Cannot recover model saved with Checkpointer() #161

Open cirobr opened 7 months ago

cirobr commented 7 months ago

Cheers,

Checkpointer() saves the entire trained model with BSON. For several models, I am able to recover them with BSON.@load modeladdress model. I am facing a case I could not recover the saved checkpoint. Error message as indicated below. Any hint on how to recover the saved stuff?

Anyway, it seems advisable to change the function from saving the model to saving the outcome of Flux.setup(model).

Thanks.

ERROR: type CodeInfo has no field pure
Stacktrace:
  [1] getproperty(ci::Core.CodeInfo, s::Symbol)
    @ Base ./deprecated.jl:326
  [2] 
    @ BSON ~/.julia/packages/BSON/DOYqe/src/anonymous.jl:58
  [3] newstruct_raw(cache::IdDict{Any, Any}, T::Type, d::Dict{Symbol, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/extensions.jl:169
  [4] (::BSON.var"#49#50")(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/extensions.jl:184
  [5] raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:92
  [6] (::BSON.var"#23#24"{IdDict{Any, Any}, Module})(x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:98
  [7] applychildren!(f::BSON.var"#23#24"{IdDict{Any, Any}, Module}, x::Vector{Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/BSON.jl:26
  [8] raise_recursive
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:98 [inlined]
--- the last 3 lines are repeated 1 more time ---
 [12] newstruct_raw(cache::IdDict{Any, Any}, ::Type{Core.TypeName}, d::Dict{Symbol, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/anonymous.jl:146
 [13] (::BSON.var"#49#50")(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/extensions.jl:184
 [14] raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:92
 [15] (::BSON.var"#18#21"{IdDict{Any, Any}, Module})(x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:82
 [16] applychildren!(f::BSON.var"#18#21"{IdDict{Any, Any}, Module}, x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/BSON.jl:19
 [17] _raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:82
 [18] raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:93
 [19] (::BSON.var"#23#24"{IdDict{Any, Any}, Module})(x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:98
 [20] applychildren!(f::BSON.var"#23#24"{IdDict{Any, Any}, Module}, x::Vector{Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/BSON.jl:26
 [21] raise_recursive
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:98 [inlined]
 [22] (::BSON.var"#17#20"{IdDict{Any, Any}, Module})(x::Vector{Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:80
 [23] applychildren!(f::BSON.var"#17#20"{IdDict{Any, Any}, Module}, x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/BSON.jl:19
 [24] _raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:80
--- the last 7 lines are repeated 3 more times ---
 [46] raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:93
 [47] (::BSON.var"#49#50")(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/extensions.jl:182
 [48] raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:92
 [49] (::BSON.var"#23#24"{IdDict{Any, Any}, Module})(x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:98
 [50] applychildren!(f::BSON.var"#23#24"{IdDict{Any, Any}, Module}, x::Vector{Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/BSON.jl:26
 [51] raise_recursive
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:98 [inlined]
 [52] (::BSON.var"#18#21"{IdDict{Any, Any}, Module})(x::Vector{Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:82
 [53] applychildren!(f::BSON.var"#18#21"{IdDict{Any, Any}, Module}, x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/BSON.jl:19
 [54] _raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:82
 [55] (::BSON.var"#49#50")(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/extensions.jl:183
--- the last 8 lines are repeated 2 more times ---
 [72] raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:92
 [73] (::BSON.var"#19#22"{IdDict{Any, Any}, Module})(x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:86
 [74] applychildren!(f::BSON.var"#19#22"{IdDict{Any, Any}, Module}, x::Dict{Symbol, Any})
    @ BSON ~/.julia/packages/BSON/DOYqe/src/BSON.jl:19
 [75] _raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:86
 [76] raise_recursive(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:93
 [77] raise_recursive
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:103 [inlined]
 [78] load
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:108 [inlined]
 [79] load(x::String)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/read.jl:108
 [80] macro expansion
    @ ~/.julia/packages/BSON/DOYqe/src/BSON.jl:50 [inlined]
ToucheSir commented 7 months ago

We'd need to use Flux.state instead of setup because setup is only for optimizer state (not model params), but it could be done. The bigger problem IMO is relying on BSON.jl, which is very buggy and basically unmaintained. For Flux's own docs, we've moved towards recommending JLD2.jl instead. FluxTraining should be switched to use that or the Serialization stdlib.

cirobr commented 7 months ago

Have made a quick-and-dirty fix, by creating a callback function, which can be executed after each epoch. Model is brought to CPU prior to saving its state. Have tested with BSON, but it might work with JLD2 as well.

function saveModelState(fullpathFilename, model)
    modelcpu    = Flux.cpu(model)
    model_state = Flux.state(modelcpu)
    BSON.@save fullpathFilename model_state
end

function saveModelStateCB(path, model)
    if path[end] != '/'
        path = path * "/"
    end

    fpfn = path * "model_state-" * Dates.format(Dates.now(), "yyyy-mm-ddTHH-MM-SS-sss") * ".bson"
    saveModelState(fpfn, model)    
end