FluxML / FluxTraining.jl

A flexible neural net training library inspired by fast.ai
https://fluxml.ai/FluxTraining.jl
MIT License
117 stars 25 forks source link

`Scheduler` causes cycle in execution DAG? #122

Closed darsnack closed 2 years ago

darsnack commented 2 years ago

I have the following script:

lossfn = Flux.Losses.logitcrossentropy

# define schedule and optimizer
initial_lr = 0.1
schedule = Step(initial_lr, 0.5, 20)
optim = Flux.Optimiser(Momentum(initial_lr), WeightDecay(1e-3))

# callbacks
logger = TensorBoardBackend("tblogs")
schcb = Scheduler(LearningRate => schedule)
hlogcb = LogHyperParams(logger)
mlogcb = LogMetrics(logger)
valcb = Metrics(Metric(accuracy; phase = TrainingPhase, name = "train_acc"),
                Metric(accuracy; phase = ValidationPhase, name = "val_acc"))

# setup learner object
learner = Learner(m, lossfn;
                  data = (trainloader, valloader),
                  optimizer = optim,
                  callbacks = [ToGPU(), mlogcb, valcb])

Any time I add schcb to the list of callbacks passed to the Learner, I get an error from FluxTraining that there is a cycle in the DAG. This did not happen in previous versions of FluxTraining (though I haven't been able to bisect the change yet).

lorenzoh commented 2 years ago

Since what version?

Can you give a stacktrace and dump learner.callbacks.cbs?

darsnack commented 2 years ago

Here is the trace:

ERROR: The input graph contains at least one loop.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] topological_sort_by_dfs(::Type{Graphs.IsDirected{Graphs.SimpleGraphs.SimpleDiGraph{Int64}}}, g::Graphs.SimpleGraphs.SimpleDiGraph{Int64})
    @ Graphs ~/.julia/packages/Graphs/zrMoC/src/traversals/dfs.jl:65
  [3] topological_sort_by_dfs(g::Graphs.SimpleGraphs.SimpleDiGraph{Int64})
    @ Graphs ~/.julia/packages/SimpleTraits/l1ZsK/src/SimpleTraits.jl:331
  [4] (::FluxTraining.var"#16#17"{Learner})()
    @ FluxTraining ~/.julia/packages/FluxTraining/bday3/src/callbacks/execution.jl:9
  [5] ignore
    @ ~/.julia/packages/Zygote/DkIUK/src/lib/utils.jl:25 [inlined]
  [6] handle(runner::FluxTraining.LinearRunner, event::FluxTraining.Events.EpochBegin, phase::TrainingPhase, learner::Learner)
    @ FluxTraining ~/.julia/packages/FluxTraining/bday3/src/callbacks/execution.jl:8
  [7] (::FluxTraining.var"#handlefn#77"{Learner, TrainingPhase})(e::FluxTraining.Events.EpochBegin)
    @ FluxTraining ~/.julia/packages/FluxTraining/bday3/src/training.jl:102
  [8] runepoch(epochfn::FluxTraining.var"#67#68"{Learner, TrainingPhase, DataLoaders.BufferGetObsParallel{NamedTuple{(:image, :label), Tuple{Array{Float32, 4}, Matrix{Bool}}}, BatchView{ObsView{MLUtils.MappedData{Base.Fix1{typeof(apply_augmenation), DataAugmentation.Sequence{Tuple{Pad{4}, Crop{2, DataAugmentation.FromRandom}, Rotate{Distributions.Uniform{Float64}}, Crop{2, DataAugmentation.FromCenter}, DataAugmentation.OneOfProjective{DataAugmentation.ProjectiveTransform, Distributions.Categorical{Float64, Vector{Float64}}}, ImageToTensor{Float32}, Normalize{3}}}}, NamedTuple{(:image, :label), Tuple{ObsView{MLUtils.MappedData{typeof(DataAugmentation.tensortoimage), Array{Float32, 4}}, Vector{Int64}}, SubArray{Bool, 2, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Vector{Int64}}, false}}}}, UnitRange{Int64}}, MLUtils.MappedData{Base.Fix1{typeof(apply_augmenation), DataAugmentation.Sequence{Tuple{Pad{4}, Crop{2, DataAugmentation.FromRandom}, Rotate{Distributions.Uniform{Float64}}, Crop{2, DataAugmentation.FromCenter}, DataAugmentation.OneOfProjective{DataAugmentation.ProjectiveTransform, Distributions.Categorical{Float64, Vector{Float64}}}, ImageToTensor{Float32}, Normalize{3}}}}, NamedTuple{(:image, :label), Tuple{ObsView{MLUtils.MappedData{typeof(DataAugmentation.tensortoimage), Array{Float32, 4}}, Vector{Int64}}, SubArray{Bool, 2, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Vector{Int64}}, false}}}}}}}, learner::Learner, phase::TrainingPhase)
    @ FluxTraining ~/.julia/packages/FluxTraining/bday3/src/training.jl:104
  [9] epoch!
    @ ~/.julia/packages/FluxTraining/bday3/src/training.jl:22 [inlined]
 [10] fit!(learner::Learner, nepochs::Int64, ::Tuple{DataLoaders.BufferGetObsParallel{NamedTuple{(:image, :label), Tuple{Array{Float32, 4}, Matrix{Bool}}}, BatchView{ObsView{MLUtils.MappedData{Base.Fix1{typeof(apply_augmenation), DataAugmentation.Sequence{Tuple{Pad{4}, Crop{2, DataAugmentation.FromRandom}, Rotate{Distributions.Uniform{Float64}}, Crop{2, DataAugmentation.FromCenter}, DataAugmentation.OneOfProjective{DataAugmentation.ProjectiveTransform, Distributions.Categorical{Float64, Vector{Float64}}}, ImageToTensor{Float32}, Normalize{3}}}}, NamedTuple{(:image, :label), Tuple{ObsView{MLUtils.MappedData{typeof(DataAugmentation.tensortoimage), Array{Float32, 4}}, Vector{Int64}}, SubArray{Bool, 2, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Vector{Int64}}, false}}}}, UnitRange{Int64}}, MLUtils.MappedData{Base.Fix1{typeof(apply_augmenation), DataAugmentation.Sequence{Tuple{Pad{4}, Crop{2, DataAugmentation.FromRandom}, Rotate{Distributions.Uniform{Float64}}, Crop{2, DataAugmentation.FromCenter}, DataAugmentation.OneOfProjective{DataAugmentation.ProjectiveTransform, Distributions.Categorical{Float64, Vector{Float64}}}, ImageToTensor{Float32}, Normalize{3}}}}, NamedTuple{(:image, :label), Tuple{ObsView{MLUtils.MappedData{typeof(DataAugmentation.tensortoimage), Array{Float32, 4}}, Vector{Int64}}, SubArray{Bool, 2, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Vector{Int64}}, false}}}}}}, DataLoaders.BufferGetObsParallel{NamedTuple{(:image, :label), Tuple{Array{Float32, 4}, Matrix{Bool}}}, BatchView{ObsView{MLUtils.MappedData{Base.Fix1{typeof(apply_augmenation), DataAugmentation.Sequence{Tuple{ImageToTensor{Float32}, Normalize{3}}}}, NamedTuple{(:image, :label), Tuple{ObsView{MLUtils.MappedData{typeof(DataAugmentation.tensortoimage), Array{Float32, 4}}, Vector{Int64}}, SubArray{Bool, 2, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Vector{Int64}}, false}}}}, UnitRange{Int64}}, MLUtils.MappedData{Base.Fix1{typeof(apply_augmenation), DataAugmentation.Sequence{Tuple{ImageToTensor{Float32}, Normalize{3}}}}, NamedTuple{(:image, :label), Tuple{ObsView{MLUtils.MappedData{typeof(DataAugmentation.tensortoimage), Array{Float32, 4}}, Vector{Int64}}, SubArray{Bool, 2, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Vector{Int64}}, false}}}}}}})
    @ FluxTraining ~/.julia/packages/FluxTraining/bday3/src/training.jl:168
 [11] fit!(learner::Learner, nepochs::Int64)
    @ FluxTraining ~/.julia/packages/FluxTraining/bday3/src/training.jl:174
 [12] top-level scope
    @ ~/test-cifar/test-kyle.jl:120

And the output of learner.callbacks.cbs:

8-element Vector{FluxTraining.SafeCallback}:
 ToDevice(Flux.gpu, Flux.gpu)
 Scheduler(LearningRate)
 LogMetrics((TensorBoardBackend(/home/daruwalla/test-cifar/tblogs),))
 Metrics(Loss(), Metric(train_acc), Metric(val_acc))
 ProgressPrinter()
 MetricsPrinter()
 StopOnNaNLoss()
 Recorder()

I'll try and bisect which version later.

lorenzoh commented 2 years ago

You can also visualize the dependency graph using

using GraphPlot
gplot(learner.callbacks.graph, nodelabel = learner.callbacks.cbs, layout = stressmajorize_layout)

That together with FluxTraining.stateaccess.(learner.callbacks.cbs) should give a better picture of where the conflict occurs.

lorenzoh commented 2 years ago

I found the problem: since #115 Scheduler now writes to learner.optimizer (because Optimisers.jl are immutable), the following cyclical dependency is created: