FluxML / FastAI.jl

Repository of best practices for deep learning in Julia, inspired by fastai
https://fluxml.ai/FastAI.jl
MIT License
585 stars 51 forks source link

Fix UNet for 3D convolutions (specify ndim to convxlayer and ResBlock) #263

Closed itan1 closed 1 year ago

itan1 commented 1 year ago

Fix the UNetDynamic creation for 3D images by specifying ndim to convxlayer and ResBlock. Also added a test to check the output size for a 3D UNet.

Now

julia> backbone = FastVision.Models.xresnet18(ndim=3);

julia> model = FastVision.Models.UNetDynamic(backbone, (128,128,128,3,1), 4);

julia> Flux.outputsize(model, (128, 128, 128, 3, 1)) == (128, 128, 128, 4, 1)
true

Previously:

julia> backbone = FastVision.Models.xresnet18(ndim=3);

julia> model = FastVision.Models.UNetDynamic(backbone, (128,128,128,3,1), 4);
┌ Error: layer SkipConnection(Chain(child = Chain(FastVision.Models.ResBlock(Chain(Chain(Conv((3, 3, 3), 256 => 512, pad=1, stride=2), BatchNorm(512, relu)), Chain(Conv((3, 3, 3), 512 => 512, pad=1), BatchNorm(512))), Chain(Conv((1, 1, 1), 256 => 512), BatchNorm(512, relu)), MeanPool((2, 2, 2))), Chain(FastVision.Models.ResBlock(Chain(Chain(Conv((3, 3, 3), 512 => 512, pad=1), BatchNorm(512, relu)), Chain(Conv((3, 3, 3), 512 => 512, pad=1), BatchNorm(512))), identity, identity), identity)), upsample = Chain(Chain(Conv((1, 1), 512 => 2048), BatchNorm(2048, relu)), PixelShuffle(2))), Parallel(catchannels, identity, BatchNorm(256))), index 2 in Chain, gave an error with input of size (8, 8, 8, 256, 1)
└ @ Flux ~/.julia/packages/Flux/4k0Ls/src/outputsize.jl:107
ERROR: DimensionMismatch: Rank of x and w must match! (5 vs. 4)
Stacktrace:
  [1] DenseConvDims(x::Array{Flux.NilNumber.Nil, 5}, w::Array{Float32, 4}; kwargs::Base.Pairs{Symbol, Any, NTuple{4, Symbol}, NamedTuple{(:stride, :padding, :dilation, :groups), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}, Int64}}})
    @ NNlib ~/.julia/packages/NNlib/0QnJJ/src/dim_helpers/DenseConvDims.jl:49
  [2] conv_dims(c::Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, x::Array{Flux.NilNumber.Nil, 5})
    @ Flux ~/.julia/packages/Flux/4k0Ls/src/layers/conv.jl:192
  [3] (::Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}})(x::Array{Flux.NilNumber.Nil, 5})
    @ Flux ~/.julia/packages/Flux/4k0Ls/src/layers/conv.jl:199
  [4] macro expansion
    @ ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:53 [inlined]
  [5] _applychain(layers::Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}, x::Array{Flux.NilNumber.Nil, 5})
    @ Flux ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:53
  [6] (::Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}})(x::Array{Flux.NilNumber.Nil, 5})
    @ Flux ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:51
  [7] macro expansion
    @ ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:53 [inlined]
  [8] _applychain(layers::Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, PixelShuffle}, x::Array{Flux.NilNumber.Nil, 5})
    @ Flux ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:53
  [9] (::Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, PixelShuffle}})(x::Array{Flux.NilNumber.Nil, 5})
    @ Flux ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:51
 [10] macro expansion
    @ ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:53 [inlined]
 [11] _applychain(layers::Tuple{Chain{Tuple{FastVision.Models.ResBlock, Chain{Tuple{FastVision.Models.ResBlock, typeof(identity)}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, PixelShuffle}}}, x::Array{Flux.NilNumber.Nil, 5})
    @ Flux ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:53
 [12] _applychain
    @ ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:59 [inlined]
 [13] Chain
    @ ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:51 [inlined]
 [14] (::SkipConnection{Chain{NamedTuple{(:child, :upsample), Tuple{Chain{Tuple{FastVision.Models.ResBlock, Chain{Tuple{FastVision.Models.ResBlock, typeof(identity)}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, PixelShuffle}}}}}, Parallel{typeof(FastVision.Models.catchannels), Tuple{typeof(identity), BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}})(input::Array{Flux.NilNumber.Nil, 5})
    @ Flux ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:344
 [15] outputsize(m::Chain{Tuple{FastVision.Models.ResBlock, SkipConnection{Chain{NamedTuple{(:child, :upsample), Tuple{Chain{Tuple{FastVision.Models.ResBlock, Chain{Tuple{FastVision.Models.ResBlock, typeof(identity)}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, PixelShuffle}}}}}, Parallel{typeof(FastVision.Models.catchannels), Tuple{typeof(identity), BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}, FastVision.Models.var"#94#95", Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, inputsizes::NTuple{5, Int64}; padbatch::Bool)
    @ Flux ~/.julia/packages/Flux/4k0Ls/src/outputsize.jl:104
 [16] outputsize(m::Chain{Tuple{FastVision.Models.ResBlock, SkipConnection{Chain{NamedTuple{(:child, :upsample), Tuple{Chain{Tuple{FastVision.Models.ResBlock, Chain{Tuple{FastVision.Models.ResBlock, typeof(identity)}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, PixelShuffle}}}}}, Parallel{typeof(FastVision.Models.catchannels), Tuple{typeof(identity), BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}, FastVision.Models.var"#94#95", Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, inputsizes::NTuple{5, Int64})
    @ Flux ~/.julia/packages/Flux/4k0Ls/src/outputsize.jl:100
 [17] unetlayers(layers::Vector{Any}, sz::NTuple{5, Int64}; k_out::Nothing, skip_upscale::Int64, m_middle::FastVision.Models.var"#91#93")
    @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:76
 [18] unetlayers(layers::Vector{Any}, sz::NTuple{5, Int64}; k_out::Nothing, skip_upscale::Int64, m_middle::FastVision.Models.var"#91#93")
    @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:66
 [19] unetlayers(layers::Vector{Any}, sz::NTuple{5, Int64}; k_out::Nothing, skip_upscale::Int64, m_middle::FastVision.Models.var"#91#93")
    @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:75
 [20] unetlayers(layers::Vector{Any}, sz::NTuple{5, Int64}; k_out::Nothing, skip_upscale::Int64, m_middle::FastVision.Models.var"#91#93") (repeats 2 times)
    @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:66
 [21] unetlayers(layers::Vector{Any}, sz::NTuple{5, Int64}; k_out::Nothing, skip_upscale::Int64, m_middle::FastVision.Models.var"#91#93")
    @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:75
 [22] unetlayers(layers::Vector{Any}, sz::NTuple{5, Int64}; k_out::Nothing, skip_upscale::Int64, m_middle::FastVision.Models.var"#91#93") (repeats 5 times)
    @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:66
 [23] unetlayers(layers::Vector{Any}, sz::NTuple{5, Int64}; k_out::Nothing, skip_upscale::Int64, m_middle::typeof(FastVision.Models.UNetMiddleBlock))
    @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:75
 [24] UNetDynamic(backbone::Chain{Tuple{Chain{Tuple{Conv{3, 6, typeof(identity), Array{Float32, 5}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{Conv{3, 6, typeof(identity), Array{Float32, 5}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{Conv{3, 6, typeof(identity), Array{Float32, 5}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, MaxPool{3, 6}, Chain{Tuple{FastVision.Models.ResBlock, FastVision.Models.ResBlock}}, Chain{Tuple{FastVision.Models.ResBlock, FastVision.Models.ResBlock}}, Chain{Tuple{FastVision.Models.ResBlock, FastVision.Models.ResBlock}}, Chain{Tuple{FastVision.Models.ResBlock, FastVision.Models.ResBlock}}}}, inputsize::NTuple{5, Int64}, k_out::Int64; final::typeof(FastVision.Models.UNetFinalBlock), fdownscale::Int64, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:38
 [25] UNetDynamic(backbone::Chain{Tuple{Chain{Tuple{Conv{3, 6, typeof(identity), Array{Float32, 5}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{Conv{3, 6, typeof(identity), Array{Float32, 5}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{Conv{3, 6, typeof(identity), Array{Float32, 5}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, MaxPool{3, 6}, Chain{Tuple{FastVision.Models.ResBlock, FastVision.Models.ResBlock}}, Chain{Tuple{FastVision.Models.ResBlock, FastVision.Models.ResBlock}}, Chain{Tuple{FastVision.Models.ResBlock, FastVision.Models.ResBlock}}, Chain{Tuple{FastVision.Models.ResBlock, FastVision.Models.ResBlock}}}}, inputsize::NTuple{5, Int64}, k_out::Int64)
    @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:31
 [26] top-level scope
    @ REPL[103]:1
 [27] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52

PR Checklist

lorenzoh commented 1 year ago

LGTM!