SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
331 stars 70 forks source link

High code complexity #917

Open linusheck opened 1 year ago

linusheck commented 1 year ago

One last complaint from me :D The code for this library is quite complex - it features a lot of different implementations of the same function that are switched inside the function. There are loads of expressions like this:

if dy !== nothing
    if W === nothing
        if inplace_sensitivity(S)
            f(dy, y, p, t)
        else
            recursive_copyto!(dy, vec(f(y, p, t)))
        end
    else
        if inplace_sensitivity(S)
            f(dy, y, p, t, W)
        else
            recursive_copyto!(dy, vec(f(y, p, t, W)))
        end
    end
end

I believe that the library would be much easier to work with if you would put these different implementations into different functions. Have a function that handles the computation if inplace_sensitivity(S) is true, etc.

ChrisRackauckas commented 1 year ago

The complaint is fine. It is complicated. I'm not sure it's an unnecessary complexity though. If these different calls were in a separate function, then there would be a lot more duplicated code since those dispatches are exactly the same but with the W on the end of them. A cleaner strategy might be some macro or something that's like W !== nothing ? f(dy, y, p, t, W) else f(dy, y, p, t) that is @Wcall f(dy, y, p, t, W) or something.

linusheck commented 1 year ago

Another example:

 if W === nothing
  if DiffEqBase.has_paramjac(f)
      # Calculate the parameter Jacobian into pJ
      f.paramjac(pJ, y, p, t)
  else
      pf.t = t
      pf.u = y
      if inplace_sensitivity(S)
          jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config)
      else
          temp = jacobian(pf, p, sensealg)
          pJ .= temp
      end
  end
else
  if DiffEqBase.has_paramjac(f)
      # Calculate the parameter Jacobian into pJ
      f.paramjac(pJ, y, p, t, W)
  else
      pf.t = t
      pf.u = y
      pf.W = W
      if inplace_sensitivity(S)
          jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config)
      else
          temp = jacobian(pf, p, sensealg)
          pJ .= temp
      end
  end
end

I don't really know how to fix this but a lot of this code is massively branching, always executing the same branch though.

If these different calls were in a separate function, then there would be a lot more duplicated code since those dispatches are exactly the same but with the W on the end of them.

I think it would be better if functions like _vecjacobian would only call abstract functions without any branching, and the functions themselves would figure out the details. In my opinion, there is a lot of code duplication already here, that could be reduced with such a strategy.

Wouldn't using multiple dispatch be enough for this? Maybe this is such a big change that it could only work in a rewrite: encode properties like inplace_sensitivity and W === nothing on the type level. Then just define a single in-place jacobian function that does what f.paramjac, jacobian! or jacobian does based on the types of e.g. S.

The variables like paramjac_config don't have to be globally flying around in the top-level functions. If behavior switches are encoded on the type level, they can be a property of the type.

linusheck commented 1 year ago

IDK, you know much more about what this library is actually doing, and such an architecture may be worse, or impossible to implement. Feel free to close the issue, it's just some ideas.

ChrisRackauckas commented 1 year ago

We can probably do something via dispatch where we make all of them wrapped in a form where it's always f(dy, y, p, t, W), but then f(dy, y, p, t, W::Nothing) = f.f(dy, y, p, t). Then the code can just use the W everywhere and cut down on the number of branches. @frankschae does that sound good to you?

frankschae commented 1 year ago

yeah sounds like a good idea. I think the adjoint_common.jl file is probably the worst in that regard -- and I am not sure when AbstractDifferentiation might be ready to do that part.