jbornschein / draw

Reimplementation of DRAW
MIT License
347 stars 86 forks source link

Multichannel training and sampling #21

Closed dribnet closed 9 years ago

dribnet commented 9 years ago

Starting from @jbornschein's multichannel branch orphaned a few months ago, I was able to get multichannel DRAW working. Turns out most of the lingering issues were in sampling, not training. Other than a small normalization bug the main work was in establishing some initial color output, and then refining it.

Perhaps overkill, but here are some sample outputs from this branch and settings used in case others want to replicate my results as a starting point.


First I used the regular mnist dataset - useful as a baseline and to show that this branch is backwards compatible with single channel capability introduces in #19.

sequence

train-draw.py --dataset=mnist --attention=2,5 --niter=64 --lr=2e-4
epochs 130
test_nll 90.6

To get started I found it useful to roll my own color version of mnist. Once it was working, the result was kind of cute.

sequence

train-draw.py --dataset=colormnist --attention=2,5 --niter=64 --lr=3e-4
epochs 108
test_nll 164.4

The first real dataset I tried was svhn2, loosely following params from the DRAW paper.

sequence

train-draw.py --dataset=svhn2 --attention=5,5 --niter=32 --lr=3e-4 --enc-dim=512 --dec-dim=512
epochs 154
test_nll 1827.4

And finally cifar10 back on the earlier settings.

sequence

train-draw.py --dataset=cifar10 --attention=2,5 --niter=64 --lr=3e-4
epochs 60
test_nll 1817.2
AjayTalati commented 9 years ago

Ha, nice, 8 bit color MNIST, cute!

Be interesting if you can get this implementation of DRAW to work on a domain specific decoder/scene rendering engine - here's an example? Large (eg 150x150 pixel), multichannel images, seem to be particularly challenging/useful.

dribnet commented 9 years ago

Yeah, maybe I'll make colormnist available - it trains quickly and was really helpful for debugging. Here's an example where I had the column/row/channel ordering all messed up and the poor neural net was still doing a decent job compensating for my bug.

sequence

jbornschein commented 9 years ago

Awesome!

I think we should add one multi-channel example to the top-level README. Maybe the SVHN samples? I think they look more impressive than the CIFAR10 samples.

dribnet commented 9 years ago

@jbornschein - I've updated the README in my master branch with a fresh SVHN example as well as instructions for downloading and running the model that generated it. I did do some light editing - mainly on what is saved each epoch - but I left as-is other parts I think might not be working that I don't use (bokeh-server, log visualization) .

So feel free to merge my README with the image/model update, though it might also be good to also do general refresh of the rest of what's there.