SciML / LabelledArrays.jl

Arrays which also have a label for each element for easy scientific machine learning (SciML)
https://docs.sciml.ai/LabelledArrays/stable/
Other
120 stars 21 forks source link

Support FixedSizeDiffCache #137

Open ChrisRackauckas opened 1 year ago

ChrisRackauckas commented 1 year ago

This could be done via the following + tests:

function get_tmp(dc::FixedSizeDiffCache,
                 u::LabelledArrays.LArray{T, N, D, Syms}) where {T, N, D, Syms}
    x = reinterpret(T, dc.dual_du.__x)
    _x = if chunksize(T) === chunksize(eltype(dc.dual_du))
        x
    else
        @view x[axes(dc.du)...]
    end
    LabelledArrays.LArray{T, N, D, Syms}(_x)
end

#LArray tests
chunk_size = 4
u0_L = LArray((2,2); a=1.0, b=1.0, c=1.0, d=1.0)
zerodual = zero(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64, chunk_size})
dual_L = LArray((2,2); a=zerodual, b=zerodual, c=zerodual, d=zerodual)
cache_L = FixedSizeDiffCache(u0_L, chunk_size)
tmp_du_LA = get_tmp(cache_L, u0_L)
tmp_dual_du_LA = get_tmp(cache_L, dual_L)
tmp_du_LN = get_tmp(cache_L, u0_L[1])
tmp_dual_du_LN = get_tmp(cache_L, dual_L[1])
@test size(tmp_du_LA) == size(u0_L)
@test typeof(tmp_du_LA) == typeof(u0_L)
@test eltype(tmp_du_LA) == eltype(u0_L)
@test size(tmp_dual_du_LA) == size(u0_L)
@test typeof(tmp_dual_du_LA) == typeof(dual_L)
@test eltype(tmp_dual_du_LA) == eltype(dual_L)
@test size(tmp_du_LN) == size(u0_L)
@test typeof(tmp_du_LN) == typeof(u0_L)
@test eltype(tmp_du_LN) == eltype(u0_L)
@test size(tmp_dual_du_LN) == size(u0_L)
@test typeof(tmp_dual_du_LN) == typeof(dual_L)
@test eltype(tmp_dual_du_LN) == eltype(dual_L)

#LArray
A = LArray((2,2); a=1.0, b=1.0, c=1.0, d=1.0)
c = LArray((2,2); a=0.0, b=0.0, c=0.0, d=0.0)
u0 = LArray((2,2); a=1.0, b=1.0, c=1.0, d=1.0)
function foo(du, u, (A, tmp), t)
    tmp = get_tmp(tmp, u)
    mul!(tmp, A, u)
    @. du = u + tmp
    nothing
end

#with specified chunk_size
chunk_size = 4
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0., 1.0), (A, FixedSizeDiffCache(c, chunk_size)))
sol = solve(prob, TRBDF2(chunk_size = chunk_size))
@test sol.retcode == ReturnCode.Success

#with auto-detected chunk_size
prob = ODEProblem(foo, u0, (0., 1.0), (A, FixedSizeDiffCache(c)))
sol = solve(prob, TRBDF2())
@test sol.retcode == ReturnCode.Success