JuliaRobotics / RigidBodyDynamics.jl

Julia implementation of various rigid body dynamics and kinematics algorithms
Other
289 stars 49 forks source link

Compatibility with Tracker from Flux for AD? #570

Open ludns opened 5 years ago

ludns commented 5 years ago

Hey there,

I am trying to use Flux models as torque controllers and eventually backpropagate from some sort of loss to the parameters of those models but it seems that RigidBodyDynamics.jl breaks Tracker.

I know that Flux and Tracker work with DifferentialEquations.jl (DiffEqFlux.jl as an example).

Edit: The reason why I am using Tracker.TrackedReal{Float64} as the type for MechanismState is because I need those parameters to record the operations being executed on them (just like the ForwardDiff.jl tutorial for this package)

Here is a minimal repro:

using RigidBodyDynamics
using RigidBodySim
using DifferentialEquations
using Flux

urdf = joinpath(dirname(pathof(RigidBodySim)), "..", "test", "urdf", "Acrobot.urdf")
mechanism = parse_urdf(Float64, urdf)
remove_fixed_tree_joints!(mechanism);
state = MechanismState{Tracker.TrackedReal{Float64}}(mechanism)

open_loop_dynamics = Dynamics(mechanism);
problem = ODEProblem(open_loop_dynamics, state, (0., 1))
sol = solve(problem, Tsit5(), abs_tol = 1e-7, dt = 0.05)

Stack trace:

ERROR: LoadError: MethodError: no method matching Float64(::Tracker.TrackedReal{Float64})
Closest candidates are:
  Float64(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:194
  Float64(::T<:Number) where T<:Number at boot.jl:718
  Float64(::Int8) at float.jl:60
  ...
Stacktrace:
 [1] convert(::Type{Float64}, ::Tracker.TrackedReal{Float64}) at ./number.jl:7
 [2] macro expansion at /Users/justinglibert/.julia/packages/StaticArrays/3KEjZ/src/util.jl:11 [inlined]
 [3] convert_ntuple at /Users/justinglibert/.julia/packages/StaticArrays/3KEjZ/src/util.jl:8 [inlined]
 [4] StaticArrays.SArray{Tuple{3},Float64,1,3}(::Tuple{Tracker.TrackedReal{Float64},Tracker.TrackedReal{Float64},Tracker.TrackedReal{Float64}}) at /Users/justinglibert/.julia/packages/StaticArrays/3KEjZ/src/SArray.jl:28
 [5] _convert at /Users/justinglibert/.julia/packages/StaticArrays/3KEjZ/src/convert.jl:29 [inlined]
 [6] convert at /Users/justinglibert/.julia/packages/StaticArrays/3KEjZ/src/convert.jl:26 [inlined]
 [7] Type at /Users/justinglibert/.julia/packages/RigidBodyDynamics/XgNGG/src/spatial/spatialmotion.jl:131 [inlined]
 [8] Twist(::CartesianFrame3D, ::CartesianFrame3D, ::CartesianFrame3D, ::TrackedArray{…,StaticArrays.SArray{Tuple{3},Float64,1,3}}, ::Array{Float64,1}) at /Users/justinglibert/.julia/packages/RigidBodyDynamics/XgNGG/src/spatial/spatialmotion.jl:161
 [9] joint_twist at /Users/justinglibert/.julia/packages/RigidBodyDynamics/XgNGG/src/joint_types/revolute.jl:67 [inlined]
 [10] joint_twist at /Users/justinglibert/.julia/packages/RigidBodyDynamics/XgNGG/src/joint.jl:378 [inlined]
 [11] macro expansion at /Users/justinglibert/.julia/packages/TypeSortedCollections/Z4ytl/src/TypeSortedCollections.jl:173 [inlined]
 [12] map! at /Users/justinglibert/.julia/packages/TypeSortedCollections/Z4ytl/src/TypeSortedCollections.jl:165 [inlined]
 [13] _update_joint_twists!(::MechanismState{Tracker.TrackedReal{Float64},Float64,Tracker.TrackedReal{Float64},TypeSortedCollections.TypeSortedCollection{Tuple{Array{Joint{Float64,Revolute{Float64}},1}},1}}) at /Users/justinglibert/.julia/packages/RigidBodyDynamics/XgNGG/src/mechanism_state.jl:723
 [14] update_joint_twists! at /Users/justinglibert/.julia/packages/RigidBodyDynamics/XgNGG/src/mechanism_state.jl:717 [inlined]
 [15] _update_twists_wrt_world!(::MechanismState{Tracker.TrackedReal{Float64},Float64,Tracker.TrackedReal{Float64},TypeSortedCollections.TypeSortedCollection{Tuple{Array{Joint{Float64,Revolute{Float64}},1}},1}}) at /Users/justinglibert/.julia/packages/RigidBodyDynamics/XgNGG/src/mechanism_state.jl:771
 [16] contact_dynamics!(::DynamicsResult{Tracker.TrackedReal{Float64},Float64}, ::MechanismState{Tracker.TrackedReal{Float64},Float64,Tracker.TrackedReal{Float64},TypeSortedCollections.TypeSortedCollection{Tuple{Array{Joint{Float64,Revolute{Float64}},1}},1}}) at /Users/justinglibert/.julia/packages/RigidBodyDynamics/XgNGG/src/mechanism_state.jl:766
 [17] #dynamics!#114(::RigidBodyDynamics.CustomCollections.ConstDict{JointID,RigidBodyDynamics.PDControl.SE3PDGains{RigidBodyDynamics.PDControl.PDGains{Tracker.TrackedReal{Float64},Tracker.TrackedReal{Float64}},RigidBodyDynamics.PDControl.PDGains{Tracker.TrackedReal{Float64},Tracker.TrackedReal{Float64}}}}, ::typeof(dynamics!), ::DynamicsResult{Tracker.TrackedReal{Float64},Float64}, ::MechanismState{Tracker.TrackedReal{Float64},Float64,Tracker.TrackedReal{Float64},TypeSortedCollections.TypeSortedCollection{Tuple{Array{Joint{Float64,Revolute{Float64}},1}},1}}, ::SegmentedVector{JointID,Tracker.TrackedReal{Float64},Base.OneTo{JointID},Array{Tracker.TrackedReal{Float64},1}}, ::RigidBodyDynamics.CustomCollections.NullDict{BodyID,Wrench{Tracker.TrackedReal{Float64}}}) at /Users/justinglibert/.julia/packages/RigidBodyDynamics/XgNGG/src/mechanism_algorithms.jl:850
 [18] dynamics! at /Users/justinglibert/.julia/packages/RigidBodyDynamics/XgNGG/src/mechanism_algorithms.jl:849 [inlined] (repeats 2 times)
 [19] (::Dynamics{Float64,TypeSortedCollections.TypeSortedCollection{Tuple{Array{Joint{Float64,Revolute{Float64}},1}},1},typeof(zero_control!),getfield(RigidBodySim.Core, Symbol("##2#4"))})(::Array{Tracker.TrackedReal{Float64},1}, ::Array{Tracker.TrackedReal{Float64},1}, ::Nothing, ::Float64) at /Users/justinglibert/.julia/packages/RigidBodySim/Beung/src/core.jl:84
 [20] ODEFunction at /Users/justinglibert/.julia/packages/DiffEqBase/XoKmO/src/diffeqfunction.jl:230 [inlined]
 [21] initialize!(::OrdinaryDiffEq.ODEIntegrator{Tsit5,true,Array{Tracker.TrackedReal{Float64},1},Float64,Nothing,Tracker.TrackedReal{Float64},Float64,Float64,Array{Array{Tracker.TrackedReal{Float64},1},1},ODESolution{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Tracker.TrackedReal{Float64},1},1},1},ODEProblem{Array{Tracker.TrackedReal{Float64},1},Tuple{Float64,Float64},true,Nothing,ODEFunction{true,Dynamics{Float64,TypeSortedCollections.TypeSortedCollection{Tuple{Array{Joint{Float64,Revolute{Float64}},1}},1},typeof(zero_control!),getfield(RigidBodySim.Core, Symbol("##2#4"))},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,CallbackSet{Tuple{},Tuple{}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{CallbackSet{Tuple{},Tuple{}}}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,Dynamics{Float64,TypeSortedCollections.TypeSortedCollection{Tuple{Array{Joint{Float64,Revolute{Float64}},1}},1},typeof(zero_control!),getfield(RigidBodySim.Core, Symbol("##2#4"))},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Tracker.TrackedReal{Float64},1},1},Array{Float64,1},Array{Array{Array{Tracker.TrackedReal{Float64},1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Tracker.TrackedReal{Float64},1},Array{Tracker.TrackedReal{Float64},1},Array{Tracker.TrackedReal{Float64},1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},ODEFunction{true,Dynamics{Float64,TypeSortedCollections.TypeSortedCollection{Tuple{Array{Joint{Float64,Revolute{Float64}},1}},1},typeof(zero_control!),getfield(RigidBodySim.Core, Symbol("##2#4"))},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.Tsit5Cache{Array{Tracker.TrackedReal{Float64},1},Array{Tracker.TrackedReal{Float64},1},Array{Tracker.TrackedReal{Float64},1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}},OrdinaryDiffEq.DEOptions{Tracker.TrackedReal{Float64},Tracker.TrackedReal{Float64},Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),CallbackSet{Tuple{},Tuple{}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,DataStructures.LessThan},DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Nothing,Nothing,Int64,Array{Float64,1},Array{Float64,1},Array{Float64,1}},Array{Tracker.TrackedReal{Float64},1},Tracker.TrackedReal{Float64},Nothing}, ::OrdinaryDiffEq.Tsit5Cache{Array{Tracker.TrackedReal{Float64},1},Array{Tracker.TrackedReal{Float64},1},Array{Tracker.TrackedReal{Float64},1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}) at /Users/justinglibert/.julia/packages/OrdinaryDiffEq/tQd6p/src/perform_step/low_order_rk_perform_step.jl:623
 [22] #__init#335(::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,1}, ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::CallbackSet{Tuple{},Tuple{}}, ::Bool, ::Bool, ::Float64, ::Float64, ::Float64, ::Bool, ::Bool, ::Rational{Int64}, ::Nothing, ::Nothing, ::Rational{Int64}, ::Int64, ::Int64, ::Int64, ::Rational{Int64}, ::Bool, ::Int64, ::Nothing, ::Nothing, ::Int64, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(LinearAlgebra.opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Nothing, ::Bool, ::Bool, ::Bool, ::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:abs_tol,),Tuple{Float64}}}, ::typeof(DiffEqBase.__init), ::ODEProblem{Array{Tracker.TrackedReal{Float64},1},Tuple{Float64,Float64},true,Nothing,ODEFunction{true,Dynamics{Float64,TypeSortedCollections.TypeSortedCollection{Tuple{Array{Joint{Float64,Revolute{Float64}},1}},1},typeof(zero_control!),getfield(RigidBodySim.Core, Symbol("##2#4"))},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,CallbackSet{Tuple{},Tuple{}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{CallbackSet{Tuple{},Tuple{}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5, ::Array{Array{Tracker.TrackedReal{Float64},1},1}, ::Array{Float64,1}, ::Array{Any,1}, ::Type{Val{true}}) at /Users/justinglibert/.julia/packages/OrdinaryDiffEq/tQd6p/src/solve.jl:352
 [23] (::getfield(DiffEqBase, Symbol("#kw##__init")))(::NamedTuple{(:callback, :abs_tol, :dt),Tuple{CallbackSet{Tuple{},Tuple{}},Float64,Float64}}, ::typeof(DiffEqBase.__init), ::ODEProblem{Array{Tracker.TrackedReal{Float64},1},Tuple{Float64,Float64},true,Nothing,ODEFunction{true,Dynamics{Float64,TypeSortedCollections.TypeSortedCollection{Tuple{Array{Joint{Float64,Revolute{Float64}},1}},1},typeof(zero_control!),getfield(RigidBodySim.Core, Symbol("##2#4"))},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,CallbackSet{Tuple{},Tuple{}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{CallbackSet{Tuple{},Tuple{}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5, ::Array{Array{Tracker.TrackedReal{Float64},1},1}, ::Array{Float64,1}, ::Array{Any,1}, ::Type{Val{true}}) at ./none:0 (repeats 4 times)
 [24] #__solve#334 at /Users/justinglibert/.julia/packages/OrdinaryDiffEq/tQd6p/src/solve.jl:4 [inlined]
 [25] #__solve at ./none:0 [inlined]
 [26] #solve_call#425(::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol,Symbol},NamedTuple{(:abs_tol, :dt),Tuple{Float64,Float64}}}, ::typeof(DiffEqBase.solve_call), ::ODEProblem{Array{Tracker.TrackedReal{Float64},1},Tuple{Float64,Float64},true,Nothing,ODEFunction{true,Dynamics{Float64,TypeSortedCollections.TypeSortedCollection{Tuple{Array{Joint{Float64,Revolute{Float64}},1}},1},typeof(zero_control!),getfield(RigidBodySim.Core, Symbol("##2#4"))},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,CallbackSet{Tuple{},Tuple{}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{CallbackSet{Tuple{},Tuple{}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5) at /Users/justinglibert/.julia/packages/DiffEqBase/XoKmO/src/solve.jl:38
 [27] #solve_call at ./none:0 [inlined]
 [28] #solve#426 at /Users/justinglibert/.julia/packages/DiffEqBase/XoKmO/src/solve.jl:57 [inlined]
 [29] (::getfield(DiffEqBase, Symbol("#kw##solve")))(::NamedTuple{(:abs_tol, :dt),Tuple{Float64,Float64}}, ::typeof(solve), ::ODEProblem{Array{Tracker.TrackedReal{Float64},1},Tuple{Float64,Float64},true,Nothing,ODEFunction{true,Dynamics{Float64,TypeSortedCollections.TypeSortedCollection{Tuple{Array{Joint{Float64,Revolute{Float64}},1}},1},typeof(zero_control!),getfield(RigidBodySim.Core, Symbol("##2#4"))},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,CallbackSet{Tuple{},Tuple{}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{CallbackSet{Tuple{},Tuple{}}}}},DiffEqBase.StandardODEProblem}, ::Tsit5) at ./none:0
 [30] top-level scope at /Users/justinglibert/Dev/repositories/fastbot/v1/repro.jl:14
in expression starting at /Users/justinglibert/Dev/repositories/fastbot/v1/repro.jl:14
tkoolen commented 5 years ago

I can take a closer look at this after Wednesday.

tkoolen commented 5 years ago

OK, so this appears to be because:

julia> Tracker.track(identity, rand(3)) |> typeof |> supertype
AbstractArray{Float64,1}

while

julia> Tracker.track(identity, rand(3)) |> eltype
Tracker.TrackedReal{Float64}

So in a sense, TrackedArray{T} is lying about its element type when it extends AbstractArray{T}. I think this is a very unfortunate design decision in Tracker.jl that probably breaks a lot of use cases, but I'll see if I can work around it.

I also see that Tracker.jl is on its way out. I'm not familiar with Flux's status quo; do you know why they're replacing Tracker.jl and what they're replacing it with?

tkoolen commented 5 years ago

After #571, I think the next issue is with

using Tracker
using StaticArrays
x = Tracker.track(identity, 1.0)
y = [Tracker.track(identity, 2.0), Tracker.track(identity, 2.0)]
@show typeof(x * y)

which results in

TrackedArray{…,Array{Tracker.TrackedReal{Float64},1}}

i.e., a TrackedArray with TrackedReal element type. I think this should be either a TrackedArray of Float64s, or a Vector of TrackedReal{Float64}s; the current result seems weirdly nested. I'm not sure, but I suspect that this is an issue with the way Tracker.jl overrides broadcasting.

But Tracker is being replaced with Zygote of course. It might make sense to switch to try the zygote branch of Flux, maybe that magically fixes all of the issues.

tkoolen commented 5 years ago

Also wanted to point to https://github.com/tkoolen/RigidBodyDynamicsDiff.jl, which, though experimental, has moderately optimized gradients w.r.t. q (not v yet) that I'd expect to be significantly faster than what reverse-mode AD in the style of Tracker or Zygote will give you. I'd love to be proved wrong though.

ludns commented 5 years ago

Thank you for your detailed reply! Zygote is still very unstable and just does not work with RigidBodyDynamics Repro:

using RigidBodyDynamics
using RigidBodySim
using DifferentialEquations
using Zygote

urdf = joinpath(dirname(pathof(RigidBodySim)), "..", "test", "urdf", "Acrobot.urdf")
mechanism = parse_urdf(Float32, urdf)
remove_fixed_tree_joints!(mechanism);
state = MechanismState(mechanism)
shoulder, elbow = joints(mechanism)
# Set the initial state
configuration(state, shoulder) .= 0.3
configuration(state, elbow) .= 0.4
velocity(state, shoulder) .= 1.
velocity(state, elbow) .= 2.;

mutable struct ControllerParams
    param::Float32
end
p_test = ControllerParams(5)

function create_controller(p::ControllerParams)
    function control!(τ, t, state)
        view(τ, velocity_range(state, shoulder))  .= p.param  * sin(t)
        view(τ, velocity_range(state, elbow)) .= -configuration(state, shoulder)
    end
    return control!
end

function simulate_full(p::ControllerParams)
    cc! = create_controller(p)
    open_loop_dynamics = Dynamics(mechanism, cc!);
    problem = ODEProblem(open_loop_dynamics, state, (0., 1))
    sol = solve(problem, Tsit5(), abs_tol = 1e-7, dt = 0.05)
    println("Done")
    return sol[end][1]
end

# Check if we can get the gradient of a function using a ControllerParams
@show gradient(p -> p.param * 2, p_test)
# Check if we can simulate 1 sec of Dynamics
@show simulate_full(p_test)
# Check if we can get the gradient of the Dynamic with the respect of a ControllerParams
# -> Breaks
@show gradient(simulate_full, p_test)

Stack trace: ` UndefVarError: S not defined show(::Base.GenericIOBuffer{Array{UInt8,1}}, ::Type{Zygote.Pullback{Tuple{Type{UnionAll},TypeVar,Type{DynamicsResult{#s127,Float32}}},T} where T}) at show.jl:13 show_datatype(::Base.GenericIOBuffer{Array{UInt8,1}}, ::DataType) at show.jl:547 show(::Base.GenericIOBuffer{Array{UInt8,1}}, ::Type) at show.jl:428 print(::Base.GenericIOBuffer{Array{UInt8,1}}, ::Type) at io.jl:37 print(::Base.GenericIOBuffer{Array{UInt8,1}}, ::String, ::Type) at io.jl:48 show_tuple_as_call(::Base.GenericIOBuffer{Array{UInt8,1}}, ::Symbol, ::Type) at show.jl:1513 show_spec_linfo(::Base.GenericIOBuffer{Array{UInt8,1}}, ::Base.StackTraces.StackFrame) at stacktraces.jl:262

show#9(::Bool, ::typeof(show), ::Base.GenericIOBuffer{Array{UInt8,1}}, ::Base.StackTraces.StackFrame) at stacktraces.jl:272

show at stacktraces.jl:272 [inlined] show(::Base.GenericIOBuffer{Array{UInt8,1}}, ::Atom.EvalError{UndefVarError}) at errors.jl:27 print(::Base.GenericIOBuffer{Array{UInt8,1}}, ::Atom.EvalError{UndefVarError}) at io.jl:37 print_to_string(::Atom.EvalError{UndefVarError}) at io.jl:129 string(::Atom.EvalError{UndefVarError}) at io.jl:168 macro expansion at comm.jl:66 [inlined] (::getfield(Atom, Symbol("##125#130")){String})() at eval.jl:102 macro expansion at essentials.jl:790 [inlined] (::getfield(Atom, Symbol("##121#126")))(::Dict{String,Any}) at eval.jl:86 handlemsg(::Dict{String,Any}, ::Dict{String,Any}) at comm.jl:164 (::getfield(Atom, Symbol("##19#21")){Array{Any,1}})() at task.jl:268 `

What I am trying to do is basically learn a parameterized controller with gradient descent. Any idea/tips on how I should approach the problem with Julia and RigidBodyDynamics.jl? Theoretically it should work and I know people have done it. I'm just not sure that the AD is mature enough with Julia to be able to do that.

ludns commented 5 years ago

I actually thought it was a bug with Atom and I tried to execute the above script using the CLI but it's the same error

❯ julia --project zyg.jl

gradient((p->begin
            #= /Users/justinglibert/Dev/repositories/magic/v1/zyg.jl:49 =#
            p.param * 2
        end), p_test) = (Base.RefValue{Any}((param = 2.0f0,)),)

simulate_full(p_test) = 0.99640465f0

ERROR: LoadError: UndefVarError: S not defined
Stacktrace:
 [1] (::Type{ERROR: UndefVarError: S not defined
Stacktrace:
 [1] show(::IOContext{Base.GenericIOBuffer{Array{UInt8,1}}}, ::Type{fatal: error thrown and no exception handler available.
UndefVarError(var=:S)
rec_backtrace at /Users/sabae/buildbot/worker/package_macos64/build/src/stackwalk.c:94
record_backtrace at /Users/sabae/buildbot/worker/package_macos64/build/src/task.c:219
jl_throw at /Users/sabae/buildbot/worker/package_macos64/build/src/task.c:429
jl_undefined_var_error at /Users/sabae/buildbot/worker/package_macos64/build/src/rtutils.c:130
show at /Users/justinglibert/.julia/packages/Zygote/bdE6T/src/compiler/show.jl:13
show_datatype at ./show.jl:547
show at ./show.jl:428
print at ./strings/io.jl:37
print at ./strings/io.jl:48
show_tuple_as_call at ./show.jl:1513
show_spec_linfo at ./stacktraces.jl:262
#show#9 at ./stacktraces.jl:272
#show at ./none:0 [inlined]
#show_trace_entry#642 at ./errorshow.jl:486
#show_trace_entry at ./none:0
unknown function (ip: 0x18f3c0ce8)
show_backtrace at ./errorshow.jl:589
#showerror#625 at ./errorshow.jl:83
#showerror at ./none:0
unknown function (ip: 0x18f3bca31)
show_exception_stack at ./errorshow.jl:652
display_error at ./client.jl:110
display_error at ./client.jl:112
jl_apply at /Users/sabae/buildbot/worker/package_macos64/build/src/./julia.h:1614 [inlined]
jl_f__apply at /Users/sabae/buildbot/worker/package_macos64/build/src/builtins.c:563
jl_f__apply_latest at /Users/sabae/buildbot/worker/package_macos64/build/src/builtins.c:601
#invokelatest#1 at ./essentials.jl:790 [inlined]
invokelatest at ./essentials.jl:789 [inlined]
_start at ./client.jl:466
true_main at /usr/local/bin/julia (unknown line)
main at /usr/local/bin/julia (unknown line)
tkoolen commented 5 years ago

(away from a computer) I actually think that particular error is just due to the println you added (nothing in RBD.jl core code should use any kind of IOBuffer). I tried Flux#zygote briefly myself last week (with RBD.jl master) and the first thing I ran into was the potrf! call (LAPACK Cholesky decomposition) in the method of dynamics_solve! that's optimized for BLAS floats. When I commented out that method so that the generic fallback is used, I think I ran into a possible Julia compiler bug. Again though, even if we get this to work, it's likely to be quite slow, possibly prohibitively so. It's not clear to me whether reverse mode AD is going to outperform forward mode, since the number of outputs of the dynamics (dimension of joint acceleration vector) is on the same order as the number of inputs. If you just want to get things working ASAP, it may be a good idea to define a custom adjoint for the dynamics in Flux that just does the ForwardDiff/StateCache thing as described in the documentation. That'll maybe be 10-100x slower than decently optimized custom gradients, while I suspect that Zygote will be slower.

ludns commented 5 years ago

I actually removed the println and got the exact same error. Thanks for your reply, it probably makes sense to just define a custom adjoint. I am wondering if it will work with controllers which would be written with Flux though. I suppose I would need to feed the state into my model, get the torques out, run the ODE solver on a very short time range, get my gradients, and repeat (I don't know how to apply torque w/o using control in RBD.jl) The callback (control), being executed in the ODE solver, will probably not be compatible with ForwardDiff and the custom adjoint strategy you mentioned.

Finn-Sueberkrueb commented 2 years ago

Is it correct that Zygote and RigidBodyDynamics are not yet compatible? In my simple test, I fail because Zygote does not yet support mutating arrays.