This library contains a collection of generative models.
It uses trainable
ConditionalDists.jl
that
can be used in conjuction with Flux.jl
models. Probability measures such as KL divergence are defined in
IPMeasures.jl
This package aims
to make experimenting with new models as easy as possible.
As an example, check out how to build a conventional variational autoencoder (VAE) that reconstructs MNIST below.
First we load the MNIST training dataset
using MLDatasets, Flux
train_x, _ = MNIST.traindata(Float32)
flat_x = reshape(train_x, :, size(train_x,3)) |> gpu
data = Flux.Data.DataLoader(flat_x, batchsize=200, shuffle=true);
and define some parameters for a VAE with an input length xlength
and latent
vector of zlength
.
using ConditionalDists
xlength = size(flat_x, 1)
zlength = 2
hdim = 512
hd2 = Int(hdim/2)
We define an encoder
with diagonal variance on the latent dimension,
which is just a Flux model wrapped in a ConditionalMvNormal
. The Flux model
must return a tuple with the appropriate number of parameters - in case of a
MvNormal
two: mean and variance. Hence, the SplitLayer
returns two vectors
of zlength
, one of which (the variance) is constrained to be positive.
using ConditionalDists: SplitLayer
# mapping that will be trained to output mean and variance
enc_map = Chain(Dense(xlength, hdim, relu),
Dense(hdim, hd2, relu),
SplitLayer(hd2, [zlength,zlength], [identity,softplus]))
# conditional encoder (can be called e.g. like `rand(encoder,x)`, see ConditionalDists.jl)
encoder = ConditionalMvNormal(enc_map)
The decoder will return a Multivariate Normal with scalar variance:
dec_map = Chain(Dense(zlength, hd2, relu),
Dense(hd2, hdim, relu),
SplitLayer(hdim, [xlength,1], σ))
decoder = ConditionalMvNormal(dec_map)
Now we can create the VAE model and train it to maximize the ELBO.
using GenerativeModels
model = VAE(zlength, encoder, decoder) |> gpu
loss(x) = -elbo(model,x)
ps = Flux.params(model)
opt = ADAM()
for e in 1:50
@info "Epoch $e" loss(flat_x)
Flux.train!(loss, ps, data, opt)
end
Some test reconstructions and the corresponding latent space are shown below:
model = model |> cpu
test_x, test_y = MNIST.testdata(Float32)
p1 = plot_reconstructions(model, test_x[:,:,1:6])