nicklashansen / tdmpc2

Code for "TD-MPC2: Scalable, Robust World Models for Continuous Control"
https://www.tdmpc2.com
MIT License
343 stars 71 forks source link

Error during training #3

Closed hanshuo-shuo closed 10 months ago

hanshuo-shuo commented 11 months ago

Hi, Thanks for your code~ I try to apply the code to my environment. During training, I found out

image image

I was wondering if you also had issues like this, is it due to some uninstalled module?

hanshuo-shuo commented 10 months ago

I use this to solve the problem, but I am not sure this is correct though:

image
aditya-spood commented 10 months ago

The vmap needs to be initialized again after setting the dropout. Try doing this

        def set_drop_out(self, dropout):
                for m in self.original_modules.modules():
                        if isinstance(m, nn.Dropout):
                                m.p = dropout
                fn, params, _ = combine_state_for_ensemble(self.original_modules)
                self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **self.kwargs)
hanshuo-shuo commented 10 months ago

@aditya-spood Thanks a lot !

nicklashansen commented 10 months ago

Hi @hanshuo-shuo @aditya-spood , thanks for initiating the discussion. Is this an issue with the current code, or does it only occur with some changes? I'd like to dig into this and fix any issues with the codebase (or anything that would make it easier to use/extend). It is possible that some behaviors are inconsistent across package versions, in which case I'd like for the code to be less sensitive to the exact versions installed.

hanshuo-shuo commented 10 months ago

@nicklashansen Hi, thanks for your reply.

https://github.com/hanshuo-shuo/tdmpc2-prey/tree/main

    - override hydra/launcher: basic

Then during training, it produces such errors. I checked pytorch document, and I don't think torch.vmap has the wrapped attribute. But after all the changes I listed above, I can get a really good training result with the smallest model parameter setting after training for 100000 steps.

nicklashansen commented 10 months ago

Thanks for sharing. In that case, I suspect that it might be due to package versions. Would you mind copy/pasting the output of conda list or conda env export here? I'll see if I can reproduce the error on my end.

hanshuo-shuo commented 10 months ago

@nicklashansen Hi, the output is

channels:
  - pytorch
  - conda-forge
dependencies:
  - aom=3.6.1=hb765f3a_0
  - brotli-python=1.1.0=py39hb198ff7_1
  - bzip2=1.0.8=h93a5062_5
  - ca-certificates=2023.7.22=hf0a4a13_0
  - cairo=1.18.0=hd1e100b_0
  - dav1d=1.2.1=hb547adb_0
  - expat=2.5.0=hb7217d7_1
  - ffmpeg=6.0.0=gpl_h1ceb99f_105
  - filelock=3.13.1=pyhd8ed1ab_0
  - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
  - font-ttf-inconsolata=3.000=h77eed37_0
  - font-ttf-source-code-pro=2.038=h77eed37_0
  - font-ttf-ubuntu=0.83=hab24e00_0
  - fontconfig=2.14.2=h82840c6_0
  - fonts-conda-ecosystem=1=0
  - fonts-conda-forge=1=0
  - freetype=2.12.1=hadb7bae_2
  - fribidi=1.0.10=h27ca646_0
  - gettext=0.21.1=h0186832_0
  - gmp=6.2.1=h9f76cd9_0
  - gmpy2=2.1.2=py39h0b4f9c6_1
  - gnutls=3.7.8=h9f1a10d_0
  - graphite2=1.3.13=h9f76cd9_1001
  - harfbuzz=8.2.1=hf1a6348_0
  - icu=73.2=hc8870d7_0
  - idna=3.4=pyhd8ed1ab_0
  - jinja2=3.1.2=pyhd8ed1ab_1
  - lame=3.100=h1a8c8d9_1003
  - lcms2=2.15=hf2736f0_3
  - lerc=4.0.0=h9a09cb3_0
  - libass=0.17.1=hf7da4fe_1
  - libblas=3.9.0=19_osxarm64_openblas
  - libcblas=3.9.0=19_osxarm64_openblas
  - libcxx=16.0.6=h4653b0c_0
  - libdeflate=1.19=hb547adb_0
  - libexpat=2.5.0=hb7217d7_1
  - libffi=3.4.2=h3422bc3_5
  - libgfortran=5.0.0=13_2_0_hd922786_1
  - libgfortran5=13.2.0=hf226fd6_1
  - libglib=2.78.1=hd9b11f9_0
  - libiconv=1.17=he4db4b2_0
  - libidn2=2.3.4=h1a8c8d9_0
  - libjpeg-turbo=3.0.0=hb547adb_1
  - liblapack=3.9.0=19_osxarm64_openblas
  - libopenblas=0.3.24=openmp_hd76b1f2_0
  - libopus=1.3.1=h27ca646_1
  - libpng=1.6.39=h76d750c_0
  - libsqlite=3.43.2=h091b4b1_0
  - libtasn1=4.19.0=h1a8c8d9_0
  - libtiff=4.6.0=ha8a6c65_2
  - libunistring=0.9.10=h3422bc3_0
  - libvpx=1.13.1=hb765f3a_0
  - libwebp-base=1.3.2=hb547adb_0
  - libxcb=1.15=hf346824_0
  - libxml2=2.11.5=h25269f3_1
  - libzlib=1.2.13=h53f4e23_5
  - llvm-openmp=17.0.4=hcd81f8e_0
  - markupsafe=2.1.3=py39h0f82c59_1
  - mpc=1.3.1=h91ba8db_0
  - mpfr=4.2.1=h9546428_0
  - mpmath=1.3.0=pyhd8ed1ab_0
  - ncurses=6.4=h7ea286d_0
  - nettle=3.8.1=h63371fa_1
  - openh264=2.3.1=hb7217d7_2
  - openjpeg=2.5.0=h4c1507b_3
  - openssl=3.1.4=h0d3ecfb_0
  - p11-kit=0.24.1=h29577a5_0
  - pcre2=10.40=hb34f9b4_0
  - pillow=10.1.0=py39h755f0b7_0
  - pip=23.2.1=pyhd8ed1ab_0
  - pixman=0.42.2=h13dd4ca_0
  - pthread-stubs=0.4=h27ca646_1001
  - pysocks=1.7.1=pyha2e5f31_6
  - python=3.9.7=hc0da0df_3_cpython
  - python_abi=3.9=4_cp39
  - pytorch=2.1.0=py3.9_0
  - pyyaml=6.0.1=py39h0f82c59_1
  - readline=8.2=h92ec313_1
  - requests=2.31.0=pyhd8ed1ab_0
  - setuptools=68.2.2=pyhd8ed1ab_0
  - sqlite=3.43.2=hf2abe2d_0
  - svt-av1=1.7.0=hb765f3a_0
  - sympy=1.12=pypyh9d50eac_103
  - tk=8.6.13=hb31c410_0
  - torchaudio=2.1.0=py39_cpu
  - torchvision=0.16.0=py39_cpu
  - typing_extensions=4.8.0=pyha770c72_0
  - wheel=0.41.2=pyhd8ed1ab_0
  - x264=1!164.3095=h57fd34a_2
  - x265=3.5=hbc6ce65_3
  - xorg-libxau=1.0.11=hb547adb_0
  - xorg-libxdmcp=1.1.3=h27ca646_0
  - xz=5.2.6=h57fd34a_0
  - yaml=0.2.5=h3422bc3_2
  - zlib=1.2.13=h53f4e23_5
  - zstd=1.5.5=h4f39d0f_0
  - pip:
    - absl-py==2.0.0
    - antlr4-python3-runtime==4.9.3
    - astunparse==1.6.3
    - cachetools==5.3.1
    - cellworld==0.0.376
    - certifi==2023.7.22
    - charset-normalizer==3.3.0
    - chex==0.1.83
    - cloudpickle==2.2.1
    - contourpy==1.1.1
    - crafter==1.8.1
    - cv==1.0.0
    - cycler==0.12.1
    - decorator==5.1.1
    - dm-tree==0.1.8
    - farama-notifications==0.0.4
    - flatbuffers==23.5.26
    - fonttools==4.43.1
    - fsspec==2023.10.0
    - gast==0.5.4
    - google-auth==2.23.3
    - google-auth-oauthlib==1.0.0
    - google-pasta==0.2.0
    - grpcio==1.59.0
    - gym==0.26.2
    - gym-notices==0.0.8
    - gymnasium==0.29.1
    - h5py==3.10.0
    - huggingface-hub==0.17.3
    - hydra-core==1.3.2
    - imageio==2.31.5
    - importlib-metadata==6.8.0
    - importlib-resources==6.1.0
    - jax==0.4.18
    - jaxlib==0.4.18
    - json-cpp==1.0.91
    - keras==2.14.0
    - kiwisolver==1.4.5
    - libclang==16.0.6
    - markdown==3.5
    - markdown-it-py==3.0.0
    - matplotlib==3.8.0
    - mdurl==0.1.2
    - ml-dtypes==0.2.0
    - networkx==3.1
    - numpy==1.26.1
    - oauthlib==3.2.2
    - omegaconf==2.3.0
    - opensimplex==0.4.5
    - opt-einsum==3.3.0
    - optax==0.1.7
    - packaging==23.2
    - pandas==2.1.2
    - pettingzoo==1.24.1
    - protobuf==4.24.4
    - pyasn1==0.5.0
    - pyasn1-modules==0.3.0
    - pygments==2.16.1
    - pyparsing==3.1.1
    - python-dateutil==2.8.2
    - pytz==2023.3.post1
    - regex==2023.10.3
    - requests-oauthlib==1.3.1
    - rich==13.6.0
    - rsa==4.9
    - ruamel-yaml==0.17.35
    - ruamel-yaml-clib==0.2.8
    - safetensors==0.4.0
    - scipy==1.11.3
    - six==1.16.0
    - stable-baselines3==2.1.0
    - supersuit==3.9.0
    - tcp-messages==1.0.45
    - tensorboard==2.14.1
    - tensorboard-data-server==0.7.1
    - tensordict==0.2.1
    - tensordict-nightly==2023.6.8
    - tensorflow==2.14.0
    - tensorflow-estimator==2.14.0
    - tensorflow-io-gcs-filesystem==0.34.0
    - tensorflow-macos==2.14.0
    - tensorflow-probability==0.22.0
    - termcolor==2.3.0
    - tinyscaler==1.2.7
    - tokenizers==0.14.1
    - toolz==0.12.0
    - torch==2.1.0
    - torchrl==0.2.1
    - tqdm==4.66.1
    - transformers==4.34.1
    - typing-extensions==4.5.0
    - tzdata==2023.3
    - urllib3==2.0.6
    - werkzeug==3.0.0
    - wrapt==1.14.1
    - zipp==3.17.0

Sorry, my env is a mess.

purewater0901 commented 10 months ago

Hello! Thank you for sharing your great work with all of us.

I also got the same errors when I ran the following command, which is from README instructions. I did not change anything in the code.

python train.py task=dog-run steps=7000000

The problem is solved with the advice in this thread.

I wonder if these revision does not change the final result.

VitaLemonTea1 commented 10 months ago

@purewater0901 I get same error too,but it still dosn't work after I edit the code like above. Could you please tell me how you fix it.

hanshuo-shuo commented 10 months ago

@VitaLemonTea1 I only did this part and it works quite well(Just overwrite it on previous layer's code)

I use this to solve the problem, but I am not sure this is correct though: image

purewater0901 commented 10 months ago

@VitaLemonTea1 @hanshuo-shuo Yes, I changed the code as @hanshuo-shuo showed here.

You can also add a new function to this class that is indicated in this thread, but I'm not sure if it's necessary.

VitaLemonTea1 commented 10 months ago

@purewater0901 @hanshuo-shuo Thanks a lot!

nicklashansen commented 10 months ago

Hi all,

Thank you for your patience. I spent a few hours investigating this today, and it appears that the existence of vmap.__wrapped__ in torch/functorch was very short-lived. I have issued a fix here https://github.com/nicklashansen/tdmpc2/commit/f3139291e2dc8e47480184a4a1bce05e8980caa3 which removes the modules() function call altogether. This specific part of the implementation is not critical, and I have verified that results are reproduced on a handful of tasks.

This commit also updates environment.yaml to circumvent the broken gym setup mentioned here https://github.com/openai/gym/issues/3176 and here https://github.com/nicklashansen/tdmpc2/issues/2.

I'll close this issue, but please feel free to re-open if you run into any other issues.