FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.53k stars 610 forks source link

Depthwise convolutions produce a large number of allocations #2508

Open JoshuaBillson opened 1 week ago

JoshuaBillson commented 1 week ago

Depthwise convolutions, which are currently implemented as a standard Conv layer with the number of groups equal to the number of input channels, seem to produce a very large number of allocations compared to the old DepthwiseConv layer. To confirm this, I restored the DepthwiseConv layer removed in #1921 and compared the performance to the current implementation.

Running the code below shows the following:

Conv with groups=1024 produces around 750 times as many allocations as DepthwiseConv. The result is that CNN architectures which rely on depthwise convolutions produce hundreds of thousands of allocations compared to only a few thousand for comparably sized models with standard convolutions. Is there any reason for this discrepancy?

Note: I am testing this on Julia 1.10 with Flux v0.14.22.

using Flux, BenchmarkTools

struct DepthwiseConv{N,M,F,A,V}
  σ::F
  weight::A
  bias::V
  stride::NTuple{N,Int}
  pad::NTuple{M,Int}
  dilation::NTuple{N,Int}
end
function DepthwiseConv(k::NTuple{<:Any,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; 
            stride = 1, pad = 0, dilation = 1, bias = true, init = Flux.glorot_uniform)
  Conv(k, ch, σ; groups=ch.first, stride, pad, dilation, bias, init)
end

function DepthwiseConv(w::AbstractArray{T,N}, bias = true, σ = identity;
                      stride = 1, pad = 0, dilation = 1) where {T,N}
  stride = Flux.expand(Val(N-2), stride)
  dilation = Flux.expand(Val(N-2), dilation)
  pad = Flux.calc_padding(DepthwiseConv, pad, size(w)[1:N-2], dilation, stride)
  b = Flux.create_bias(w, bias, prod(size(w)[N-1:end]))
  return DepthwiseConv(σ, w, b, stride, pad, dilation)
end

function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
                init = Flux.glorot_uniform, stride = 1, pad = 0, dilation = 1,
                bias = true) where N
  @assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
  weight = depthwiseconvfilter(k, ch, init = init)
  return DepthwiseConv(weight, bias, σ; stride, pad, dilation)
end

Flux.@functor DepthwiseConv

depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
                    init = Flux.glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])

function (c::DepthwiseConv)(x)
  σ = NNlib.fast_act(c.σ, x)
  cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
  σ.(depthwiseconv(x, c.weight, cdims) .+ Flux.conv_reshape_bias(c))
end

x = rand(Float32, 28, 28, 1024, 1)

conv = Conv((3,3), 1024=>1024, relu, pad=SamePad())

grouped_conv = Conv((3,3), 1024=>1024, relu, pad=SamePad(), groups=1024)

depthwise_conv = DepthwiseConv((3,3), 1024=>1024, relu, pad=SamePad())

@btime conv(x);

@btime grouped_conv(x);

@btime depthwise_conv(x);
mcabbott commented 1 day ago

I see I'm blamed in https://github.com/FluxML/Flux.jl/pull/1921 for suggesting that change, although I've forgotten why.

With the code above, I see similar numbers to you, grouped_conv is faster but has many small allocations:

julia> depthwise_conv.weight .= grouped_conv.weight;

julia> y1 = @btime grouped_conv(x);
  4.430 ms (23577 allocations: 201.13 MiB)

julia> y2 = @btime depthwise_conv(x);
  15.584 ms (27 allocations: 199.06 MiB)

julia> y1 ≈ y2
true

Repeating the benchmarks of #1921 today... Flux.DepthwiseConv with groups has many small allocations, more than seen in #1921, although even then it was an increase over before:

julia> x = randn(Float32, 128, 128, 32, 32);

julia> dconv1 = Flux.DepthwiseConv((3,3), 32 => 64)  # using groups, after 1921
Conv((3, 3), 32 => 64, groups=32)  # 640 parameters

julia> z1 = @btime $dconv1($x);
  38.161 ms (1236 allocations: 370.29 MiB)

julia> dconv2 = DepthwiseConv((3,3), 32 => 64);  # using code above

julia> copyto!(dconv2.weight, dconv1.weight); # 3×3×2×32  from 3×3×1×64

julia> z2 = @btime $dconv2($x);
  45.090 ms (42 allocations: 370.16 MiB)

julia> z1 ≈ z2
true 

julia> Threads.nthreads()
4

I think the NNlib CPU conv code remains in need of some care... more and more layers of multi-threading were added & probably ought to be pruned.