hardmaru / WorldModelsExperiments

World Models Experiments
626 stars 172 forks source link

MDN RNN Loss #6

Closed zmonoid closed 6 years ago

zmonoid commented 6 years ago

Hi,

I wonder how is the loss curve for you during training of Doom MDN RNN.

For me the negative log likelihood of next latent state variable prediction is staturated at around 1.0, which means the likelihood is around exp(-1) = 0.36, a low probability. I am replementing this paper, the training of VAE and ES-Controller has confirmed no problem. Just the gap between real and dream world is too large for the controller, seen in below figure. I suspect the MDN RNN was not trained well. May I have your opinion?

image

hardmaru commented 6 years ago

Hi Zhou Bin

Sure I’ll upload the loss curve for doom experiment for vae and mdnrnn next week when I’m back from a conference

On Fri, Jul 20, 2018 at 3:50 PM ZHOU Bin notifications@github.com wrote:

Hi,

I wonder how is the loss curve for you during training of MDN RNN.

For me the negative log likelihood of next latent state variable prediction is staturated at around 1.0, which means the likelihood is around exp(-1) = 0.36, a low probability. I am replementing this paper, the training of VAE and ES-Controller has confirmed no problem. Just the gap between real and dream world is too large for the controller, seen in below figure. I suspect the MDN RNN was not trained well.

[image: image] https://user-images.githubusercontent.com/16687307/42987516-2c309aac-8c2c-11e8-9c23-83ca8c31f1b8.png

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/hardmaru/WorldModelsExperiments/issues/6, or mute the thread https://github.com/notifications/unsubscribe-auth/AGBoHgnYpxKWNp6LdcGYYNcbMZWHTFooks5uIX3GgaJpZM4VXiir .

zmonoid commented 6 years ago

@hardmaru Thanks for your quick reply.

hardmaru commented 6 years ago

Here's a log of the vae and rnn loss in the doom experiment as requested:

https://github.com/hardmaru/WorldModelsExperiments/blob/master/doomrnn/trainlog/gpu_jobs.log.txt

zmonoid commented 6 years ago

@hardmaru Thanks a lot for your help~

zmonoid commented 6 years ago

@hardmaru May I know your version of packages? Strangely for me the VAE loss goes to near 10, rather than yours around 3. In both my implementation and your implementation.

hardmaru commented 6 years ago

below is what I get when I run pip list on the P100 GPU instance. please note that this machine is not used to generate the training data, which was done on the 64-core CPU instance (so vizdoom, doom_py, etc, are not installed here).

Package                            Version     Location                
---------------------------------- ----------- ------------------------
absl-py                            0.2.0       
agents                             1.4.0       
alabaster                          0.7.9       
anaconda-clean                     1.0         
anaconda-client                    1.5.1       
anaconda-navigator                 1.3.1       
argcomplete                        1.0.0       
astor                              0.6.2       
astroid                            1.4.7       
astropy                            2.0.1       
Babel                              2.3.4       
backports.shutil-get-terminal-size 1.0.0       
beautifulsoup4                     4.5.1       
bitarray                           0.8.1       
blaze                              0.10.1      
bleach                             1.5.0       
bokeh                              0.12.2      
boto                               2.42.0      
Bottleneck                         1.2.1       
Box2D                              2.3.2       
cffi                               1.7.0       
chest                              0.2.3       
click                              6.6         
cloudpickle                        0.2.1       
clyent                             1.2.2       
cma                                2.2.0       
colorama                           0.3.7       
conda                              4.3.27      
conda-build                        2.0.2       
configobj                          5.0.6       
contextlib2                        0.5.3       
cryptography                       1.5         
cycler                             0.10.0      
Cython                             0.24.1      
cytoolz                            0.8.0       
dask                               0.15.2      
datashape                          0.5.2       
decorator                          4.0.10      
dill                               0.2.5       
docutils                           0.12        
dynd                               0.7.3.dev1  
et-xmlfile                         1.0.1       
fastcache                          1.0.2       
filelock                           2.0.6       
Flask                              0.11.1      
Flask-Cors                         2.1.2       
gast                               0.2.0       
gevent                             1.1.2       
gpustat                            0.3.2       
greenlet                           0.4.10      
grpcio                             1.11.0      
gym                                0.9.2       
h5py                               2.8.0rc1    
HeapDict                           1.0.0       
html5lib                           0.9999999   
idna                               2.1         
imagesize                          0.7.1       
ipykernel                          4.5.0       
ipython                            5.1.0       
ipython-genutils                   0.1.0       
ipywidgets                         5.2.2       
itsdangerous                       0.24        
jdcal                              1.2         
jedi                               0.9.0       
Jinja2                             2.8         
jsonschema                         2.5.1       
jupyter                            1.0.0       
jupyter-client                     4.4.0       
jupyter-console                    5.0.0       
jupyter-core                       4.2.0       
lazy-object-proxy                  1.2.1       
llvmlite                           0.19.0      
locket                             0.2.0       
lxml                               3.6.4       
Markdown                           2.6.11      
MarkupSafe                         0.23        
matplotlib                         2.0.2       
mistune                            0.7.3       
mpi4py                             2.0.0       
mpmath                             0.19        
multipledispatch                   0.4.8       
nb-anacondacloud                   1.2.0       
nb-conda                           2.0.0       
nb-conda-kernels                   2.0.0       
nbconvert                          4.2.0       
nbformat                           4.1.0       
nbpresent                          3.0.2       
networkx                           1.11        
nltk                               3.2.1       
nose                               1.3.7       
notebook                           4.2.3       
numba                              0.34.0      
numexpr                            2.6.2       
numpy                              1.13.3      
odo                                0.5.0       
openpyxl                           2.3.2       
pandas                             0.20.3      
partd                              0.3.6       
path.py                            0.0.0       
pathlib2                           2.1.0       
patsy                              0.4.1       
pep8                               1.7.0       
pexpect                            4.0.1       
pickleshare                        0.7.4       
Pillow                             3.3.1       
pip                                18.0        
pkginfo                            1.3.2       
ply                                3.9         
prompt-toolkit                     1.0.3       
protobuf                           3.5.2.post1 
pstar                              0.1.6       
psutil                             4.3.1       
ptyprocess                         0.5.1       
py                                 1.4.31      
pyasn1                             0.1.9       
pybullet                           1.7.5       
pycosat                            0.6.1       
pycparser                          2.14        
pycrypto                           2.6.1       
pycurl                             7.43.0      
pyflakes                           1.3.0       
pyglet                             1.2.4       
Pygments                           2.1.3       
pylint                             1.5.4       
PyOpenGL                           3.1.0       
pyOpenSSL                          16.2.0      
pyparsing                          2.1.4       
pytest                             2.9.2       
python-dateutil                    2.5.3       
pytz                               2016.6.1    
PyWavelets                         0.5.2       
PyYAML                             3.12        
pyzmq                              15.4.0      
qj                                 0.1.4       
QtAwesome                          0.3.3       
qtconsole                          4.2.1       
QtPy                               1.1.2       
redis                              2.10.5      
requests                           2.14.2      
roboschool                         1.0
rope-py3k                          0.9.4.post1 
ruamel-yaml                        -VERSION    
ruamel.yaml                        0.15.37     
scikit-image                       0.13.0      
scikit-learn                       0.19.0      
scipy                              0.19.1      
setuptools                         39.1.0      
simplegeneric                      0.8.1       
singledispatch                     3.4.0.3     
six                                1.11.0      
snowballstemmer                    1.2.1       
sockjs-tornado                     1.0.3       
Sphinx                             1.4.6       
spyder                             3.0.0       
SQLAlchemy                         1.0.13      
statsmodels                        0.8.0       
svgwrite                           1.1.6       
sympy                              1.0         
tables                             3.4.2       
tensorboard                        1.8.0       
tensorflow                         1.8.0       
tensorflow-gpu                     1.8.0       
tensorflow-tensorboard             1.5.1       
tensorsets                         0.1.0       
termcolor                          1.1.0       
terminado                          0.6         
toolz                              0.8.0       
torch                              0.2.0.post4 
torchvision                        0.1.9       
tornado                            4.4.1       
traitlets                          4.3.0       
unicodecsv                         0.14.1      
wcwidth                            0.1.7       
Werkzeug                           0.14.1      
wheel                              0.31.0      
widgetsnbextension                 1.2.6       
wrapt                              1.10.6      
xlrd                               1.0.0       
XlsxWriter                         0.9.3       
xlwt                               1.1.2 
hardmaru commented 6 years ago

For completeness, here is the pip list on the 64-core CPU only machine that was used to generate the training data for V and M (using extract.bash), so doom_py and actual doom gym environments were installed here.

Package                            Version     Location                
---------------------------------- ----------- ------------------------
absl-py                            0.2.1       
alabaster                          0.7.9       
anaconda-clean                     1.0         
anaconda-client                    1.5.1       
anaconda-navigator                 1.3.1       
argcomplete                        1.0.0       
astor                              0.6.2       
astroid                            1.4.7       
astropy                            2.0.1       
Babel                              2.3.4       
backports.shutil-get-terminal-size 1.0.0       
beautifulsoup4                     4.5.1       
bitarray                           0.8.1       
blaze                              0.10.1      
bleach                             1.5.0       
bokeh                              0.12.2      
boto                               2.42.0      
Bottleneck                         1.2.1       
Box2D                              2.3.2       
cffi                               1.7.0       
chest                              0.2.3       
click                              6.6         
cloudpickle                        0.2.1       
clyent                             1.2.2       
cma                                2.2.0       
colorama                           0.3.7       
conda                              4.3.25      
conda-build                        2.0.2       
configobj                          5.0.6       
contextlib2                        0.5.3       
cryptography                       1.5         
cycler                             0.10.0      
Cython                             0.24.1      
cytoolz                            0.8.0       
dask                               0.15.3      
datashape                          0.5.2       
decorator                          4.0.10      
dill                               0.2.5       
docutils                           0.12        
doom-py                            0.0.14      /home/hardmaru/doom-py   
dynd                               0.7.3.dev1  
et-xmlfile                         1.0.1       
fastcache                          1.0.2       
filelock                           2.0.6       
Flask                              0.11.1      
Flask-Cors                         2.1.2       
gast                               0.2.0       
gevent                             1.1.2       
greenlet                           0.4.10      
grpcio                             1.12.0      
gym                                0.9.2       
h5py                               2.7.0       
HeapDict                           1.0.0       
html5lib                           0.9999999   
idna                               2.1         
imagesize                          0.7.1       
ipykernel                          4.5.0       
ipython                            5.1.0       
ipython-genutils                   0.1.0       
ipywidgets                         5.2.2       
itsdangerous                       0.24        
jdcal                              1.2         
jedi                               0.9.0       
Jinja2                             2.8         
jsonschema                         2.5.1       
jupyter                            1.0.0       
jupyter-client                     4.4.0       
jupyter-console                    5.0.0       
jupyter-core                       4.2.0       
lazy-object-proxy                  1.2.1       
llvmlite                           0.19.0      
locket                             0.2.0       
lxml                               3.6.4       
Markdown                           2.6.9       
MarkupSafe                         0.23        
matplotlib                         2.0.2       
mistune                            0.7.3       
mpi4py                             2.0.0       
mpmath                             0.19        
multipledispatch                   0.4.8       
nb-anacondacloud                   1.2.0       
nb-conda                           2.0.0       
nb-conda-kernels                   2.0.0       
nbconvert                          4.2.0       
nbformat                           4.1.0       
nbpresent                          3.0.2       
networkx                           1.11        
nltk                               3.2.1       
nose                               1.3.7       
notebook                           4.2.3       
numba                              0.34.0      
numexpr                            2.6.2       
numpy                              1.13.3      
odo                                0.5.0       
openpyxl                           2.3.2       
pandas                             0.20.3      
partd                              0.3.6       
patch                              1.16        
path.py                            0.0.0       
pathlib2                           2.1.0       
patsy                              0.4.1       
pep8                               1.7.0       
pexpect                            4.0.1       
pickleshare                        0.7.4       
Pillow                             3.3.1       
pip                                10.0.1      
pkginfo                            1.3.2       
ply                                3.9         
ppaquette-gym-doom                 0.0.6       /home/hardmaru/gym-doom  
prompt-toolkit                     1.0.3       
protobuf                           3.4.0       
psutil                             4.3.1       
ptyprocess                         0.5.1       
py                                 1.4.31      
pyasn1                             0.1.9       
pybullet                           1.6.3       
pycosat                            0.6.1       
pycparser                          2.14        
pycrypto                           2.6.1       
pycurl                             7.43.0      
pyflakes                           1.3.0       
pyglet                             1.2.4       
Pygments                           2.1.3       
pylint                             1.5.4       
PyOpenGL                           3.1.0       
pyOpenSSL                          16.2.0      
pyparsing                          2.1.4       
pytest                             2.9.2       
python-dateutil                    2.5.3       
pytz                               2016.6.1    
PyWavelets                         0.5.2       
PyYAML                             3.12        
pyzmq                              15.4.0      
QtAwesome                          0.3.3       
qtconsole                          4.2.1       
QtPy                               1.1.2       
redis                              2.10.5      
requests                           2.14.2      
roboschool                         1.0         
rope-py3k                          0.9.4.post1 
ruamel-yaml                        -VERSION    
scikit-image                       0.13.0      
scikit-learn                       0.19.0      
scipy                              0.19.1      
setuptools                         27.2.0      
simplegeneric                      0.8.1       
singledispatch                     3.4.0.3     
six                                1.10.0      
snowballstemmer                    1.2.1       
sockjs-tornado                     1.0.3       
Sphinx                             1.4.6       
spyder                             3.0.0       
SQLAlchemy                         1.0.13      
statsmodels                        0.8.0       
sympy                              1.0         
tables                             3.4.2       
tensorboard                        1.8.0       
tensorflow                         1.8.0       
tensorflow-tensorboard             0.1.5       
termcolor                          1.1.0       
terminado                          0.6         
toolz                              0.8.0       
torch                              0.2.0.post4 
torchvision                        0.1.9       
tornado                            4.4.1       
traitlets                          4.3.0       
unicodecsv                         0.14.1      
wcwidth                            0.1.7       
Werkzeug                           0.11.11     
wheel                              0.29.0      
widgetsnbextension                 1.2.6       
wrapt                              1.10.6      
xlrd                               1.0.0       
XlsxWriter                         0.9.3       
xlwt                               1.1.2 
hardmaru commented 6 years ago

It might be worthwhile to look at the data created by extract.bash to see if there were any issues. If the doom screenshots produced were somehow rubbish (due to a bad doom install or something), perhaps that might explain the VAE loss not going down?

zmonoid commented 6 years ago

@hardmaru Thanks very much for your reply~~~ Indeed it is the dataset problem, in my implementation, I use vizdoom package and load TakeOver directly instead of ppaquette-gym-doom. I replaced the dataset to the one generated by your code and it will have same loss as yours.

But I am curious about the reason. My data collection code is short, you may check it here https://github.com/zmonoid/WorldModels/blob/master/collect_data.py

The video it collected looks okay: https://zhoubin.me/static/share/data.avi

The training of VAE is like: image

The training of RNN is like: image

Both of VAE and RNN training loss curve look similar to yours.

However, the previous problem that the controller in the virtual does not agree with real environment still exists. I am still working to solve it. image

zmonoid commented 6 years ago

@hardmaru I finally got some result similar to your paper's Fig 28, you may check my repo here.

The problem is caused by the game setting, especially the difficulty level of doom, which ppaquette_gym_doom set it as 5 while by default was 4. The ppaquette_gym_doom wrap uses thread lock which makes the data collection and evolution training very slow. One strange point is that, if I set difficulty as 5, most of collected sequence has less than 100 frames; I compared my game setting carefully with ppaquette_gym_doom, still no cue for the reason currently.

There is a also big gap between your score and my score; still looking for a reason.

hardmaru commented 6 years ago

Hi @zmonoid Thanks for the information I didn't know that the ppaquette_gym_doom env was that different compared to the vizdoom env (I only used it since it was advertised on OpenAI's gym website).

For a pure implementation using vizdoom only, you may want to check out this one:

https://github.com/AdeelMufti/WorldModels

@AdeelMufti was able to (somewhat) replicate the results from scratch in Chainer using his own wrapper over vizdoom, and I think this is what you are trying to do, so perhaps see what he is doing and compare with your approach?

Regards.

zmonoid commented 6 years ago

@hardmaru Thanks for your reply~

To my surprise, your preprocessing of frames also contributes a lot for the problem. This is actually inverting the image, just wonder why this help.

Anyway, thanks very much for your help~

def _process_frame(frame):
  obs = np.array(frame[0:400, :, :]).astype(np.float)/255.0
  obs = np.array(resize(obs, (SCREEN_Y, SCREEN_X)))
  obs = ((1.0 - obs) * 255).round().astype(np.uint8)
  return obs
AdeelMufti commented 6 years ago

Hello @zmonoid. So sorry for the delay in checking this thread. I wasn't aware either that the difficult for the ppaquette_gym_doom was not set to default.

If you use the data collection method in my repo, with the custom wrapper around ViZDoom (it's extremely simple), I think it works decently well. When training a full World Models agent, it's able to get decent scores, but not as high as @hardmaru. I suspect some careful manipulation of temperature will help push it higher.

Good luck!