Multiple dispatch for obj/grad method #30

Open farhadrclass opened 8 months ago

farhadrclass commented 8 months ago

We need to develop a method for rewriting the objective function (obj) and gradient (grad) when the weight vector (w) differs from the weight vector associated with the NLP model (nlp.w). Currently, an if statement is employed at line #15 in response to scenarios such as the one illustrated below:

I added print statment:

unction NLPModels.obj(nlp::AbstractFluxNLPModel{T, S}, w::AbstractVector{V}) where {T, S, V}
  x, y = nlp.current_training_minibatch
  print("type V is ", V, "\n")
  print("type eltype(nlp.w) is ", eltype(nlp.w), "\n")
  print("type of T is ", T, "\n")
  print("type of eltype(x) is ", eltype(x), "\n")
  eltype(nlp.w) == V || update_type!(nlp, w) #Check if the type has changed 
  if eltype(x) != V
    x = V.(x)

  set_vars!(nlp, w)
  increment!(nlp, :neval_obj)
  return nlp.loss_f(nlp.chain(x), y)

Then I write a simple test code:

@testset "Multiple precision test" begin
  # Create test and train dataloaders
  train_data, test_data = getdata(args)

  # Construct model in Float32
  DN = build_model() |> device
  nlp = FluxNLPModel(DN, train_data, test_data)

  x1 = copy(nlp.w)
  obj_x1 = obj(nlp, x1)
  grad_x1 = NLPModels.grad(nlp, x1)

  # change to Float16 
  x2 = Float16.(x1)
  obj_x2 = obj(nlp, x2)

  # change to Float32 again, this is where the issue is
  # nlp.w is float16, but T and V both 32 
  obj_x3 = obj(nlp, x1) # x1 is float32, but nlp.w is float16 and nlp{T,S} is float32

Here is the output :

type V is Float32
type eltype(nlp.w) is Float32
type of T is Float32
type of eltype(x) is Float32
type V is Float16
type eltype(nlp.w) is Float32
type of T is Float32
type of eltype(x) is Float32
type eltype(nlp.w) is after type change Float16
type V is Float32
type eltype(nlp.w) is Float16
type of T is Float32
type of eltype(x) is Float32

the eltype(nlp.w) is float16 but T and V both are Float32,

change to Float32 again, this is where the issue is

nlp.w is float16, but T and V both 32 obj_x3 = obj(nlp, x1) # x1 is float32, but nlp.w is float16 and nlp{T,S} is float32