slimgroup / WISE.jl

WISE: full-Waveform variational Inference via Subsurface Extensions
https://doi.org/10.1190/geo2023-0744.1
MIT License
2 stars 0 forks source link

Size of m_train and calculation of grad_train parameters #2

Closed advaitb closed 6 months ago

advaitb commented 6 months ago

Hi @ziyiyin97 ,

a few questions on the training script especially in line 200, 201 and 202 (and correspondingly 204-206).

I'm a little confused as to how the m_train is a 4D array. Also, the calculation ofgrad_train and keep_offset_idx is not very obvious. I currently have the Model, Model0 and the sscig rtm (offset, x, z) saved and wanted to know how this maps to the variables you've loaded. If you could help clarify this that'll be great!

ziyiyin97 commented 6 months ago

m_train is in size of nx*nz*1*batchsize. We followed the ML literature in our configuration so the third dimension is the so-called "channel" dimension -- e.g., would be in size of 3 for RGB images.

grad_train is in size of nx*nz*nh*batchsize, where nh is number of offsets. So you would permute your sscig rtm so that the offset dimension would be moved to the end.

keep_offset_idx is a manually set up variable to select how many offsets one would use to train the CNFs. For example, if you only keep the index in the middle (for 51 offsets, it's 26:26), then you are only using the RTM (zero-offset). To use the full offsets, you would do 1:51.

Hope this clarifies!

advaitb commented 6 months ago

Thanks for your clarifications, @ziyiyin97! I got a couple of follow ups on other parts of the code.

1) Could you explain this normalization a bit (especially using 300) and the reason behind it?

max_y = quantile(abs.(vec(grad_train[:,:,:,1:300])),0.9999);
grad_train ./= max_y;

2) Also, what does the Summarizednet of UNet and CNF mean and what is the advantage of such a formulation?

3) Are there any constraints on setting L and K? How are they determined?

ziyiyin97 commented 6 months ago
  1. Velocity models typically have a limited range, say 1.5 to 5, but CIGs might have various ranges. Therefore, a typical trick in the ML literature is to apply a normalization so that the input and output roughly stay in the same range. After this normalization, most values in grad_train will be in the range of -1 to 1. It should be insensitive w.r.t. 300.
  2. The summary network (UNet) is applied to the CIGs, so in the end the CNFs are trying to learn from pairs (velocity models, summary network applied on CIGs). During training, the CNFs and the UNet are jointly trained. Many literatures indicate that this summary network would improve the quality of inference because it "helps" CNFs to further process the observables (in this case CIGs). Useful references include: https://arxiv.org/abs/2003.06281 (see Equation 18). We think this is a technical detail / engineering choice so it is omitted in the WISE paper.
  3. We suggest you set L such that 2^L is divisible by both nx and nz, because a lot of CNFs layers involve partition the current image into 2 * 2 block and rearrange them. For K there is no such constraint. More details can be found here https://github.com/slimgroup/InvertibleNetworks.jl/blob/1a4cba93b1976d868d2f922b7628115a7ebf5e72/src/networks/invertible_network_conditional_glow.jl#L114
advaitb commented 6 months ago

@ziyiyin97 Could you confirm that both chan_target and chan_cond are equal to 1 in your case? When I try executing the line Zx, Zy, lgdet = G.forward(X |> device, Y) in the training loop I get an index error which is traced back to this line in Unet.jl code which means that u.conv_blocks in my case is not array like for some reason while running the summarized net. This is part of the exact error I get

ERROR: LoadError: MethodError: no method matching getindex(::Int64, ::UnitRange{Int64})

Closest candidates are:
  getindex(::Number, ::Integer)
   @ Base number.jl:96
  getindex(::Number)
   @ Base number.jl:95
  getindex(::Number, ::Integer...)
   @ Base number.jl:101
  ...

Stacktrace:
  [1] (::Unet)(x::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ UNet ~/.julia/packages/UNet/NMxh2/src/model.jl:73

Here is the code I'm running

ziyiyin97 commented 6 months ago

Yes I confirm both chan_target and chan_cond are equal to 1 in my case.

To further debug, could you make a minimal failing example with just a single y (CIG) and the unet, and do unet(y)? It seems like the error comes only from the unet evaluation.

Also may I know what package versions you are using? I have a https://github.com/slimgroup/WISE.jl/blob/main/Manifest.toml to keep track of the version of all the packages. If you follow my DrWatson-ized repo and activate the environment correctly, you would use my versions by default.

advaitb commented 6 months ago

I get the same error when executing this println(unet(grad_train[:,:,:,1])). I have my own Manifest file here. Looks like the versions for Flux, UNet, and Invertible Networks are the same as your environment though I seem to have an updated version of JUDI.

ziyiyin97 commented 6 months ago

It needs to be unet(grad_train[:,:,:,1:1]) since the input needs to be a 4D array. Not sure whether this is related to this https://github.com/slimgroup/WISE.jl/issues/2#issuecomment-1984429242 though.

advaitb commented 6 months ago

Something weird is happening with the UNet code. Here is the minimal toy code that gives the error

unet_lev = 4
unet = Unet(chan_obs, chan_cond, unet_lev) |> device;
trainmode!(unet, true); 

random_array = rand(Float64, 256, 256, 21, 10)
println(unet.conv_blocks)
println(unet(random_array[:,:,:,1:1]))

The result for the first print is an Integer, 1. However, the UNet model.jl file has the line I referenced above op = u.conv_blocks[1:2](x) which throws an index error (rightly so), and I'm trying to figure out how an Integer appears where there should clearly be an array according to the constructor.

advaitb commented 6 months ago

@ziyiyin97 I fixed the UNet error. Turns out I was using UNet.jl from https://github.com/DhairyaLGandhi/UNet.jl instead of https://github.com/mloubout/UNet.jl. However I'm still running into an error when running this code:

cond_net = NetworkConditionalGlow(chan_target, chan_cond, n_hidden,  L, K;  split_scales=true, activation=SigmoidLayer(low=low,high=1.0f0)) |> device;
X_train_batch = m_train[:,:,:,1:1] |> device
Y_train_batch = grad_train[:,:,:,1:1] |> device
println(cond_net(X_train_batch, Y_train_batch))

X_train has size (64,64,1,1) and Y_train has size (64,64,11,1). The error I get is:

ERROR: LoadError: MethodError: no method matching length(::Nothing)

Stacktrace:
  [1] #s597#122
    @ ~/.julia/packages/GPUCompiler/S3TWf/src/cache.jl:18 [inlined]
  [2] var"#s597#122"(f::Any, tt::Any, ::Any, job::Any)
    @ GPUCompiler ./none:0
  [3] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
  [4] cached_compilation(cache::Dict{UInt64, Any}, job::GPUCompiler.CompilerJob, compiler::typeof(CUDA.cufunction_compile), linker::typeof(CUDA.cufunction_link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/S3TWf/src/cache.jl:71
  [5] cufunction(f::typeof(CUDA.partial_mapreduce_grid), tt::Type{Tuple{typeof(identity), typeof(Base.add_sum), Float32, CartesianIndices{4, NTuple{4, Base.OneTo{Int64}}}, CartesianIndices{4, NTuple{4, Base.OneTo{Int64}}}, Val{true}, CUDA.CuDeviceArray{Float32, 5, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{4}, NTuple{4, Base.OneTo{Int64}}, ComposedFunction{Base.Fix1{typeof(*), Float32}, typeof(identity)}, Tuple{CUDA.CuDeviceArray{Float32, 4, 1}}}}}; name::Nothing, always_inline::Bool, kwargs::@Kwargs{})
    @ CUDA ~/.julia/packages/CUDA/BbliS/src/compiler/execution.jl:300
  [6] cufunction(f::typeof(CUDA.partial_mapreduce_grid), tt::Type{Tuple{typeof(identity), typeof(Base.add_sum), Float32, CartesianIndices{4, NTuple{4, Base.OneTo{Int64}}}, CartesianIndices{4, NTuple{4, Base.OneTo{Int64}}}, Val{true}, CUDA.CuDeviceArray{Float32, 5, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{4}, NTuple{4, Base.OneTo{Int64}}, ComposedFunction{Base.Fix1{typeof(*), Float32}, typeof(identity)}, Tuple{CUDA.CuDeviceArray{Float32, 4, 1}}}}})
    @ CUDA ~/.julia/packages/CUDA/BbliS/src/compiler/execution.jl:293
  [7] macro expansion
    @ ~/.julia/packages/CUDA/BbliS/src/compiler/execution.jl:102 [inlined]
  [8] mapreducedim!(f::typeof(identity), op::typeof(Base.add_sum), R::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, A::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{4}, NTuple{4, Base.OneTo{Int64}}, ComposedFunction{Base.Fix1{typeof(*), Float32}, typeof(identity)}, Tuple{CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}}}; init::Float32)
    @ CUDA ~/.julia/packages/CUDA/BbliS/src/mapreduce.jl:234
  [9] mapreducedim!
    @ ~/.julia/packages/CUDA/BbliS/src/mapreduce.jl:169 [inlined]
 [10] _mapreduce(f::ComposedFunction{Base.Fix1{typeof(*), Float32}, typeof(identity)}, op::typeof(Base.add_sum), As::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}; dims::Vector{Int64}, init::Nothing)
    @ GPUArrays ~/.julia/packages/GPUArrays/5XhED/src/host/mapreduce.jl:69
 [11] mapreduce(::Function, ::Function, ::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}; dims::Vector{Int64}, init::Nothing)
    @ GPUArrays ~/.julia/packages/GPUArrays/5XhED/src/host/mapreduce.jl:30
 [12] mapreduce
    @ ~/.julia/packages/GPUArrays/5XhED/src/host/mapreduce.jl:30 [inlined]
 [13] _sum
    @ ./reducedim.jl:1039 [inlined]
 [14] sum
    @ ./reducedim.jl:1011 [inlined]
 [15] _mean(f::Function, A::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, dims::Vector{Int64})
    @ GPUArrays ~/.julia/packages/GPUArrays/5XhED/src/host/statistics.jl:37
 [16] mean
    @ ~/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/share/julia/stdlib/v1.10/Statistics/src/Statistics.jl:174 [inlined]
 [17] forward(X::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, AN::ActNorm; logdet::Nothing)
    @ InvertibleNetworks ~/.julia/packages/InvertibleNetworks/vWeOw/src/layers/invertible_layer_actnorm.jl:68
 [18] forward(X::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, AN::ActNorm)
    @ InvertibleNetworks ~/.julia/packages/InvertibleNetworks/vWeOw/src/layers/invertible_layer_actnorm.jl:60
 [19] _predefined_mode(obj::ActNorm, sym::Symbol, args::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}; kwargs::@Kwargs{})
    @ InvertibleNetworks ~/.julia/packages/InvertibleNetworks/vWeOw/src/utils/neuralnet.jl:29
 [20] _predefined_mode
    @ ~/.julia/packages/InvertibleNetworks/vWeOw/src/utils/neuralnet.jl:27 [inlined]
 [21] #135
    @ ~/.julia/packages/InvertibleNetworks/vWeOw/src/utils/neuralnet.jl:40 [inlined]
 [22] forward(X::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, C::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, G::NetworkConditionalGlow)
    @ InvertibleNetworks ~/.julia/packages/InvertibleNetworks/vWeOw/src/networks/invertible_network_conditional_glow.jl:111
 [23] _predefined_mode(::NetworkConditionalGlow, ::Symbol, ::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::Vararg{CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}}; kwargs::@Kwargs{})
    @ InvertibleNetworks ~/.julia/packages/InvertibleNetworks/vWeOw/src/utils/neuralnet.jl:29
 [24] _predefined_mode
    @ ~/.julia/packages/InvertibleNetworks/vWeOw/src/utils/neuralnet.jl:27 [inlined]
 [25] #135
    @ ~/.julia/packages/InvertibleNetworks/vWeOw/src/utils/neuralnet.jl:40 [inlined]
 [26] forward_net
    @ ~/.julia/packages/InvertibleNetworks/vWeOw/src/utils/neuralnet.jl:141 [inlined]
 [27] (::NetworkConditionalGlow)(X::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Y::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ InvertibleNetworks ~/.julia/packages/InvertibleNetworks/vWeOw/src/utils/neuralnet.jl:140
 [28] top-level scope
    @ ~/Advait/train_ofwi.jl:104

I can't seem to trace the error beyond this line. Any idea what's happening here?

ziyiyin97 commented 6 months ago

Is there any reason why you dropped the summary network (Unet)? It seems like that line is pointing out that it fails to find a summary network.

advaitb commented 6 months ago

I get the same error when running with a summarized network. So I feel it's more of the cond net issue. Another question I have is that should chan_cond be equal to the number of offsets in the CIG or can it be an arbitrary number (like 1). I'm still getting the error either way but I want to understand that better. Byond that, I seem to be running out of ideas as to why this code is failing.