JuliaSmoothOptimizers / FluxNLPModels.jl

Other
6 stars 2 forks source link

Float16 compatibility #15

Closed d-monnet closed 10 months ago

d-monnet commented 1 year ago

Hi there,

I would like to know if Float16 is supported. I followed this tutorial https://jso.dev/FluxNLPModels.jl/dev/tutorial/ and naively tried

w16 = Float16.(nlp.w)
obj(nlp,w16)

but got a Float32. Therefore I assume at least some computations are performed with Float32 when evaluating the objective. I also tried to modify the function getdata() as

function get_data(bs) 
   ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"

  # Loading Dataset
  xtrain, ytrain = MLDatasets.MNIST(Tx = Float16, split = :train)[:]
  xtest, ytest = MLDatasets.MNIST(Tx = Float16, split = :test)[:]
  .
  .
  .
end

but still got a Float32 when evaluating the objective. Any idea how to run in Float16 (or any other format)?

d-monnet commented 1 year ago

Ok I found the issue: obj calls set_var!() which is not type stable. The issue comes from nlp.w .= new_w at https://github.com/JuliaSmoothOptimizers/FluxNLPModels.jl/blob/a17ca9448ef0396f8a8fb3ecb90e082d128defae/src/utils.jl#L6 which is not type stable. The operator .= casts right hand side vector into left hand side vector's format. For example:

x32 = ones(Float32,10)
x16 = ones(Float16,10)
x32 .= x16 # this is still a Float32

That is, even is the argument of obj is a Vector{Float16}, it is cast in whatever the parameter type S of FluxNLPModel{T,S} is.

d-monnet commented 1 year ago

Ok I found the issue: obj calls set_var!() which is not type stable. The issue comes from nlp.w .= new_w at

https://github.com/JuliaSmoothOptimizers/FluxNLPModels.jl/blob/a17ca9448ef0396f8a8fb3ecb90e082d128defae/src/utils.jl#L6 which is not type stable. The operator .= casts right hand side vector into left hand side vector's format. For example:

x32 = ones(Float32,10)
x16 = ones(Float16,10)
x32 .= x16 # this is still a Float32

That is, even is the argument of obj is a Vector{Float16}, it is cast in whatever the parameter type S of FluxNLPModel{T,S} is.

In fact this is not even the bottom of the issue: Flux.destructure does not allow FP format modification via the restructure mechanism. From destructure documentation: "Such restoration follows the rules of ChainRulesCore.ProjectTo, and thus will restore floating point precision" Since the restructure is called in set_var(), we're still can't allow fp format switch. Any workaround would be welcomed!

farhadrclass commented 12 months ago

I can change the backend to change the model everytime

a quick change is as :

f64(m) = Flux.paramtype(Float64, m) # similar to https://github.com/FluxML/Flux.jl/blob/d21460060e055dca1837c488005f6b1a8e87fa1b/src/functor.jl#L217

then to change our model we use :

fluxnlp.model= f64(fluxnlp.model)
farhadrclass commented 11 months ago

Flux just recently added support for this https://fluxml.ai/Flux.jl/stable/utilities/#Flux.f16