Closed itan1 closed 1 year ago
Fix the UNetDynamic creation for 3D images by specifying ndim to convxlayer and ResBlock. Also added a test to check the output size for a 3D UNet.
ndim
convxlayer
ResBlock
Now
julia> backbone = FastVision.Models.xresnet18(ndim=3); julia> model = FastVision.Models.UNetDynamic(backbone, (128,128,128,3,1), 4); julia> Flux.outputsize(model, (128, 128, 128, 3, 1)) == (128, 128, 128, 4, 1) true
Previously:
julia> backbone = FastVision.Models.xresnet18(ndim=3); julia> model = FastVision.Models.UNetDynamic(backbone, (128,128,128,3,1), 4); ┌ Error: layer SkipConnection(Chain(child = Chain(FastVision.Models.ResBlock(Chain(Chain(Conv((3, 3, 3), 256 => 512, pad=1, stride=2), BatchNorm(512, relu)), Chain(Conv((3, 3, 3), 512 => 512, pad=1), BatchNorm(512))), Chain(Conv((1, 1, 1), 256 => 512), BatchNorm(512, relu)), MeanPool((2, 2, 2))), Chain(FastVision.Models.ResBlock(Chain(Chain(Conv((3, 3, 3), 512 => 512, pad=1), BatchNorm(512, relu)), Chain(Conv((3, 3, 3), 512 => 512, pad=1), BatchNorm(512))), identity, identity), identity)), upsample = Chain(Chain(Conv((1, 1), 512 => 2048), BatchNorm(2048, relu)), PixelShuffle(2))), Parallel(catchannels, identity, BatchNorm(256))), index 2 in Chain, gave an error with input of size (8, 8, 8, 256, 1) └ @ Flux ~/.julia/packages/Flux/4k0Ls/src/outputsize.jl:107 ERROR: DimensionMismatch: Rank of x and w must match! (5 vs. 4) Stacktrace: [1] DenseConvDims(x::Array{Flux.NilNumber.Nil, 5}, w::Array{Float32, 4}; kwargs::Base.Pairs{Symbol, Any, NTuple{4, Symbol}, NamedTuple{(:stride, :padding, :dilation, :groups), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}, Int64}}}) @ NNlib ~/.julia/packages/NNlib/0QnJJ/src/dim_helpers/DenseConvDims.jl:49 [2] conv_dims(c::Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, x::Array{Flux.NilNumber.Nil, 5}) @ Flux ~/.julia/packages/Flux/4k0Ls/src/layers/conv.jl:192 [3] (::Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}})(x::Array{Flux.NilNumber.Nil, 5}) @ Flux ~/.julia/packages/Flux/4k0Ls/src/layers/conv.jl:199 [4] macro expansion @ ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:53 [inlined] [5] _applychain(layers::Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}, x::Array{Flux.NilNumber.Nil, 5}) @ Flux ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:53 [6] (::Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}})(x::Array{Flux.NilNumber.Nil, 5}) @ Flux ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:51 [7] macro expansion @ ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:53 [inlined] [8] _applychain(layers::Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, PixelShuffle}, x::Array{Flux.NilNumber.Nil, 5}) @ Flux ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:53 [9] (::Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, PixelShuffle}})(x::Array{Flux.NilNumber.Nil, 5}) @ Flux ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:51 [10] macro expansion @ ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:53 [inlined] [11] _applychain(layers::Tuple{Chain{Tuple{FastVision.Models.ResBlock, Chain{Tuple{FastVision.Models.ResBlock, typeof(identity)}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, PixelShuffle}}}, x::Array{Flux.NilNumber.Nil, 5}) @ Flux ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:53 [12] _applychain @ ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:59 [inlined] [13] Chain @ ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:51 [inlined] [14] (::SkipConnection{Chain{NamedTuple{(:child, :upsample), Tuple{Chain{Tuple{FastVision.Models.ResBlock, Chain{Tuple{FastVision.Models.ResBlock, typeof(identity)}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, PixelShuffle}}}}}, Parallel{typeof(FastVision.Models.catchannels), Tuple{typeof(identity), BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}})(input::Array{Flux.NilNumber.Nil, 5}) @ Flux ~/.julia/packages/Flux/4k0Ls/src/layers/basic.jl:344 [15] outputsize(m::Chain{Tuple{FastVision.Models.ResBlock, SkipConnection{Chain{NamedTuple{(:child, :upsample), Tuple{Chain{Tuple{FastVision.Models.ResBlock, Chain{Tuple{FastVision.Models.ResBlock, typeof(identity)}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, PixelShuffle}}}}}, Parallel{typeof(FastVision.Models.catchannels), Tuple{typeof(identity), BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}, FastVision.Models.var"#94#95", Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, inputsizes::NTuple{5, Int64}; padbatch::Bool) @ Flux ~/.julia/packages/Flux/4k0Ls/src/outputsize.jl:104 [16] outputsize(m::Chain{Tuple{FastVision.Models.ResBlock, SkipConnection{Chain{NamedTuple{(:child, :upsample), Tuple{Chain{Tuple{FastVision.Models.ResBlock, Chain{Tuple{FastVision.Models.ResBlock, typeof(identity)}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, PixelShuffle}}}}}, Parallel{typeof(FastVision.Models.catchannels), Tuple{typeof(identity), BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}, FastVision.Models.var"#94#95", Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, inputsizes::NTuple{5, Int64}) @ Flux ~/.julia/packages/Flux/4k0Ls/src/outputsize.jl:100 [17] unetlayers(layers::Vector{Any}, sz::NTuple{5, Int64}; k_out::Nothing, skip_upscale::Int64, m_middle::FastVision.Models.var"#91#93") @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:76 [18] unetlayers(layers::Vector{Any}, sz::NTuple{5, Int64}; k_out::Nothing, skip_upscale::Int64, m_middle::FastVision.Models.var"#91#93") @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:66 [19] unetlayers(layers::Vector{Any}, sz::NTuple{5, Int64}; k_out::Nothing, skip_upscale::Int64, m_middle::FastVision.Models.var"#91#93") @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:75 [20] unetlayers(layers::Vector{Any}, sz::NTuple{5, Int64}; k_out::Nothing, skip_upscale::Int64, m_middle::FastVision.Models.var"#91#93") (repeats 2 times) @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:66 [21] unetlayers(layers::Vector{Any}, sz::NTuple{5, Int64}; k_out::Nothing, skip_upscale::Int64, m_middle::FastVision.Models.var"#91#93") @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:75 [22] unetlayers(layers::Vector{Any}, sz::NTuple{5, Int64}; k_out::Nothing, skip_upscale::Int64, m_middle::FastVision.Models.var"#91#93") (repeats 5 times) @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:66 [23] unetlayers(layers::Vector{Any}, sz::NTuple{5, Int64}; k_out::Nothing, skip_upscale::Int64, m_middle::typeof(FastVision.Models.UNetMiddleBlock)) @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:75 [24] UNetDynamic(backbone::Chain{Tuple{Chain{Tuple{Conv{3, 6, typeof(identity), Array{Float32, 5}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{Conv{3, 6, typeof(identity), Array{Float32, 5}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{Conv{3, 6, typeof(identity), Array{Float32, 5}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, MaxPool{3, 6}, Chain{Tuple{FastVision.Models.ResBlock, FastVision.Models.ResBlock}}, Chain{Tuple{FastVision.Models.ResBlock, FastVision.Models.ResBlock}}, Chain{Tuple{FastVision.Models.ResBlock, FastVision.Models.ResBlock}}, Chain{Tuple{FastVision.Models.ResBlock, FastVision.Models.ResBlock}}}}, inputsize::NTuple{5, Int64}, k_out::Int64; final::typeof(FastVision.Models.UNetFinalBlock), fdownscale::Int64, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:38 [25] UNetDynamic(backbone::Chain{Tuple{Chain{Tuple{Conv{3, 6, typeof(identity), Array{Float32, 5}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{Conv{3, 6, typeof(identity), Array{Float32, 5}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{Conv{3, 6, typeof(identity), Array{Float32, 5}, Vector{Float32}}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}}}, MaxPool{3, 6}, Chain{Tuple{FastVision.Models.ResBlock, FastVision.Models.ResBlock}}, Chain{Tuple{FastVision.Models.ResBlock, FastVision.Models.ResBlock}}, Chain{Tuple{FastVision.Models.ResBlock, FastVision.Models.ResBlock}}, Chain{Tuple{FastVision.Models.ResBlock, FastVision.Models.ResBlock}}}}, inputsize::NTuple{5, Int64}, k_out::Int64) @ FastVision.Models ~/.julia/dev/FastAI/FastVision/src/models/unet.jl:31 [26] top-level scope @ REPL[103]:1 [27] top-level scope @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52
LGTM!
Fix the UNetDynamic creation for 3D images by specifying
ndim
toconvxlayer
andResBlock
. Also added a test to check the output size for a 3D UNet.Now
Previously:
PR Checklist