FluxML / Metalhead.jl

Computer vision models for Flux
https://fluxml.ai/Metalhead.jl
Other
328 stars 65 forks source link

Add training benchmarking script #264

Open IanButterworth opened 9 months ago

IanButterworth commented 9 months ago

As discussed in https://github.com/FluxML/Metalhead.jl/pull/198#issuecomment-1846064419 I think it would be good to demonstrate that each of these are trainable on a generic dataset, and while doing so collect benchmark information.

I am happy to run this all locally, but want to collect feedback before doing so, in case these models have nuance that is worth taking into account.

The general approach here is to train each of the smallest variants of the models found in /test

darsnack commented 9 months ago

This also looks good vs. my code. Only main difference I see is the use of data augmentation. It would be nice to see if that makes a meaningful difference.

IanButterworth commented 9 months ago

Current status

┌ Info: Benchmarking
│   epochs = 45
│   batchsize = 1000
│   device = gpu (generic function with 5 methods)
└   imsize = (32, 32)
25×5 DataFrame
 Row │ model                              train_loss     train_acc     test_loss      test_acc     
     │ String                             Float64?       Float64?      Float64?       Float64?     
─────┼─────────────────────────────────────────────────────────────────────────────────────────────
   1 │ AlexNet(; pretrain=false, inchan…  missing        missing       missing        missing      DimensionMismatch: Kernel * dilation ((3 - 1) * 1 + 1) cannot be larger than input + padding (1 + 0 + 0)!
   2 │ VGG(11, batchnorm=true; pretrain…  missing        missing       missing        missing      DimensionMismatch: layer Dense(25088 => 4096, relu) expects size(input, 1) == 25088, but got 512×1000 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
   3 │ SqueezeNet(; pretrain=false, inc…        2.22064        0.1115        2.22714        0.112
   4 │ ResNet(18; pretrain=false, incha…        1.11878        0.586         1.29114        0.524
   5 │ WideResNet(50; pretrain=false, i…  missing        missing       missing        missing      CUDNNError: CUDNN_STATUS_BAD_PARAM (code 3)
   6 │ ResNeXt(50, cardinality=32, base…  missing        missing       missing        missing      CUDNNError: CUDNN_STATUS_BAD_PARAM (code 3)
   7 │ SEResNet(18; pretrain=false, inc…        1.1682         0.5885        1.31459        0.5405
   8 │ SEResNeXt(50, cardinality=32, ba…  missing        missing       missing        missing      CUDNNError: CUDNN_STATUS_BAD_PARAM (code 3)
   9 │ Res2Net(50, base_width=26, scale…  missing        missing       missing        missing      Scalar indexing is disallowed.
  10 │ Res2NeXt(50; pretrain=false, inc…  missing        missing       missing        missing      Scalar indexing is disallowed.
  11 │ GoogLeNet(batchnorm=true; pretra…        1.45998        0.497         1.53686        0.4725
  12 │ DenseNet(121; pretrain=false, in…        1.18146        0.5835        1.41045        0.5115
  13 │ Inceptionv3(; pretrain=false, in…  missing        missing       missing        missing      DimensionMismatch: Kernel * dilation ((3 - 1) * 1 + 1) cannot be larger than input + padding (1 + 0 + 0)!
  14 │ Inceptionv4(; pretrain=false, in…  missing        missing       missing        missing      DimensionMismatch: Kernel * dilation ((3 - 1) * 1 + 1) cannot be larger than input + padding (1 + 0 + 0)!
  15 │ InceptionResNetv2(; pretrain=fal…  missing        missing       missing        missing      DimensionMismatch: Kernel * dilation ((3 - 1) * 1 + 1) cannot be larger than input + padding (1 + 0 + 0)!
  16 │ Xception(; pretrain=false, incha…        1.97242        0.304         1.97524        0.3
  17 │ MobileNetv1(0.5; pretrain=false,…        2.01992        0.2465        2.07331        0.23
  18 │ MobileNetv2(0.5; pretrain=false,…        2.06227        0.2065        2.08002        0.2005
  19 │ MobileNetv3(:small, width_mult=0…        2.13884        0.2085        2.09306        0.223
  20 │ MNASNet(:A1, width_mult=0.5; pre…        2.31312        0.118         2.32238        0.1015
  21 │ EfficientNet(:b0; pretrain=false…        2.26213        0.158         2.26398        0.1455
  22 │ EfficientNetv2(:small; pretrain=…        2.30173        0.1145        2.30193        0.1105
  23 │ ConvMixer(:small; pretrain=false…  missing        missing       missing        missing      Out of GPU memory trying to allocate 1.046 GiB
  24 │ ConvNeXt(:small; pretrain=false,…  missing        missing       missing        missing      Out of GPU memory trying to allocate 585.938 MiB
  25 │ ViT(:tiny; pretrain=false, incha…  missing        missing       missing        missing      DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 5 and 197
 ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                                                                                                           Time                    Allocations      
                                                                                                  ───────────────────────   ────────────────────────
                                         Tot / % measured:                                              486s /  96.8%           60.8GiB / 100.0%    

 Section                                                                                  ncalls     time    %tot     avg     alloc    %tot      avg
 ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
 AlexNet(; pretrain=false, inchannels=3, nclasses=10)                                          1    1.11s    0.2%   1.11s    883MiB    1.4%   883MiB
   Load                                                                                        1    163ms    0.0%   163ms    435MiB    0.7%   435MiB
   Training                                                                                    1    658ms    0.1%   658ms   12.8MiB    0.0%  12.8MiB
     batch step                                                                                1   16.8ms    0.0%  16.8ms    203KiB    0.0%   203KiB
 VGG(11, batchnorm=true; pretrain=false, inchannels=3, nclasses=10)                            1    1.55s    0.3%   1.55s   1.93GiB    3.2%  1.93GiB
   Load                                                                                        1    438ms    0.1%   438ms   0.96GiB    1.6%  0.96GiB
   Training                                                                                    1    746ms    0.2%   746ms   12.9MiB    0.0%  12.9MiB
     batch step                                                                                1   13.7ms    0.0%  13.7ms    316KiB    0.0%   316KiB
 SqueezeNet(; pretrain=false, inchannels=3, nclasses=10)                                       1    29.8s    6.3%   29.8s   3.49GiB    5.7%  3.49GiB
   Load                                                                                        1   1.40ms    0.0%  1.40ms   5.59MiB    0.0%  5.59MiB
   Training                                                                                    1    29.8s    6.3%   29.8s   3.47GiB    5.7%  3.47GiB
     batch step                                                                               90    850ms    0.2%  9.44ms    264MiB    0.4%  2.93MiB
     testing                                                                                  45    644ms    0.1%  14.3ms   1.06GiB    1.7%  24.0MiB
 ResNet(18; pretrain=false, inchannels=3, nclasses=10)                                         1    31.8s    6.8%   31.8s   3.66GiB    6.0%  3.66GiB
   Load                                                                                        1   14.9ms    0.0%  14.9ms   85.4MiB    0.1%  85.4MiB
   Training                                                                                    1    31.8s    6.8%   31.8s   3.49GiB    5.7%  3.49GiB
     batch step                                                                               90    1.29s    0.3%  14.3ms    295MiB    0.5%  3.28MiB
     testing                                                                                  45    1.24s    0.3%  27.5ms   1.05GiB    1.7%  23.9MiB
 WideResNet(50; pretrain=false, inchannels=3, nclasses=10)                                     1    1.10s    0.2%   1.10s   1.01GiB    1.7%  1.01GiB
   Load                                                                                        1    182ms    0.0%   182ms    510MiB    0.8%   510MiB
   Training                                                                                    1    738ms    0.2%   738ms   16.7MiB    0.0%  16.7MiB
     batch step                                                                                1   44.1ms    0.0%  44.1ms   3.36MiB    0.0%  3.36MiB
 ResNeXt(50, cardinality=32, base_width=4; pretrain=false, inchannels=3, nclasses=10)          1    821ms    0.2%   821ms    368MiB    0.6%   368MiB
   Load                                                                                        1   59.1ms    0.0%  59.1ms    176MiB    0.3%   176MiB
   Training                                                                                    1    702ms    0.1%   702ms   16.7MiB    0.0%  16.7MiB
     batch step                                                                                1   35.9ms    0.0%  35.9ms   3.36MiB    0.0%  3.36MiB
 SEResNet(18; pretrain=false, inchannels=3, nclasses=10)                                       1    33.6s    7.1%   33.6s   3.96GiB    6.5%  3.96GiB
   Load                                                                                        1   16.3ms    0.0%  16.3ms   86.2MiB    0.1%  86.2MiB
   Training                                                                                    1    33.6s    7.1%   33.6s   3.79GiB    6.2%  3.79GiB
     batch step                                                                               90    1.75s    0.4%  19.4ms    575MiB    0.9%  6.39MiB
     testing                                                                                  45    1.31s    0.3%  29.1ms   1.06GiB    1.7%  24.2MiB
 SEResNeXt(50, cardinality=32, base_width=4; pretrain=false, inchannels=3, nclasses=10)        1    833ms    0.2%   833ms    410MiB    0.7%   410MiB
   Load                                                                                        1   66.0ms    0.0%  66.0ms    195MiB    0.3%   195MiB
   Training                                                                                    1    700ms    0.1%   700ms   19.5MiB    0.0%  19.5MiB
     batch step                                                                                1   43.5ms    0.0%  43.5ms   5.75MiB    0.0%  5.75MiB
 Res2Net(50, base_width=26, scale=4; pretrain=false, inchannels=3, nclasses=10)                1    787ms    0.2%   787ms    376MiB    0.6%   376MiB
   Load                                                                                        1   56.0ms    0.0%  56.0ms    181MiB    0.3%   181MiB
   Training                                                                                    1    672ms    0.1%   672ms   14.1MiB    0.0%  14.1MiB
     batch step                                                                                1   19.3ms    0.0%  19.3ms   65.7KiB    0.0%  65.7KiB
 Res2NeXt(50; pretrain=false, inchannels=3, nclasses=10)                                       1    781ms    0.2%   781ms    361MiB    0.6%   361MiB
   Load                                                                                        1   59.5ms    0.0%  59.5ms    173MiB    0.3%   173MiB
   Training                                                                                    1    661ms    0.1%   661ms   14.1MiB    0.0%  14.1MiB
     batch step                                                                                1   6.05ms    0.0%  6.05ms   65.7KiB    0.0%  65.7KiB
 GoogLeNet(batchnorm=true; pretrain=false, inchannels=3, nclasses=10)                          1    36.7s    7.8%   36.7s   4.23GiB    7.0%  4.23GiB
   Load                                                                                        1   11.4ms    0.0%  11.4ms   45.9MiB    0.1%  45.9MiB
   Training                                                                                    1    36.7s    7.8%   36.7s   4.14GiB    6.8%  4.14GiB
     batch step                                                                               90    2.92s    0.6%  32.4ms    864MiB    1.4%  9.60MiB
     testing                                                                                  45    1.56s    0.3%  34.7ms   1.09GiB    1.8%  24.9MiB
 DenseNet(121; pretrain=false, inchannels=3, nclasses=10)                                      1    47.3s   10.1%   47.3s   6.16GiB   10.1%  6.16GiB
   Load                                                                                        1   16.5ms    0.0%  16.5ms   53.5MiB    0.1%  53.5MiB
   Training                                                                                    1    47.3s   10.0%   47.3s   6.05GiB   10.0%  6.05GiB
     batch step                                                                               90    8.98s    1.9%   100ms   2.62GiB    4.3%  29.8MiB
     testing                                                                                  45    4.18s    0.9%  92.8ms   1.16GiB    1.9%  26.4MiB
 Inceptionv3(; pretrain=false, inchannels=3, nclasses=10)                                      1    738ms    0.2%   738ms    348MiB    0.6%   348MiB
   Load                                                                                        1   42.4ms    0.0%  42.4ms    167MiB    0.3%   167MiB
   Training                                                                                    1    655ms    0.1%   655ms   14.7MiB    0.0%  14.7MiB
     batch step                                                                                1   7.58ms    0.0%  7.58ms    862KiB    0.0%   862KiB
 Inceptionv4(; pretrain=false, inchannels=3, nclasses=10)                                      1    829ms    0.2%   829ms    645MiB    1.0%   645MiB
   Load                                                                                        1   78.7ms    0.0%  78.7ms    314MiB    0.5%   314MiB
   Training                                                                                    1    669ms    0.1%   669ms   16.0MiB    0.0%  16.0MiB
     batch step                                                                                1   17.3ms    0.0%  17.3ms   1.38MiB    0.0%  1.38MiB
 InceptionResNetv2(; pretrain=false, inchannels=3, nclasses=10)                                1    921ms    0.2%   921ms    851MiB    1.4%   851MiB
   Load                                                                                        1    126ms    0.0%   126ms    416MiB    0.7%   416MiB
   Training                                                                                    1    671ms    0.1%   671ms   19.0MiB    0.0%  19.0MiB
     batch step                                                                                1   6.96ms    0.0%  6.96ms    196KiB    0.0%   196KiB
 Xception(; pretrain=false, inchannels=3, nclasses=10)                                         1    50.8s   10.8%   50.8s   4.54GiB    7.5%  4.54GiB
   Load                                                                                        1   52.8ms    0.0%  52.8ms    159MiB    0.3%   159MiB
   Training                                                                                    1    50.7s   10.8%   50.7s   4.23GiB    7.0%  4.23GiB
     batch step                                                                               90    6.78s    1.4%  75.3ms   0.94GiB    1.6%  10.7MiB
     testing                                                                                  45    5.13s    1.1%   114ms   1.09GiB    1.8%  24.9MiB
 MobileNetv1(0.5; pretrain=false, inchannels=3, nclasses=10)                                   1    32.0s    6.8%   32.0s   3.57GiB    5.9%  3.57GiB
   Load                                                                                        1   2.42ms    0.0%  2.42ms   6.39MiB    0.0%  6.39MiB
   Training                                                                                    1    32.0s    6.8%   32.0s   3.56GiB    5.9%  3.56GiB
     batch step                                                                               90    1.34s    0.3%  14.9ms    352MiB    0.6%  3.91MiB
     testing                                                                                  45    1.08s    0.2%  24.1ms   1.06GiB    1.7%  24.1MiB
 MobileNetv2(0.5; pretrain=false, inchannels=3, nclasses=10)                                   1    33.0s    7.0%   33.0s   3.98GiB    6.5%  3.98GiB
   Load                                                                                        1   2.35ms    0.0%  2.35ms   5.54MiB    0.0%  5.54MiB
   Training                                                                                    1    33.0s    7.0%   33.0s   3.97GiB    6.5%  3.97GiB
     batch step                                                                               90    1.93s    0.4%  21.4ms    721MiB    1.2%  8.02MiB
     testing                                                                                  45    1.32s    0.3%  29.2ms   1.08GiB    1.8%  24.5MiB
 MobileNetv3(:small, width_mult=0.5; pretrain=false, inchannels=3, nclasses=10)                1    32.6s    6.9%   32.6s   4.03GiB    6.6%  4.03GiB
   Load                                                                                        1   2.06ms    0.0%  2.06ms   4.68MiB    0.0%  4.68MiB
   Training                                                                                    1    32.6s    6.9%   32.6s   4.02GiB    6.6%  4.02GiB
     batch step                                                                               90    1.93s    0.4%  21.5ms    775MiB    1.2%  8.61MiB
     testing                                                                                  45    846ms    0.2%  18.8ms   1.08GiB    1.8%  24.5MiB
 MNASNet(:A1, width_mult=0.5; pretrain=false, inchannels=3, nclasses=10)                       1    33.7s    7.2%   33.7s   4.17GiB    6.9%  4.17GiB
   Load                                                                                        1   2.98ms    0.0%  2.98ms   6.45MiB    0.0%  6.45MiB
   Training                                                                                    1    33.7s    7.2%   33.7s   4.16GiB    6.8%  4.16GiB
     batch step                                                                               90    2.30s    0.5%  25.5ms    900MiB    1.4%  10.0MiB
     testing                                                                                  45    1.33s    0.3%  29.5ms   1.08GiB    1.8%  24.7MiB
 EfficientNet(:b0; pretrain=false, inchannels=3, nclasses=10)                                  1    38.7s    8.2%   38.7s   4.54GiB    7.5%  4.54GiB
   Load                                                                                        1   10.9ms    0.0%  10.9ms   30.9MiB    0.0%  30.9MiB
   Training                                                                                    1    38.7s    8.2%   38.7s   4.48GiB    7.4%  4.48GiB
     batch step                                                                               90    3.17s    0.7%  35.3ms   1.16GiB    1.9%  13.2MiB
     testing                                                                                  45    2.42s    0.5%  53.9ms   1.10GiB    1.8%  25.1MiB
 EfficientNetv2(:small; pretrain=false, inchannels=3, nclasses=10)                             1    57.2s   12.2%   57.2s   6.22GiB   10.2%  6.22GiB
   Load                                                                                        1   51.8ms    0.0%  51.8ms    155MiB    0.2%   155MiB
   Training                                                                                    1    57.1s   12.1%   57.1s   5.92GiB    9.7%  5.92GiB
     batch step                                                                               90    14.1s    3.0%   156ms   2.44GiB    4.0%  27.7MiB
     testing                                                                                  45    7.41s    1.6%   165ms   1.19GiB    2.0%  27.0MiB
 ConvMixer(:small; pretrain=false, inchannels=3, nclasses=10)                                  1    1.56s    0.3%   1.56s    326MiB    0.5%   326MiB
   Load                                                                                        1   49.8ms    0.0%  49.8ms    155MiB    0.2%   155MiB
   Training                                                                                    1    1.46s    0.3%   1.46s   15.6MiB    0.0%  15.6MiB
     batch step                                                                                1    802ms    0.2%   802ms   2.13MiB    0.0%  2.13MiB
 ConvNeXt(:small; pretrain=false, inchannels=3, nclasses=10)                                   1    1.51s    0.3%   1.51s    773MiB    1.2%   773MiB
   Load                                                                                        1   41.3ms    0.0%  41.3ms    377MiB    0.6%   377MiB
   Training                                                                                    1    1.43s    0.3%   1.43s   18.8MiB    0.0%  18.8MiB
     batch step                                                                                1    761ms    0.2%   761ms   4.00MiB    0.0%  4.00MiB
 ViT(:tiny; pretrain=false, inchannels=3, nclasses=10)                                         1    662ms    0.1%   662ms   97.1MiB    0.2%  97.1MiB
   Load                                                                                        1   4.70ms    0.0%  4.70ms   42.0MiB    0.1%  42.0MiB
   Training                                                                                    1    652ms    0.1%   652ms   13.2MiB    0.0%  13.2MiB
     batch step                                                                                1   5.50ms    0.0%  5.50ms    268KiB    0.0%   268KiB
 ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Screenshot 2023-12-21 at 3 53 59 PM
theabhirath commented 9 months ago

I think the Inception errors are because that family of models use an image size of 299x299, and I don't think they support alternate image sizes. AlexNet uses 224x224 and doesn't support anything else either. VGG needs a special imsize parameter passed to it to work for smaller image sizes, as does ViT. So those errors can at least be diagnosed at a glance.

The ResNet family errors are weird. One set seems to be with the larger ResNet variants, which seems to be some sort of memory issue? Correct me if I'm wrong. The Res2Net variant was never tested on GPU unfortunately, which means most likely there is something in the code incompatible with GPU 😬. There's some "clever" code there which I think may not be as GPU compatible.

ToucheSir commented 9 months ago

I can confirm with https://github.com/FluxML/Metalhead.jl/pull/262 that the scalar indexing errors are real. Thanks for helping me corroborate those. For now, those models can be ignored here.

The CUDNNError: CUDNN_STATUS_BAD_PARAM (code 3) errors may be related to running out of memory. Try calling CUDA.reclaim() after GC.gc(true), because the latter is often not sufficient to get CUDA.jl/the CUDA driver to return VRAM to the system. The reason I think this is OOM-related is because that's how the error came up before, and because #262 CI shows the same error for a different set of models (suggesting non-determinism).

darsnack commented 9 months ago

I think the Inception errors are because that family of models use an image size of 299x299, and I don't think they support alternate image sizes. AlexNet uses 224x224 and doesn't support anything else either.

These work with arbitrary sizes, but there is a lower bound on how small an image they can handle. I guess 32x32 is too small.

The ResNet family errors are weird. One set seems to be with the larger ResNet variants, which seems to be some sort of memory issue? Correct me if I'm wrong. The Res2Net variant was never tested on GPU unfortunately, which means most likely there is something in the code incompatible with GPU 😬. There's some "clever" code there which I think may not be as GPU compatible.

The use of MLUtils.chunk on the forward pass has two problems:

Shortened stack trace (cut off the part that is irrelevant):

julia> m(x)
┌ Warning: Performing scalar indexing on task Task (runnable) @0x00007f1ca13fc010.
│ Invocation of getindex resulted in scalar indexing of a GPU array.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on the GPU, but very slowly on the CPU,
│ and therefore are only permitted from the REPL for prototyping purposes.
│ If you did intend to index this array, annotate the caller with @allowscalar.
└ @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:106
┌ Warning: Performing scalar indexing on task Task (runnable) @0x00007f1a2236dc30.
│ Invocation of getindex resulted in scalar indexing of a GPU array.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on the GPU, but very slowly on the CPU,
│ and therefore are only permitted from the REPL for prototyping purposes.
│ If you did intend to index this array, annotate the caller with @allowscalar.
└ @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:106
ERROR: TaskFailedException

    nested task error: TaskFailedException

        nested task error: MethodError: no method matching gemm!(::Val{false}, ::Val{false}, ::Int64, ::Int64, ::Int64, ::Float32, ::CuPtr{Float32}, ::CuPtr{Float32}, ::Float32, ::CuPtr{Float32})

        Closest candidates are:
          gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::Float32, ::Ptr{Float32}, ::Ptr{Float32}, ::Float32, ::Ptr{Float32})
           @ NNlib ~/.julia/packages/NNlib/sXmAj/src/gemm.jl:29
          gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::Float64, ::Ptr{Float64}, ::Ptr{Float64}, ::Float64, ::Ptr{Float64})
           @ NNlib ~/.julia/packages/NNlib/sXmAj/src/gemm.jl:29
          gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::ComplexF64, ::Ptr{ComplexF64}, ::Ptr{ComplexF64}, ::ComplexF64, ::Ptr{ComplexF64})
           @ NNlib ~/.julia/packages/NNlib/sXmAj/src/gemm.jl:29
          ...

        Stacktrace:
         [1] macro expansion
           @ ~/.julia/packages/NNlib/sXmAj/src/impl/conv_im2col.jl:59 [inlined]
         [2] (::NNlib.var"#647#648"{CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, Float32, Float32, SubArray{Float32, 5, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Base.ReshapedArray{Float32, 5, SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, DenseConvDims{3, 3, 3, 6, 3}, Int64, Int64, Int64, UnitRange{Int64}, Int64})()
           @ NNlib ./threadingconstructs.jl:416
    Stacktrace:
     [1] sync_end(c::Channel{Any})
       @ Base ./task.jl:445
     [2] macro expansion
       @ ./task.jl:477 [inlined]
     [3] conv_im2col!(y::SubArray{Float32, 5, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{Float32, 5, Base.ReshapedArray{Float32, 5, SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{3, 3, 3, 6, 3}; col::CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, alpha::Float32, beta::Float32, ntasks::Int64)
       @ NNlib ~/.julia/packages/NNlib/sXmAj/src/impl/conv_im2col.jl:50
     [4] conv_im2col!(y::SubArray{Float32, 5, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{Float32, 5, Base.ReshapedArray{Float32, 5, SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{3, 3, 3, 6, 3})
       @ NNlib ~/.julia/packages/NNlib/sXmAj/src/impl/conv_im2col.jl:23
     [5] (::NNlib.var"#305#309"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, DenseConvDims{3, 3, 3, 6, 3}, SubArray{Float32, 5, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, SubArray{Float32, 5, Base.ReshapedArray{Float32, 5, SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})()
       @ NNlib ./threadingconstructs.jl:416
Stacktrace:
  [1] sync_end(c::Channel{Any})
    @ Base ./task.jl:445
  [2] macro expansion
    @ ./task.jl:477 [inlined]
  [3] conv!(out::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, in1::Base.ReshapedArray{Float32, 5, SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, in2::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{3, 3, 3, 6, 3}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/sXmAj/src/conv.jl:205
  [4] conv!
    @ ~/.julia/packages/NNlib/sXmAj/src/conv.jl:185 [inlined]
  [5] #conv!#264
    @ ~/.julia/packages/NNlib/sXmAj/src/conv.jl:145 [inlined]
  [6] conv!
    @ ~/.julia/packages/NNlib/sXmAj/src/conv.jl:140 [inlined]
  [7] conv(x::SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/sXmAj/src/conv.jl:88
  [8] conv
    @ ~/.julia/packages/NNlib/sXmAj/src/conv.jl:83 [inlined]
  [9] (::Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Bool})(x::SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false})
    @ Flux ~/.julia/packages/Flux/jgpVj/src/layers/conv.jl:202
 [10] macro expansion
    @ ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:53 [inlined]
 [11] _applychain
    @ ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:53 [inlined]
 [12] Chain
    @ ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:51 [inlined]
 [13] |>
    @ ./operators.jl:907 [inlined]
 [14] map (repeats 2 times)
    @ ./tuple.jl:302 [inlined]
 [15] (::Parallel{typeof(Metalhead.cat_channels), Tuple{MeanPool{2, 4}, Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Bool}, BatchNorm{typeof(relu), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Bool}, BatchNorm{typeof(relu), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Bool}, BatchNorm{typeof(relu), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}})(::SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, ::Vararg{SubArray{Float32, 4, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})
    @ Flux ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:541