Closed hanshuo-shuo closed 10 months ago
I use this to solve the problem, but I am not sure this is correct though:
The vmap needs to be initialized again after setting the dropout. Try doing this
def set_drop_out(self, dropout):
for m in self.original_modules.modules():
if isinstance(m, nn.Dropout):
m.p = dropout
fn, params, _ = combine_state_for_ensemble(self.original_modules)
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **self.kwargs)
@aditya-spood Thanks a lot !
Hi @hanshuo-shuo @aditya-spood , thanks for initiating the discussion. Is this an issue with the current code, or does it only occur with some changes? I'd like to dig into this and fix any issues with the codebase (or anything that would make it easier to use/extend). It is possible that some behaviors are inconsistent across package versions, in which case I'd like for the code to be less sensitive to the exact versions installed.
@nicklashansen Hi, thanks for your reply.
https://github.com/hanshuo-shuo/tdmpc2-prey/tree/main
- override hydra/launcher: basic
cpu
: simply change all the device into self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
And I also store the memory on cpu only.Then during training, it produces such errors. I checked pytorch document, and I don't think torch.vmap
has the wrapped attribute. But after all the changes I listed above, I can get a really good training result with the smallest model parameter setting after training for 100000 steps.
Thanks for sharing. In that case, I suspect that it might be due to package versions. Would you mind copy/pasting the output of conda list
or conda env export
here? I'll see if I can reproduce the error on my end.
@nicklashansen Hi, the output is
channels:
- pytorch
- conda-forge
dependencies:
- aom=3.6.1=hb765f3a_0
- brotli-python=1.1.0=py39hb198ff7_1
- bzip2=1.0.8=h93a5062_5
- ca-certificates=2023.7.22=hf0a4a13_0
- cairo=1.18.0=hd1e100b_0
- dav1d=1.2.1=hb547adb_0
- expat=2.5.0=hb7217d7_1
- ffmpeg=6.0.0=gpl_h1ceb99f_105
- filelock=3.13.1=pyhd8ed1ab_0
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
- font-ttf-inconsolata=3.000=h77eed37_0
- font-ttf-source-code-pro=2.038=h77eed37_0
- font-ttf-ubuntu=0.83=hab24e00_0
- fontconfig=2.14.2=h82840c6_0
- fonts-conda-ecosystem=1=0
- fonts-conda-forge=1=0
- freetype=2.12.1=hadb7bae_2
- fribidi=1.0.10=h27ca646_0
- gettext=0.21.1=h0186832_0
- gmp=6.2.1=h9f76cd9_0
- gmpy2=2.1.2=py39h0b4f9c6_1
- gnutls=3.7.8=h9f1a10d_0
- graphite2=1.3.13=h9f76cd9_1001
- harfbuzz=8.2.1=hf1a6348_0
- icu=73.2=hc8870d7_0
- idna=3.4=pyhd8ed1ab_0
- jinja2=3.1.2=pyhd8ed1ab_1
- lame=3.100=h1a8c8d9_1003
- lcms2=2.15=hf2736f0_3
- lerc=4.0.0=h9a09cb3_0
- libass=0.17.1=hf7da4fe_1
- libblas=3.9.0=19_osxarm64_openblas
- libcblas=3.9.0=19_osxarm64_openblas
- libcxx=16.0.6=h4653b0c_0
- libdeflate=1.19=hb547adb_0
- libexpat=2.5.0=hb7217d7_1
- libffi=3.4.2=h3422bc3_5
- libgfortran=5.0.0=13_2_0_hd922786_1
- libgfortran5=13.2.0=hf226fd6_1
- libglib=2.78.1=hd9b11f9_0
- libiconv=1.17=he4db4b2_0
- libidn2=2.3.4=h1a8c8d9_0
- libjpeg-turbo=3.0.0=hb547adb_1
- liblapack=3.9.0=19_osxarm64_openblas
- libopenblas=0.3.24=openmp_hd76b1f2_0
- libopus=1.3.1=h27ca646_1
- libpng=1.6.39=h76d750c_0
- libsqlite=3.43.2=h091b4b1_0
- libtasn1=4.19.0=h1a8c8d9_0
- libtiff=4.6.0=ha8a6c65_2
- libunistring=0.9.10=h3422bc3_0
- libvpx=1.13.1=hb765f3a_0
- libwebp-base=1.3.2=hb547adb_0
- libxcb=1.15=hf346824_0
- libxml2=2.11.5=h25269f3_1
- libzlib=1.2.13=h53f4e23_5
- llvm-openmp=17.0.4=hcd81f8e_0
- markupsafe=2.1.3=py39h0f82c59_1
- mpc=1.3.1=h91ba8db_0
- mpfr=4.2.1=h9546428_0
- mpmath=1.3.0=pyhd8ed1ab_0
- ncurses=6.4=h7ea286d_0
- nettle=3.8.1=h63371fa_1
- openh264=2.3.1=hb7217d7_2
- openjpeg=2.5.0=h4c1507b_3
- openssl=3.1.4=h0d3ecfb_0
- p11-kit=0.24.1=h29577a5_0
- pcre2=10.40=hb34f9b4_0
- pillow=10.1.0=py39h755f0b7_0
- pip=23.2.1=pyhd8ed1ab_0
- pixman=0.42.2=h13dd4ca_0
- pthread-stubs=0.4=h27ca646_1001
- pysocks=1.7.1=pyha2e5f31_6
- python=3.9.7=hc0da0df_3_cpython
- python_abi=3.9=4_cp39
- pytorch=2.1.0=py3.9_0
- pyyaml=6.0.1=py39h0f82c59_1
- readline=8.2=h92ec313_1
- requests=2.31.0=pyhd8ed1ab_0
- setuptools=68.2.2=pyhd8ed1ab_0
- sqlite=3.43.2=hf2abe2d_0
- svt-av1=1.7.0=hb765f3a_0
- sympy=1.12=pypyh9d50eac_103
- tk=8.6.13=hb31c410_0
- torchaudio=2.1.0=py39_cpu
- torchvision=0.16.0=py39_cpu
- typing_extensions=4.8.0=pyha770c72_0
- wheel=0.41.2=pyhd8ed1ab_0
- x264=1!164.3095=h57fd34a_2
- x265=3.5=hbc6ce65_3
- xorg-libxau=1.0.11=hb547adb_0
- xorg-libxdmcp=1.1.3=h27ca646_0
- xz=5.2.6=h57fd34a_0
- yaml=0.2.5=h3422bc3_2
- zlib=1.2.13=h53f4e23_5
- zstd=1.5.5=h4f39d0f_0
- pip:
- absl-py==2.0.0
- antlr4-python3-runtime==4.9.3
- astunparse==1.6.3
- cachetools==5.3.1
- cellworld==0.0.376
- certifi==2023.7.22
- charset-normalizer==3.3.0
- chex==0.1.83
- cloudpickle==2.2.1
- contourpy==1.1.1
- crafter==1.8.1
- cv==1.0.0
- cycler==0.12.1
- decorator==5.1.1
- dm-tree==0.1.8
- farama-notifications==0.0.4
- flatbuffers==23.5.26
- fonttools==4.43.1
- fsspec==2023.10.0
- gast==0.5.4
- google-auth==2.23.3
- google-auth-oauthlib==1.0.0
- google-pasta==0.2.0
- grpcio==1.59.0
- gym==0.26.2
- gym-notices==0.0.8
- gymnasium==0.29.1
- h5py==3.10.0
- huggingface-hub==0.17.3
- hydra-core==1.3.2
- imageio==2.31.5
- importlib-metadata==6.8.0
- importlib-resources==6.1.0
- jax==0.4.18
- jaxlib==0.4.18
- json-cpp==1.0.91
- keras==2.14.0
- kiwisolver==1.4.5
- libclang==16.0.6
- markdown==3.5
- markdown-it-py==3.0.0
- matplotlib==3.8.0
- mdurl==0.1.2
- ml-dtypes==0.2.0
- networkx==3.1
- numpy==1.26.1
- oauthlib==3.2.2
- omegaconf==2.3.0
- opensimplex==0.4.5
- opt-einsum==3.3.0
- optax==0.1.7
- packaging==23.2
- pandas==2.1.2
- pettingzoo==1.24.1
- protobuf==4.24.4
- pyasn1==0.5.0
- pyasn1-modules==0.3.0
- pygments==2.16.1
- pyparsing==3.1.1
- python-dateutil==2.8.2
- pytz==2023.3.post1
- regex==2023.10.3
- requests-oauthlib==1.3.1
- rich==13.6.0
- rsa==4.9
- ruamel-yaml==0.17.35
- ruamel-yaml-clib==0.2.8
- safetensors==0.4.0
- scipy==1.11.3
- six==1.16.0
- stable-baselines3==2.1.0
- supersuit==3.9.0
- tcp-messages==1.0.45
- tensorboard==2.14.1
- tensorboard-data-server==0.7.1
- tensordict==0.2.1
- tensordict-nightly==2023.6.8
- tensorflow==2.14.0
- tensorflow-estimator==2.14.0
- tensorflow-io-gcs-filesystem==0.34.0
- tensorflow-macos==2.14.0
- tensorflow-probability==0.22.0
- termcolor==2.3.0
- tinyscaler==1.2.7
- tokenizers==0.14.1
- toolz==0.12.0
- torch==2.1.0
- torchrl==0.2.1
- tqdm==4.66.1
- transformers==4.34.1
- typing-extensions==4.5.0
- tzdata==2023.3
- urllib3==2.0.6
- werkzeug==3.0.0
- wrapt==1.14.1
- zipp==3.17.0
Sorry, my env is a mess.
Hello! Thank you for sharing your great work with all of us.
I also got the same errors when I ran the following command, which is from README instructions. I did not change anything in the code.
python train.py task=dog-run steps=7000000
The problem is solved with the advice in this thread.
I wonder if these revision does not change the final result.
@purewater0901 I get same error too,but it still dosn't work after I edit the code like above. Could you please tell me how you fix it.
@VitaLemonTea1 I only did this part and it works quite well(Just overwrite it on previous layer's code)
I use this to solve the problem, but I am not sure this is correct though:
@VitaLemonTea1 @hanshuo-shuo Yes, I changed the code as @hanshuo-shuo showed here.
You can also add a new function to this class that is indicated in this thread, but I'm not sure if it's necessary.
@purewater0901 @hanshuo-shuo Thanks a lot!
Hi all,
Thank you for your patience. I spent a few hours investigating this today, and it appears that the existence of vmap.__wrapped__
in torch/functorch was very short-lived. I have issued a fix here https://github.com/nicklashansen/tdmpc2/commit/f3139291e2dc8e47480184a4a1bce05e8980caa3 which removes the modules()
function call altogether. This specific part of the implementation is not critical, and I have verified that results are reproduced on a handful of tasks.
This commit also updates environment.yaml
to circumvent the broken gym
setup mentioned here https://github.com/openai/gym/issues/3176 and here https://github.com/nicklashansen/tdmpc2/issues/2.
I'll close this issue, but please feel free to re-open if you run into any other issues.
Hi, Thanks for your code~ I try to apply the code to my environment. During training, I found out
I was wondering if you also had issues like this, is it due to some uninstalled module?