Closed vavrines closed 6 months ago
function loss_function1(model, x, ps, st, y)
smodel = StatefulLuxLayer(model, ps, st)
ŷ = smodel(x)
loss_emp = sum(abs2, ŷ .- y)
J = only(Zygote.jacobian(smodel, x)) # used to be ForwardDiff
loss_reg = abs2(norm(J))
return loss_emp + loss_reg
end
Zygote.<....>
functions return a tuple, so you need to add that only
or a first
.
It is mentioned in the manual that "You can use
Zygote.jacobian
as well but ForwardDiff tends to be more efficient here". Switching to Zygote:I got:
Env: