Open rolling-robot opened 1 year ago
I was tinkering around with Flux and anomaly detection and found that Flux changes its API starting from 0.13 to what they call "explicit style". So these are changes to support Flux 0.13 and above for DSAD algorithm.
More: https://fluxml.ai/Flux.jl/stable/training/training/#Model-Gradients and https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1
My testing code:
using MLJ using PyCall using Flux using OutlierDetection using OutlierDetectionInterface: Labels, Data using Plots using CategoricalArrays n_train = 200 n_test = 200 ocnn = @load DSADDetector pkg=OutlierDetectionNetworks verbosity=0 skl_ds = pyimport("sklearn.datasets") data, labels = skl_ds.make_moons(n_train, noise=0.1) fig = scatter(data[:,1], data[:,2], marker=:+) anomaly_labels = map((x -> "normal"),labels) push!(anomaly_labels, "outlier") data = vcat(data, [1.5;0.5] |> transpose) push!(anomaly_labels, "outlier") data = vcat(data, [-0.5;0.] |> permutedims) encoder = Chain( Dense(2 => 4, relu, bias=false), Dense(4 => 8, relu, bias=false), Dense(8 => 15, relu, bias=false) ) decoder = Chain( Dense(15 => 8, relu, bias=false), Dense(8 => 4, relu, bias=false), Dense(4 => 2, relu, bias=false)) loss_log = Vector() detector = ocnn( encoder=encoder, decoder=decoder, epochs=120, callback = ( ((m, x) -> ()), ((m, x, y) -> push!(loss_log, mean(m(x))))) ) model, score = OutlierDetection.fit(detector, data |> permutedims, CategoricalArray(anomaly_labels, levels=["normal", "outlier"]), verbosity=2) test = rand(Float64, (n_test,2)) * 11 - ones(n_test,2)*5 ŷ = OutlierDetection.transform(detector, model, test |> permutedims) contour!(fig, range(-2, stop = 3, length = 200), range(-2, stop = 3, length = 200), (x, y) -> first(OutlierDetection.transform(detector, model, [x y] |> permutedims)), levels=30)
Noticeable changes:
I was tinkering around with Flux and anomaly detection and found that Flux changes its API starting from 0.13 to what they call "explicit style". So these are changes to support Flux 0.13 and above for DSAD algorithm.
More: https://fluxml.ai/Flux.jl/stable/training/training/#Model-Gradients and https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1
My testing code:
Noticeable changes: