FluxML / MLJFlux.jl

Wrapping deep learning models from the package Flux.jl for use in the MLJ.jl toolbox
http://fluxml.ai/MLJFlux.jl/
MIT License
143 stars 17 forks source link

Bump compat for Metalhead #232

Closed ablaom closed 11 months ago

ablaom commented 11 months ago

This PR bumps the [compat] for Metalhead to "0.8" and addresses resulting breakages.

Replaces #226

ablaom commented 11 months ago

Failing on GPU only. The complaint is about scalar indexing.

I've spent some time on this today but this is hard for me to debug because I don't currently have GPU access. I conjecture that the following code is failing on a GPU but not a CPU, and this contains the issue. Be good if someone can confirm this indeed fails. And if so, where is the scalar indexing?

import Flux
import MLJFlux
import StableRNGs.StableRNG

rng = StableRNG(123)
X, y = MLJFlux.make_images(rng);

typeof(X)
# Vector{Matrix{Gray{Float64}}}

data = MLJFlux.collate(ImageClassifier(), X, y);

Flux.gpu(data) # no effect on my CPU-only machine
typeof(data)
# Tuple{Vector{Array{Float32, 4}}, Vector{OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}}

n_channels = 1
n_classes = 3
init = Flux.glorot_uniform(rng)

chain = Flux.Chain(
    Flux.Conv((2, 2), n_channels=>2, init=init),
    Flux.Conv((2, 2), 2=>1, init=init),
    x->reshape(x, :, size(x)[end]),
    Flux.Dense(16, n_classes, init=init))

x = data[1][1]
typeof(x)
# Array{Float32, 4}

sizeof(x)
# (6, 6, 1, 1)

chain(x)
ablaom commented 11 months ago

Okay. I guess the reshape is the issue. https://github.com/JuliaGPU/CUDA.jl/issues/228 .

But these docs say that reshape is special-cased by CuArrays. Mmm...

ablaom commented 11 months ago

Okay, looks like the colon is not supported in the CuArray-specialized reshape.

mohamed82008 commented 11 months ago

@ablaom would you also be open to making Metalhead an optional dependency?

ablaom commented 11 months ago

@ablaom would you also be open to making Metalhead an optional dependency?

That's tricky because the default builder for ImageClassifier is a VGG architecture from Metalhead.jl. We could throw an error if Metalhead.jl is not loaded, but that would violate a general principle currently holding true among 200+ MLJ models (and assumed in applications of MLJTestIntegration.jl): the empty argument constructor always works unless the model is a wrapper, like TunedModel.

An alternative, which already looks too complicated to me, is to make the default

Related discussion: https://github.com/FluxML/MLJFlux.jl/issues/162

Maybe you have a better idea?

mohamed82008 commented 11 months ago

Not a better idea but moving ImageClassifier itself out to another package is another option. Not a great option though if you want all Flux-related wrappers to be in this repo.

ToucheSir commented 11 months ago

Okay, looks like the colon is not supported in the CuArray-specialized reshape.

That seems strange to me, because it's basically what MLUtils.flatten does and I don't recall having any issues with it. Indeed, making the same reshape calls manually results in a CuArray.

I had a look through the failing CI runs, and the problem is instead this warning: https://buildkite.com/julialang/mljflux-dot-jl/builds/339#018a25d7-e2c7-4c2b-a7af-c1a9d97436c4/425-777. Because we switched to package extensions in Flux 0.14, cuDNN needs to be separately added to an environment to enable the CUDA conv routines in NNlib. I'm guessing MLJFlux doesn't want to take it on as a dep, so adding it into your test env/extras should be enough.

codecov-commenter commented 11 months ago

Codecov Report

Patch coverage: 25.00% and project coverage change: -1.21% :warning:

Comparison is base (19dc08b) 93.26% compared to head (fa7133b) 92.06%.

:exclamation: Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## dev #232 +/- ## ========================================== - Coverage 93.26% 92.06% -1.21% ========================================== Files 12 12 Lines 312 315 +3 ========================================== - Hits 291 290 -1 - Misses 21 25 +4 ``` | [Files Changed](https://app.codecov.io/gh/FluxML/MLJFlux.jl/pull/232?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Flux) | Coverage Δ | | |---|---|---| | [src/metalhead.jl](https://app.codecov.io/gh/FluxML/MLJFlux.jl/pull/232?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Flux#diff-c3JjL21ldGFsaGVhZC5qbA==) | `80.64% <14.28%> (-12.22%)` | :arrow_down: | | [src/mlj\_model\_interface.jl](https://app.codecov.io/gh/FluxML/MLJFlux.jl/pull/232?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Flux#diff-c3JjL21sal9tb2RlbF9pbnRlcmZhY2Uuamw=) | `94.20% <100.00%> (ø)` | |

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

ablaom commented 11 months ago

Thanks indeed for the help @ToucheSir. I guess that scalar indexing error was a Red Herring. I've added cuDNN to the tests deps and there's no sign of it now.