bat / BAT.jl

A Bayesian Analysis Toolkit in Julia
Other
212 stars 30 forks source link

Pullback of DistributionTransform #322

Closed jknollm closed 3 years ago

jknollm commented 3 years ago

Hi everyone,

I got a problem regarding the Zygote.pullback and BAT.DistributionTransform, which somehow does no longer work. I use Julia 1.7 with a freshly installed BAT version (should probably be the main branch). Here is a minimal example:

using BAT using Zygote prior = NamedTupleDist( ξ = BAT.StandardMvNormal(length(D)) ) pos = rand(prior) bwd_trafo = BAT.DistributionTransform(Normal, prior) fwd_trafo = inv(bwd_trafo) pos_t = bwd_trafo(pos) pullback(fwd_trafo,pos_t)

Has anyone an idea what might be going on here? Here are the first lines of the error message, let me know in case you need the full thing:

ERROR: type DataType has no field mutable Stacktrace: [1] getproperty @ ./Base.jl:37 [inlined] [2] adjoint @ ~/.julia/packages/Zygote/zowrf/src/lib/lib.jl:281 [inlined] [3] _pullback @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined] [4] _pullback @ ~/.julia/packages/ValueShapes/p5OTT/src/array_shape.jl:21 [inlined] [5] _pullback(ctx::Zygote.Context, f::Type{ArrayShape{Float64, 1}}, args::Tuple{Int64}) @ Zygote ~/.julia/packages/Zygote/zowrf/src/compiler/interface2.jl:0 [6] _pullback @ ~/.julia/packages/ValueShapes/p5OTT/src/array_shape.jl:27 [inlined] [7] _pullback(ctx::Zygote.Context, f::Type{ArrayShape{Float64}}, args::Tuple{Int64}) @ Zygote ~/.julia/packages/Zygote/zowrf/src/compiler/interface2.jl:0

oschulz commented 3 years ago

I'll take a look!

oschulz commented 3 years ago

Which package versions do you use?

I can run the code above successfully on Julia v1.7 using BAT#master, ValueShapes@0.9.5, Zygote@0.6.29 and Distributions@0.25.21 (these are the current versions).

oschulz commented 3 years ago

Can you try running

using Pkg
Pkg.activate(temp = true)
pkg"""add BAT#master ValueShapes Zygote Distributions"""

using BAT, ValueShapes, Distributions
using Zygote
prior = NamedTupleDist(a = BAT.StandardMvNormal(3))
pos = rand(prior)
bwd_trafo = BAT.DistributionTransform(Normal, prior)
fwd_trafo = inv(bwd_trafo)
pos_t = bwd_trafo(pos)
Zygote.pullback(fwd_trafo,pos_t)
jknollm commented 3 years ago

I got the following packages:

[c0cd4b16] BAT v2.0.5 [336ed68f] CSV v0.9.9 [a93c6f00] DataFrames v1.2.2 [31c24e10] Distributions v0.24.18 [7a1cc6ca] FFTW v1.4.5 [f6369f11] ForwardDiff v0.10.21 [f67ccb44] HDF5 v0.15.6 [fdae7790] MGVI v0.2.3 [429524aa] Optim v1.4.1 [90014a1f] PDMats v0.11.1 [91a5bcdd] Plots v1.23.1 [276daf66] SpecialFunctions v1.8.0 [136a8f8c] ValueShapes v0.8.3 [e88e6eb3] Zygote v0.6.12 [8bb1440f] DelimitedFiles

Your code crashes with the following message:

julia> pkg"""add BAT#master ValueShapes Zygote Distributions""" ERROR: No more files /home/iwsatlas1/knollmue/.julia/registries/General.tar.gz System ERROR: Unknown error -2147024872 ERROR: failed process: Process(setenv(/opt/julia-1.7/libexec/7z x /home/iwsatlas1/knollmue/.julia/registries/General.tar.gz -so,["PATH=/opt/julia-1.7/libexec:/opt/nodejs/bin:/opt/anaconda3/bin:/opt/anaconda3/condabin:/opt/julia/bin:/opt/julia-1.7/bin:/opt/julia/bin:/opt/julia-1.6/bin:/opt/julia-1.3/bin:/opt/julia-1.0/bin:/opt/cmake/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin", "NV_LIBNCCL_DEV_PACKAGE=libnccl-devel-2.9.6-1+cuda11.3", "NV_LIBNCCL_PACKAGE_VERSION=2.9.6-1", "STY=4126094.madmax", "JUPYTER=jupyter", "NV_NVPROF_VERSION=11.3.58-1", "NVIDIA_DRIVER_CAPABILITIES=compute,utility", "LD_LIBRARY_PATH=/opt/julia-1.7/bin/../lib/julia:/opt/julia-1.7/bin/../lib:/usr/local/cuda/lib64:/usr/local/cuda/nvvm/lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/.singularity.d/libs", "DBUS_SESSION_BUS_ADDRESS=unix:path=/run/user/1404/bus", "MANPATH=/opt/nodejs/share/man:/opt/anaconda3/share/man:/opt/julia/share/man:/opt/cmake/share/man:" … "OPENBLAS_NUM_THREADS=8", "LANGUAGE=en_US:en", "SINGULARITY_CONTAINER=/remote/ceph/group/odsl/vm/singularity/images/mppmu_odsl-ml_latest.sif", "NV_NVML_DEV_VERSION=11.3.58-1", "NV_LIBNCCL_PACKAGE=libnccl-2.9.6-1+cuda11.3", "CONDA_PREFIX=/opt/anaconda3", "XDG_SESSION_ID=10394", "NV_LIBCUBLAS_VERSION=11.4.2.10064-1", "USES_VSCODE_SERVER_SPAWN=true", "OPENBLAS_MAIN_FREE=1"]), ProcessExited(2)) [2] Stacktrace: [1] pipeline_error @ ./process.jl:531 [inlined] [2] open(::Pkg.Registry.var"#11#14"{IOBuffer, Vector{UInt8}, Dict{String, String}}, ::Cmd; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ Base ./process.jl:406 [3] open

oschulz commented 3 years ago

You need to use the master branch of BAT (should appear as `BAT v2.1.0-DEV in Pkg status).

jknollm commented 3 years ago

Ok, I got the wrong BAT version, sorry, my bad. Thanks Oliver! It works now.

oschulz commented 3 years ago

No problem, glad it works!