Closed d-monnet closed 10 months ago
Ok I found the issue: obj
calls set_var!()
which is not type stable.
The issue comes from nlp.w .= new_w
at https://github.com/JuliaSmoothOptimizers/FluxNLPModels.jl/blob/a17ca9448ef0396f8a8fb3ecb90e082d128defae/src/utils.jl#L6 which is not type stable. The operator .=
casts right hand side vector into left hand side vector's format.
For example:
x32 = ones(Float32,10)
x16 = ones(Float16,10)
x32 .= x16 # this is still a Float32
That is, even is the argument of obj
is a Vector{Float16}
, it is cast in whatever the parameter type S
of FluxNLPModel{T,S}
is.
Ok I found the issue:
obj
callsset_var!()
which is not type stable. The issue comes fromnlp.w .= new_w
athttps://github.com/JuliaSmoothOptimizers/FluxNLPModels.jl/blob/a17ca9448ef0396f8a8fb3ecb90e082d128defae/src/utils.jl#L6 which is not type stable. The operator
.=
casts right hand side vector into left hand side vector's format. For example:x32 = ones(Float32,10) x16 = ones(Float16,10) x32 .= x16 # this is still a Float32
That is, even is the argument of
obj
is aVector{Float16}
, it is cast in whatever the parameter typeS
ofFluxNLPModel{T,S}
is.
In fact this is not even the bottom of the issue: Flux.destructure
does not allow FP format modification via the restructure mechanism. From destructure
documentation: "Such restoration follows the rules of ChainRulesCore.ProjectTo, and thus will restore floating point precision"
Since the restructure is called in set_var()
, we're still can't allow fp format switch.
Any workaround would be welcomed!
I can change the backend to change the model everytime
a quick change is as :
f64(m) = Flux.paramtype(Float64, m) # similar to https://github.com/FluxML/Flux.jl/blob/d21460060e055dca1837c488005f6b1a8e87fa1b/src/functor.jl#L217
then to change our model we use :
fluxnlp.model= f64(fluxnlp.model)
Flux just recently added support for this https://fluxml.ai/Flux.jl/stable/utilities/#Flux.f16
Hi there,
I would like to know if Float16 is supported. I followed this tutorial https://jso.dev/FluxNLPModels.jl/dev/tutorial/ and naively tried
but got a
Float32
. Therefore I assume at least some computations are performed withFloat32
when evaluating the objective. I also tried to modify the functiongetdata()
asbut still got a
Float32
when evaluating the objective. Any idea how to run in Float16 (or any other format)?