Closed philtomson closed 5 years ago
I think I figured it out for the autoencoder case (from the modelzoo):
using Flux, Flux.Data.MNIST
using Flux: @epochs, onehotbatch, argmax, mse, throttle
using Base.Iterators: partition
using Juno: @progress
using CuArrays
using Images, ImageView
# Encode MNIST images as compressed vectors that can later be decoded back into
# images.
imgs = MNIST.images()
# Partition into batches of size 1000
data = [float(hcat(vec.(imgs)...)) for imgs in partition(imgs, 1000)]
data = gpu.(data)
N = 32 # Size of the encoding
# You can try to make the encoder/decoder network larger
# Also, the output of encoder is a coding of the given input.
# In this case, the input dimension is 28^2 and the output dimension of
# encoder is 32. This implies that the coding is a compressed representation.
# We can make lossy compression via this `encoder`.
encoder = Dense(28^2, N, relu) |> gpu
decoder = Dense(N, 28^2, relu) |> gpu
m = Chain(encoder, decoder)
loss(x) = mse(m(x), x)
img(x::Vector) = Gray.(reshape(clamp.(x, 0, 1), 28, 28))
function sample()
# 20 random digits
before = [imgs[i] for i in rand(1:length(imgs), 20)]
# Before and after images
after = img.(map(x -> cpu(m)(float(vec(x))).data, before))
# Stack them all together
hcat(vcat.(before, after)...)
end
#evalcb = throttle(() -> @show(loss(data[1])), 5)
s = sample()
guidict = imshow(s)
sleep(0.1) #<- Why is this necessary?
function evalcb()
throttle(@show(loss(data[1])), 1)
canvas = guidict["gui"]["canvas"]
s = sample()
imshow(canvas, s)
end
opt = ADAM(params(m))
@epochs 10 Flux.train!(loss, zip(data), opt, cb = evalcb)
...However, I'm not sure why that sleep there is required after the first imshow. Without it you don't see the graphics window show up until after the whole program is run.
This notebook shows an example that prints the output images in real time as training progresses.
Looking at this GAN implementation in Python/numpy: https://github.com/shinseung428/gan_numpy
Notice the animations there under "Results" showing the improving generation of the numbers.
Is there a way to do this kind of thing in Flux?