FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.54k stars 608 forks source link

Flux.train! stops working after the first iteration without an error. #1692

Closed surya-chak closed 3 years ago

surya-chak commented 3 years ago

Flux.train! goes through the first iteration resulting in 1 call to the callback function after which the process seems to get stuck. We've used htop to verify that the Julia process stops using any CPU at all after a few minutes into this step.

We're are currently using Flux.train! for a NeuralODE problem, the code for which is attached below:

# =====================================
# Defining the parameters of the system
# =====================================

Linear=0.0001*rand(nX,nX);
Quadratic=0.001*rand(nX,nX,nX);

BMat=0.01*rand(nX,nU);

# packing up the coefficients into parameters vector
pLin=Linear[:];
len_pL=length(pLin);

pQuad=Quadratic[:];
len_pQ=length(pQuad);

pB=BMat[:];
len_pB=length(pB);

p=[pLin;pQuad;pB];

# RHS of the system
function Syst_RHS!(dX,X,p,t)
    # Dismantling the parameters of the neural network
    LTerm=reshape(p[1:len_pL],nX,nX);
    QTerm=reshape(p[len_pL+1:len_pQ+len_pL],nX,nX,nX);
    BTerm=reshape(p[len_pL+len_pQ+1:end],nX,nU);

    U=zeros(nU)

    if t<=TCtrlOn
        U=UVec;
    else
        U.=0.0;
    end

    for iState=1:1:nX
        dX[iState]=dot(LTerm[iState,:],X)+X'*(transpose(QTerm[iState,:,:])*X);
        dX[iState]=dX[iState]+dot(BTerm[iState,:],U); # Adding the control
    end
end

# ==================
# Setup ODE problem
# ==================
X0=XDat[1,:]; #init conditions
println("size X0 is", size(X0))
TSpan=(0.0,TFin);
prob_nn = ODEProblem(Syst_RHS!, X0, TSpan, p);

println("Solving with untrained params...")
sol = Array(solve(prob_nn, Tsit5(),saveat=TVec))
# ================
# Training set up
# ================
# Forward pass function
function predict_adjoint() # Trainable layer
    Array(solve(prob_nn, Tsit5(), saveat=TVec, reltol=1e-4))
end

# Loss function
function loss_adjoint()
    prediction = predict_adjoint()
    loss = sum((prediction - XDat').^2); # L2 norm
    return loss
end

# Defining learning parameters
opt=ADAM(0.1);
params=Flux.params(p)

losshistory = []
cb = function () #callback function to observe training
    push!(losshistory,loss_adjoint());
    display(loss_adjoint());
end

# Display the ODE with the initial parameter values.
cb()
@info "Start training"
Flux.train!(loss_adjoint, params, Iterators.repeated((), 100), opt, cb = cb)
@info "Finished Training"
DhairyaLGandhi commented 3 years ago

It would be good to get a stacktrace by terminating the run. Also, could you please format the code snippets in the post?

I would also try to run the different training steps by hand (that is, call Flux.gradient and update manually)

CarloLucibello commented 3 years ago

no stacktrace, no MWE, and OP not responsive, closing this