FluxML / FastAI.jl

Repository of best practices for deep learning in Julia, inspired by fastai
https://fluxml.ai/FastAI.jl
MIT License
585 stars 51 forks source link

Custom learning tasks tutorial gives error #285

Open usiam opened 12 months ago

usiam commented 12 months ago
using Pkg;
Pkg.activate(".")
using FastAI, FastVision, Random, Images
import CairoMakie;
CairoMakie.activate!(type="png");

path = FastAI.load(datasets()["oxford-iiit-pet"])
im_path = joinpath(path, "images")
files = loadfolderdata(im_path; filterfn=FastVision.isimagefile)

function transform_image(image, sz=224)
    image_resized = imresize(convert.(RGB{N0f8}, image), (sz, sz))
    a = permuteddimsview(channelview(image_resized), (2, 3, 1))
end

p = getobs(files, 1)
image = loadfile(p)

label_func(path) = match(r"^(.*)_\d+\.jpg$", pathname(path))[1]
label_func(p)

labels = map(label_func, files)
length(unique(labels))

data = mapobs(files) do file
    return (loadfile(file), label_func(file))
end

idxs = shuffle(1:length(files))
cut = round(Int, 0.8 * length(idxs))
trainidxs, valididxs = idxs[1:cut], idxs[cut+1:end]
trainfiles, validfiles = files[trainidxs], files[valididxs]
summary.((trainfiles, validfiles))

import FastAI.MLUtils

struct SiamesePairs
    labels
    same
    other
    valid
end

function SiamesePairs(labels; valid=false)
    ulabels = unique(labels)
    same = Dict(
        label => [i for (i, l) in enumerate(labels) if l == label]
        for label in ulabels)
    other = Dict(
        label => [i for (i, l) in enumerate(labels) if l != label]
        for label in ulabels)

    return SiamesePairs(labels, same, other, valid)
end

function MLUtils.getobs(si::SiamesePairs, idx::Int)
    rng = si.valid ? MersenneTwister(idx) : Random.GLOBAL_RNG
    if rand(rng) > 0.5
        return ((idx, rand(rng, si.same[si.labels[idx]])), true)
    else
        return ((idx, rand(rng, si.other[si.labels[idx]])), false)
    end
end

MLUtils.numobs(si::SiamesePairs) = length(si.labels)

function siamesedata(files; valid = false, transformfn = identity)
    labels = map(label_func, files)
    si = SiamesePairs(labels; valid = valid)
    return mapobs(si) do obs
        (i, j), same = obs
        image1 = transformfn(loadfile(getobs(files, i)))
        image2 = transformfn(loadfile(getobs(files, j)))
        return ((image1, image2), same)
    end
end

traindata = siamesedata(trainfiles; transformfn=transform_image)
validdata = siamesedata(validfiles; transformfn=transform_image, valid=true);

traindl = FastAI.MLUtils.DataLoader(traindata, 16)

ERROR: MethodError: no method matching MLUtils.DataLoader(::MLUtils.MappedData{:auto, var"#75#76"{typeof(transform_image), ObsView{MLDatasets.FileDataset{typeof(identity), String}, Vector{Int64}}}, SiamesePairs}, ::Int64)

I was trying to recreate the Siamese example in the docs and could not figure out why I am getting this error? And how do I fix this?