Open IanButterworth opened 9 months ago
This also looks good vs. my code. Only main difference I see is the use of data augmentation. It would be nice to see if that makes a meaningful difference.
Current status
┌ Info: Benchmarking
│ epochs = 45
│ batchsize = 1000
│ device = gpu (generic function with 5 methods)
└ imsize = (32, 32)
25×5 DataFrame
Row │ model train_loss train_acc test_loss test_acc
│ String Float64? Float64? Float64? Float64?
─────┼─────────────────────────────────────────────────────────────────────────────────────────────
1 │ AlexNet(; pretrain=false, inchan… missing missing missing missing DimensionMismatch: Kernel * dilation ((3 - 1) * 1 + 1) cannot be larger than input + padding (1 + 0 + 0)!
2 │ VGG(11, batchnorm=true; pretrain… missing missing missing missing DimensionMismatch: layer Dense(25088 => 4096, relu) expects size(input, 1) == 25088, but got 512×1000 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
3 │ SqueezeNet(; pretrain=false, inc… 2.22064 0.1115 2.22714 0.112
4 │ ResNet(18; pretrain=false, incha… 1.11878 0.586 1.29114 0.524
5 │ WideResNet(50; pretrain=false, i… missing missing missing missing CUDNNError: CUDNN_STATUS_BAD_PARAM (code 3)
6 │ ResNeXt(50, cardinality=32, base… missing missing missing missing CUDNNError: CUDNN_STATUS_BAD_PARAM (code 3)
7 │ SEResNet(18; pretrain=false, inc… 1.1682 0.5885 1.31459 0.5405
8 │ SEResNeXt(50, cardinality=32, ba… missing missing missing missing CUDNNError: CUDNN_STATUS_BAD_PARAM (code 3)
9 │ Res2Net(50, base_width=26, scale… missing missing missing missing Scalar indexing is disallowed.
10 │ Res2NeXt(50; pretrain=false, inc… missing missing missing missing Scalar indexing is disallowed.
11 │ GoogLeNet(batchnorm=true; pretra… 1.45998 0.497 1.53686 0.4725
12 │ DenseNet(121; pretrain=false, in… 1.18146 0.5835 1.41045 0.5115
13 │ Inceptionv3(; pretrain=false, in… missing missing missing missing DimensionMismatch: Kernel * dilation ((3 - 1) * 1 + 1) cannot be larger than input + padding (1 + 0 + 0)!
14 │ Inceptionv4(; pretrain=false, in… missing missing missing missing DimensionMismatch: Kernel * dilation ((3 - 1) * 1 + 1) cannot be larger than input + padding (1 + 0 + 0)!
15 │ InceptionResNetv2(; pretrain=fal… missing missing missing missing DimensionMismatch: Kernel * dilation ((3 - 1) * 1 + 1) cannot be larger than input + padding (1 + 0 + 0)!
16 │ Xception(; pretrain=false, incha… 1.97242 0.304 1.97524 0.3
17 │ MobileNetv1(0.5; pretrain=false,… 2.01992 0.2465 2.07331 0.23
18 │ MobileNetv2(0.5; pretrain=false,… 2.06227 0.2065 2.08002 0.2005
19 │ MobileNetv3(:small, width_mult=0… 2.13884 0.2085 2.09306 0.223
20 │ MNASNet(:A1, width_mult=0.5; pre… 2.31312 0.118 2.32238 0.1015
21 │ EfficientNet(:b0; pretrain=false… 2.26213 0.158 2.26398 0.1455
22 │ EfficientNetv2(:small; pretrain=… 2.30173 0.1145 2.30193 0.1105
23 │ ConvMixer(:small; pretrain=false… missing missing missing missing Out of GPU memory trying to allocate 1.046 GiB
24 │ ConvNeXt(:small; pretrain=false,… missing missing missing missing Out of GPU memory trying to allocate 585.938 MiB
25 │ ViT(:tiny; pretrain=false, incha… missing missing missing missing DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 5 and 197
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Time Allocations
─────────────────────── ────────────────────────
Tot / % measured: 486s / 96.8% 60.8GiB / 100.0%
Section ncalls time %tot avg alloc %tot avg
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
AlexNet(; pretrain=false, inchannels=3, nclasses=10) 1 1.11s 0.2% 1.11s 883MiB 1.4% 883MiB
Load 1 163ms 0.0% 163ms 435MiB 0.7% 435MiB
Training 1 658ms 0.1% 658ms 12.8MiB 0.0% 12.8MiB
batch step 1 16.8ms 0.0% 16.8ms 203KiB 0.0% 203KiB
VGG(11, batchnorm=true; pretrain=false, inchannels=3, nclasses=10) 1 1.55s 0.3% 1.55s 1.93GiB 3.2% 1.93GiB
Load 1 438ms 0.1% 438ms 0.96GiB 1.6% 0.96GiB
Training 1 746ms 0.2% 746ms 12.9MiB 0.0% 12.9MiB
batch step 1 13.7ms 0.0% 13.7ms 316KiB 0.0% 316KiB
SqueezeNet(; pretrain=false, inchannels=3, nclasses=10) 1 29.8s 6.3% 29.8s 3.49GiB 5.7% 3.49GiB
Load 1 1.40ms 0.0% 1.40ms 5.59MiB 0.0% 5.59MiB
Training 1 29.8s 6.3% 29.8s 3.47GiB 5.7% 3.47GiB
batch step 90 850ms 0.2% 9.44ms 264MiB 0.4% 2.93MiB
testing 45 644ms 0.1% 14.3ms 1.06GiB 1.7% 24.0MiB
ResNet(18; pretrain=false, inchannels=3, nclasses=10) 1 31.8s 6.8% 31.8s 3.66GiB 6.0% 3.66GiB
Load 1 14.9ms 0.0% 14.9ms 85.4MiB 0.1% 85.4MiB
Training 1 31.8s 6.8% 31.8s 3.49GiB 5.7% 3.49GiB
batch step 90 1.29s 0.3% 14.3ms 295MiB 0.5% 3.28MiB
testing 45 1.24s 0.3% 27.5ms 1.05GiB 1.7% 23.9MiB
WideResNet(50; pretrain=false, inchannels=3, nclasses=10) 1 1.10s 0.2% 1.10s 1.01GiB 1.7% 1.01GiB
Load 1 182ms 0.0% 182ms 510MiB 0.8% 510MiB
Training 1 738ms 0.2% 738ms 16.7MiB 0.0% 16.7MiB
batch step 1 44.1ms 0.0% 44.1ms 3.36MiB 0.0% 3.36MiB
ResNeXt(50, cardinality=32, base_width=4; pretrain=false, inchannels=3, nclasses=10) 1 821ms 0.2% 821ms 368MiB 0.6% 368MiB
Load 1 59.1ms 0.0% 59.1ms 176MiB 0.3% 176MiB
Training 1 702ms 0.1% 702ms 16.7MiB 0.0% 16.7MiB
batch step 1 35.9ms 0.0% 35.9ms 3.36MiB 0.0% 3.36MiB
SEResNet(18; pretrain=false, inchannels=3, nclasses=10) 1 33.6s 7.1% 33.6s 3.96GiB 6.5% 3.96GiB
Load 1 16.3ms 0.0% 16.3ms 86.2MiB 0.1% 86.2MiB
Training 1 33.6s 7.1% 33.6s 3.79GiB 6.2% 3.79GiB
batch step 90 1.75s 0.4% 19.4ms 575MiB 0.9% 6.39MiB
testing 45 1.31s 0.3% 29.1ms 1.06GiB 1.7% 24.2MiB
SEResNeXt(50, cardinality=32, base_width=4; pretrain=false, inchannels=3, nclasses=10) 1 833ms 0.2% 833ms 410MiB 0.7% 410MiB
Load 1 66.0ms 0.0% 66.0ms 195MiB 0.3% 195MiB
Training 1 700ms 0.1% 700ms 19.5MiB 0.0% 19.5MiB
batch step 1 43.5ms 0.0% 43.5ms 5.75MiB 0.0% 5.75MiB
Res2Net(50, base_width=26, scale=4; pretrain=false, inchannels=3, nclasses=10) 1 787ms 0.2% 787ms 376MiB 0.6% 376MiB
Load 1 56.0ms 0.0% 56.0ms 181MiB 0.3% 181MiB
Training 1 672ms 0.1% 672ms 14.1MiB 0.0% 14.1MiB
batch step 1 19.3ms 0.0% 19.3ms 65.7KiB 0.0% 65.7KiB
Res2NeXt(50; pretrain=false, inchannels=3, nclasses=10) 1 781ms 0.2% 781ms 361MiB 0.6% 361MiB
Load 1 59.5ms 0.0% 59.5ms 173MiB 0.3% 173MiB
Training 1 661ms 0.1% 661ms 14.1MiB 0.0% 14.1MiB
batch step 1 6.05ms 0.0% 6.05ms 65.7KiB 0.0% 65.7KiB
GoogLeNet(batchnorm=true; pretrain=false, inchannels=3, nclasses=10) 1 36.7s 7.8% 36.7s 4.23GiB 7.0% 4.23GiB
Load 1 11.4ms 0.0% 11.4ms 45.9MiB 0.1% 45.9MiB
Training 1 36.7s 7.8% 36.7s 4.14GiB 6.8% 4.14GiB
batch step 90 2.92s 0.6% 32.4ms 864MiB 1.4% 9.60MiB
testing 45 1.56s 0.3% 34.7ms 1.09GiB 1.8% 24.9MiB
DenseNet(121; pretrain=false, inchannels=3, nclasses=10) 1 47.3s 10.1% 47.3s 6.16GiB 10.1% 6.16GiB
Load 1 16.5ms 0.0% 16.5ms 53.5MiB 0.1% 53.5MiB
Training 1 47.3s 10.0% 47.3s 6.05GiB 10.0% 6.05GiB
batch step 90 8.98s 1.9% 100ms 2.62GiB 4.3% 29.8MiB
testing 45 4.18s 0.9% 92.8ms 1.16GiB 1.9% 26.4MiB
Inceptionv3(; pretrain=false, inchannels=3, nclasses=10) 1 738ms 0.2% 738ms 348MiB 0.6% 348MiB
Load 1 42.4ms 0.0% 42.4ms 167MiB 0.3% 167MiB
Training 1 655ms 0.1% 655ms 14.7MiB 0.0% 14.7MiB
batch step 1 7.58ms 0.0% 7.58ms 862KiB 0.0% 862KiB
Inceptionv4(; pretrain=false, inchannels=3, nclasses=10) 1 829ms 0.2% 829ms 645MiB 1.0% 645MiB
Load 1 78.7ms 0.0% 78.7ms 314MiB 0.5% 314MiB
Training 1 669ms 0.1% 669ms 16.0MiB 0.0% 16.0MiB
batch step 1 17.3ms 0.0% 17.3ms 1.38MiB 0.0% 1.38MiB
InceptionResNetv2(; pretrain=false, inchannels=3, nclasses=10) 1 921ms 0.2% 921ms 851MiB 1.4% 851MiB
Load 1 126ms 0.0% 126ms 416MiB 0.7% 416MiB
Training 1 671ms 0.1% 671ms 19.0MiB 0.0% 19.0MiB
batch step 1 6.96ms 0.0% 6.96ms 196KiB 0.0% 196KiB
Xception(; pretrain=false, inchannels=3, nclasses=10) 1 50.8s 10.8% 50.8s 4.54GiB 7.5% 4.54GiB
Load 1 52.8ms 0.0% 52.8ms 159MiB 0.3% 159MiB
Training 1 50.7s 10.8% 50.7s 4.23GiB 7.0% 4.23GiB
batch step 90 6.78s 1.4% 75.3ms 0.94GiB 1.6% 10.7MiB
testing 45 5.13s 1.1% 114ms 1.09GiB 1.8% 24.9MiB
MobileNetv1(0.5; pretrain=false, inchannels=3, nclasses=10) 1 32.0s 6.8% 32.0s 3.57GiB 5.9% 3.57GiB
Load 1 2.42ms 0.0% 2.42ms 6.39MiB 0.0% 6.39MiB
Training 1 32.0s 6.8% 32.0s 3.56GiB 5.9% 3.56GiB
batch step 90 1.34s 0.3% 14.9ms 352MiB 0.6% 3.91MiB
testing 45 1.08s 0.2% 24.1ms 1.06GiB 1.7% 24.1MiB
MobileNetv2(0.5; pretrain=false, inchannels=3, nclasses=10) 1 33.0s 7.0% 33.0s 3.98GiB 6.5% 3.98GiB
Load 1 2.35ms 0.0% 2.35ms 5.54MiB 0.0% 5.54MiB
Training 1 33.0s 7.0% 33.0s 3.97GiB 6.5% 3.97GiB
batch step 90 1.93s 0.4% 21.4ms 721MiB 1.2% 8.02MiB
testing 45 1.32s 0.3% 29.2ms 1.08GiB 1.8% 24.5MiB
MobileNetv3(:small, width_mult=0.5; pretrain=false, inchannels=3, nclasses=10) 1 32.6s 6.9% 32.6s 4.03GiB 6.6% 4.03GiB
Load 1 2.06ms 0.0% 2.06ms 4.68MiB 0.0% 4.68MiB
Training 1 32.6s 6.9% 32.6s 4.02GiB 6.6% 4.02GiB
batch step 90 1.93s 0.4% 21.5ms 775MiB 1.2% 8.61MiB
testing 45 846ms 0.2% 18.8ms 1.08GiB 1.8% 24.5MiB
MNASNet(:A1, width_mult=0.5; pretrain=false, inchannels=3, nclasses=10) 1 33.7s 7.2% 33.7s 4.17GiB 6.9% 4.17GiB
Load 1 2.98ms 0.0% 2.98ms 6.45MiB 0.0% 6.45MiB
Training 1 33.7s 7.2% 33.7s 4.16GiB 6.8% 4.16GiB
batch step 90 2.30s 0.5% 25.5ms 900MiB 1.4% 10.0MiB
testing 45 1.33s 0.3% 29.5ms 1.08GiB 1.8% 24.7MiB
EfficientNet(:b0; pretrain=false, inchannels=3, nclasses=10) 1 38.7s 8.2% 38.7s 4.54GiB 7.5% 4.54GiB
Load 1 10.9ms 0.0% 10.9ms 30.9MiB 0.0% 30.9MiB
Training 1 38.7s 8.2% 38.7s 4.48GiB 7.4% 4.48GiB
batch step 90 3.17s 0.7% 35.3ms 1.16GiB 1.9% 13.2MiB
testing 45 2.42s 0.5% 53.9ms 1.10GiB 1.8% 25.1MiB
EfficientNetv2(:small; pretrain=false, inchannels=3, nclasses=10) 1 57.2s 12.2% 57.2s 6.22GiB 10.2% 6.22GiB
Load 1 51.8ms 0.0% 51.8ms 155MiB 0.2% 155MiB
Training 1 57.1s 12.1% 57.1s 5.92GiB 9.7% 5.92GiB
batch step 90 14.1s 3.0% 156ms 2.44GiB 4.0% 27.7MiB
testing 45 7.41s 1.6% 165ms 1.19GiB 2.0% 27.0MiB
ConvMixer(:small; pretrain=false, inchannels=3, nclasses=10) 1 1.56s 0.3% 1.56s 326MiB 0.5% 326MiB
Load 1 49.8ms 0.0% 49.8ms 155MiB 0.2% 155MiB
Training 1 1.46s 0.3% 1.46s 15.6MiB 0.0% 15.6MiB
batch step 1 802ms 0.2% 802ms 2.13MiB 0.0% 2.13MiB
ConvNeXt(:small; pretrain=false, inchannels=3, nclasses=10) 1 1.51s 0.3% 1.51s 773MiB 1.2% 773MiB
Load 1 41.3ms 0.0% 41.3ms 377MiB 0.6% 377MiB
Training 1 1.43s 0.3% 1.43s 18.8MiB 0.0% 18.8MiB
batch step 1 761ms 0.2% 761ms 4.00MiB 0.0% 4.00MiB
ViT(:tiny; pretrain=false, inchannels=3, nclasses=10) 1 662ms 0.1% 662ms 97.1MiB 0.2% 97.1MiB
Load 1 4.70ms 0.0% 4.70ms 42.0MiB 0.1% 42.0MiB
Training 1 652ms 0.1% 652ms 13.2MiB 0.0% 13.2MiB
batch step 1 5.50ms 0.0% 5.50ms 268KiB 0.0% 268KiB
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
I think the Inception errors are because that family of models use an image size of 299x299, and I don't think they support alternate image sizes. AlexNet uses 224x224 and doesn't support anything else either. VGG needs a special imsize parameter passed to it to work for smaller image sizes, as does ViT. So those errors can at least be diagnosed at a glance.
The ResNet family errors are weird. One set seems to be with the larger ResNet variants, which seems to be some sort of memory issue? Correct me if I'm wrong. The Res2Net variant was never tested on GPU unfortunately, which means most likely there is something in the code incompatible with GPU 😬. There's some "clever" code there which I think may not be as GPU compatible.
I can confirm with https://github.com/FluxML/Metalhead.jl/pull/262 that the scalar indexing errors are real. Thanks for helping me corroborate those. For now, those models can be ignored here.
The CUDNNError: CUDNN_STATUS_BAD_PARAM (code 3)
errors may be related to running out of memory. Try calling CUDA.reclaim()
after GC.gc(true)
, because the latter is often not sufficient to get CUDA.jl/the CUDA driver to return VRAM to the system. The reason I think this is OOM-related is because that's how the error came up before, and because #262 CI shows the same error for a different set of models (suggesting non-determinism).
I think the Inception errors are because that family of models use an image size of 299x299, and I don't think they support alternate image sizes. AlexNet uses 224x224 and doesn't support anything else either.
These work with arbitrary sizes, but there is a lower bound on how small an image they can handle. I guess 32x32 is too small.
The ResNet family errors are weird. One set seems to be with the larger ResNet variants, which seems to be some sort of memory issue? Correct me if I'm wrong. The Res2Net variant was never tested on GPU unfortunately, which means most likely there is something in the code incompatible with GPU 😬. There's some "clever" code there which I think may not be as GPU compatible.
The use of MLUtils.chunk
on the forward pass has two problems:
NNlib.conv
on the resulting SubArray
is what actually throws an exceptionShortened stack trace (cut off the part that is irrelevant):
julia> m(x)
┌ Warning: Performing scalar indexing on task Task (runnable) @0x00007f1ca13fc010.
│ Invocation of getindex resulted in scalar indexing of a GPU array.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on the GPU, but very slowly on the CPU,
│ and therefore are only permitted from the REPL for prototyping purposes.
│ If you did intend to index this array, annotate the caller with @allowscalar.
└ @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:106
┌ Warning: Performing scalar indexing on task Task (runnable) @0x00007f1a2236dc30.
│ Invocation of getindex resulted in scalar indexing of a GPU array.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on the GPU, but very slowly on the CPU,
│ and therefore are only permitted from the REPL for prototyping purposes.
│ If you did intend to index this array, annotate the caller with @allowscalar.
└ @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:106
ERROR: TaskFailedException
nested task error: TaskFailedException
nested task error: MethodError: no method matching gemm!(::Val{false}, ::Val{false}, ::Int64, ::Int64, ::Int64, ::Float32, ::CuPtr{Float32}, ::CuPtr{Float32}, ::Float32, ::CuPtr{Float32})
Closest candidates are:
gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::Float32, ::Ptr{Float32}, ::Ptr{Float32}, ::Float32, ::Ptr{Float32})
@ NNlib ~/.julia/packages/NNlib/sXmAj/src/gemm.jl:29
gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::Float64, ::Ptr{Float64}, ::Ptr{Float64}, ::Float64, ::Ptr{Float64})
@ NNlib ~/.julia/packages/NNlib/sXmAj/src/gemm.jl:29
gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::ComplexF64, ::Ptr{ComplexF64}, ::Ptr{ComplexF64}, ::ComplexF64, ::Ptr{ComplexF64})
@ NNlib ~/.julia/packages/NNlib/sXmAj/src/gemm.jl:29
...
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/NNlib/sXmAj/src/impl/conv_im2col.jl:59 [inlined]
[2] (::NNlib.var"#647#648"{CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, Float32, Float32, SubArray{Float32, 5, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Base.ReshapedArray{Float32, 5, SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, DenseConvDims{3, 3, 3, 6, 3}, Int64, Int64, Int64, UnitRange{Int64}, Int64})()
@ NNlib ./threadingconstructs.jl:416
Stacktrace:
[1] sync_end(c::Channel{Any})
@ Base ./task.jl:445
[2] macro expansion
@ ./task.jl:477 [inlined]
[3] conv_im2col!(y::SubArray{Float32, 5, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{Float32, 5, Base.ReshapedArray{Float32, 5, SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{3, 3, 3, 6, 3}; col::CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, alpha::Float32, beta::Float32, ntasks::Int64)
@ NNlib ~/.julia/packages/NNlib/sXmAj/src/impl/conv_im2col.jl:50
[4] conv_im2col!(y::SubArray{Float32, 5, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{Float32, 5, Base.ReshapedArray{Float32, 5, SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{3, 3, 3, 6, 3})
@ NNlib ~/.julia/packages/NNlib/sXmAj/src/impl/conv_im2col.jl:23
[5] (::NNlib.var"#305#309"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, DenseConvDims{3, 3, 3, 6, 3}, SubArray{Float32, 5, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, SubArray{Float32, 5, Base.ReshapedArray{Float32, 5, SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})()
@ NNlib ./threadingconstructs.jl:416
Stacktrace:
[1] sync_end(c::Channel{Any})
@ Base ./task.jl:445
[2] macro expansion
@ ./task.jl:477 [inlined]
[3] conv!(out::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, in1::Base.ReshapedArray{Float32, 5, SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, in2::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{3, 3, 3, 6, 3}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ NNlib ~/.julia/packages/NNlib/sXmAj/src/conv.jl:205
[4] conv!
@ ~/.julia/packages/NNlib/sXmAj/src/conv.jl:185 [inlined]
[5] #conv!#264
@ ~/.julia/packages/NNlib/sXmAj/src/conv.jl:145 [inlined]
[6] conv!
@ ~/.julia/packages/NNlib/sXmAj/src/conv.jl:140 [inlined]
[7] conv(x::SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ NNlib ~/.julia/packages/NNlib/sXmAj/src/conv.jl:88
[8] conv
@ ~/.julia/packages/NNlib/sXmAj/src/conv.jl:83 [inlined]
[9] (::Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Bool})(x::SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false})
@ Flux ~/.julia/packages/Flux/jgpVj/src/layers/conv.jl:202
[10] macro expansion
@ ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:53 [inlined]
[11] _applychain
@ ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:53 [inlined]
[12] Chain
@ ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:51 [inlined]
[13] |>
@ ./operators.jl:907 [inlined]
[14] map (repeats 2 times)
@ ./tuple.jl:302 [inlined]
[15] (::Parallel{typeof(Metalhead.cat_channels), Tuple{MeanPool{2, 4}, Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Bool}, BatchNorm{typeof(relu), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Bool}, BatchNorm{typeof(relu), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Bool}, BatchNorm{typeof(relu), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}})(::SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, ::Vararg{SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})
@ Flux ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:541
As discussed in https://github.com/FluxML/Metalhead.jl/pull/198#issuecomment-1846064419 I think it would be good to demonstrate that each of these are trainable on a generic dataset, and while doing so collect benchmark information.
I am happy to run this all locally, but want to collect feedback before doing so, in case these models have nuance that is worth taking into account.
The general approach here is to train each of the smallest variants of the models found in
/test