alex-petrenko / sample-factory

High throughput synchronous and asynchronous reinforcement learning
https://samplefactory.dev
MIT License
835 stars 113 forks source link

Cannot reproduce scores on dmlab-30 #109

Closed sungwoong closed 3 years ago

sungwoong commented 3 years ago

Hi @alex-petrenko,

I ran the codes on dmlab-30 with the exactly same arguments/configurations in README. However, as shown in the below figure, the obtained scores (mean capped) are lower than the reference scores in the paper (Fig. 5) by about 10%. This performance gap is still within a reasonable range due to the randomness? Screen Shot 2021-05-06 at 11 29 28 PM

In addition, when I increase the rollout/recurrence length to 100 or incorporate a previous action and reward into the current LSTM inputs like the original IMPALA, the scores are more decreased. Did you observe the similar results or is there any reason for these differences of your setting/architecture compared to the original IMPALA?

alex-petrenko commented 3 years ago

Thank you for reporting this! Worst case scenario, it might be a regression, or perhaps a default parameter value change. Did you run on the latest version or on the version of code tagged ICML 2020? https://github.com/alex-petrenko/sample-factory/releases/tag/1.0.0

We'll be looking into this. @BoyuanLong we might need to re-run this on one of our servers.

sungwoong commented 3 years ago

@alex-petrenko Thank you for your answer. I ran the latest codes from the master branch. Also, is there any reason for your differences such as the short rollout/recurrence length (32) and excluding the previous actions and rewards from the core LSTM inputs, compared with the original IMPALA?

alex-petrenko commented 3 years ago

@sungwoong I am looking into this now. I believe I found the original experimental run that we did for the paper. https://drive.google.com/file/d/1WcGGOHeok9BroxXLg5zIXhsAHgYEcdK2/view?usp=sharing

If you could similarly package your run and send it to me, it'd be very helpful! We can compare cfg.json to see if there are parameter changes, and besides, differences in performance on individual envs and metrics such as entropy and losses can point to potential problems.

It can also be just a lucky/unlucky run... PBT is usually very consistent and the difference between results is quite big, but it is not entirely unlikely. These experiments take a substantial amount of time to run, and I'd be lying if I say that I ran enough of them to get a good feel for the degree of random variation.

Regarding the rollout length, it is a tricky parameter. By increasing the rollout from 32 to 100 you increase the policy lag substantially, which has implications for learning even with Vtrace enabled. To keep policy lag the same you'd need to decrease the total number of envs by a similar factor (~3), so for example change the number of envs per worker from 12 to 4. It is also hard to predict how backpropagation through time is going to affect learning, since you will backprop through 100 steps instead of 32. DMLab-30 has very few envs where this is beneficial, so it might harm performance on easier envs? Another implication of a longer rollout is the fact that Monte-Carlo return estimates will be noisier (higher variance). With short rollouts they are bootstrapped by the critic. But gamma 0.99 is quite low... so this should not be a big problem? There are all just hypotheses, you never know with RL. From my experience longer rollouts/bptt usually works but is rarely beneficial. RNNs just suck in general with very long rollouts, especially in RL.

Regarding adding actions/rewards as policy inputs, I have no idea. I have never tried this with DMLab-30. Should not harm the results in theory, but I never found a setting where this actually helped. Did you run a separate experiment where everything else is the same, but actions/rewards are added? It would be odd if this significantly harms the results.

sungwoong commented 3 years ago

@alex-petrenko

Thank you again for your kind answers! I share my train logs: https://drive.google.com/file/d/1xFW8xIG6KlsUcHxc33LBLpYuVc2_h57Z/view?usp=sharing

At the moment, I just increase the rollout and recurrence length without any other changes. I will perform several ablations regarding the rollout length as your suggestions and let you know.

I also just added the previous action and reward to the core LSTM inputs where everything else is the same, referring to the "scalable agent", and observed that it significantly reduces the scores (by about half). I will re-check the codes and summaries.

alex-petrenko commented 3 years ago

Ok, here's some more analysis.

First of all, our command lines are identical, but there are some differences in cfg.json file because we're using different versions of the code (so some new parameters were introduced, plus some parameters changed their default values).

Notably: obs_subtract_mean 128 -> 0, obs_scale 128 -> 255 - these just mean that we went from scaling images (-1,1 ) to (0,1) I don't think this is very important, low chance of affecting the result.

ppo_clip_value 0.2 -> 1.0 Also unlikely to affect result. In fact 1.0 is a more reasonable value. 0.2 is too conservative.

Your system also has different performance characteristics, look at these two graphs: image

In your case actors almost never wait for policy to finish inference, and policy workers tend to wait for actors. Which means that actors are much slower. In itself, it is not a problem.

Your average policy-lag (version_diff) is about the same, but for some reason it spikes much higher (version_diff_max): image

What's important, is that sometimes it spikes higher than 35 which is the policy-lag threshold for this experiment (see --max_policy_lag=35 in the command line). So some of the experience gets discarded.

You can clearly see it here: image

That said, the total number of discarded rollouts is still very small (22000?) which amounts to just a few million transitions. In a 10^10 timestep experiment this should hardly matter.

All in all, I don't have a clear answer why your results are worse, although we identified some differences.

I see following options: 1) Re-run my experiment from the original version of the code with the exact same command line. I believe this is commit commit 3fe6bd4e146e67085051ce8a10650657916cdde0 2) Re-run the experiment with the latest version of the code while reverting all the changed parameters (obs scale, obs subtract mean, ppo clip).

I personally prefer option 2, or both if you have resources for it. Please let me know if you decide to do it, I might want to do my runs on my own to make sure there is no regression in the latest code base.

sungwoong commented 3 years ago

@alex-petrenko

Thanks a lot for your kind answers!

I have started running both the commit 3fe6bd4 and the latest code with the same parameters that you used (obs_scale:128. obs_subtract_mean:128, ppo_clip_value:0.2). I will let you know when it is completed.

alex-petrenko commented 3 years ago

I re-ran on one of the latest versions with the old params and I got similar results: image

The new score is about 50 instead of close to 60. @sungwoong did you get similar results in your re-run?

I noticed that in particular, it could not solve the select_nonmatching_object environment image

This environment is pretty distinct, it requires strong memory. @BoyuanLong if you have time, could we train a DMLab agent JUST on this environment alone just to make sure we can still learn it and nothing is broken? Chat on Slack for details :)

donghoonlee04 commented 3 years ago

Hi @alex-petrenko, I am @sungwoong 's colleague. Thanks a lot for your kind help.

We are running three experiments as follows: exp 1: 3fe6bd4 commit, 4 pbt, exact same command line with your checkpoint's cfg.json (pink) exp 2: latest commit, 4 pbt, obs_subtract_mean=128, obs_scale=128, ppo_clip_value=0.2, (orange) exp 3: latest commit, 4 pbt, obs_subtract_mean=128, obs_scale=128, ppo_clip_value=0.2, max_policy_lag=50 (blue) ref 1: your log (red) ref 2: our 10B train log (cyan) Please note that only .summary/0 are subsampled and are displayed (too large log file).

image The model is still training, but the results are not satisfactory. They seem to converge to 50 instead of 60.

image image As you mentioned, there is a difference in select_nonmatching_object. Language levels are also not solved.

The followings are our other progress.

  1. on LSTM initialization. core is nn.LSTM, https://github.com/alex-petrenko/sample-factory/blob/beb120e5b2b8761852c6de5a4c186a63e220716e/sample_factory/algorithms/appo/model_utils.py#L340 and initialize_weights are described here https://github.com/alex-petrenko/sample-factory/blob/beb120e5b2b8761852c6de5a4c186a63e220716e/sample_factory/algorithms/appo/model.py#L179 thus, core (and also instruction lstm in dmlab_model) may not be initialized properly. (In 3fe6bd4 commit, core is nn.LSTMCell, thus it is initialized properly)

  2. 3fe6bd4 commit cannot load your checkpoint. image

    • It can be easily fixed by re-name model keys during load_checkpoint, but this implies that your checkpoint might be learned in some other commit.
  3. wait_policy and wait actor image your log and our 10B train log showed different performance characteristics, it seems to depend on the commit, not the system. Our experiment using 3fe6bd4 commit showed similar performance characteristics with your log.

  4. experiments with_pbt=False In experiments without pbt, the results are more close to your log exp 4: latest commit, NO pbt, obs_subtract_mean=128, obs_scale=128, ppo_clip_value=0.2, dmlab instruction model device=gpu (blue) exp 5: latest commit, NO pbt, obs_subtract_mean=128, obs_scale=128, ppo_clip_value=0.2, dmlab instruction model device=cpu(default) (orange) https://github.com/alex-petrenko/sample-factory/blob/beb120e5b2b8761852c6de5a4c186a63e220716e/sample_factory/envs/dmlab/dmlab_model.py#L44 ref 1: your log (green) ref 2: our 10B train log (gray) image Reproducing problem is because of PBT randomness? or bug? I will have a look at the PBT code.

alex-petrenko commented 3 years ago

@donghoonlee04 thank you for the update! This is very insightful!

on LSTM initialization I believe in the current version this initialization code is actually inactive, since we switched from LSTMCell to LSTM for performance reasons. LSTM initialized itself as stated in the docs:

image

Bottom line, I don't think either version has a bug here, although LSTM weight initialization had likely changed. Since you tried to re-run the old commit and also got lower result, I don't think this is the problem. But good find.

3fe6bd4 commit cannot load your checkpoint.

I just double-checked and I was able to continue training from that checkpoint. Commit 3fe6bd4e146e67085051ce8a10650657916cdde0 from April 10th, command line is:

python -m algorithms.appo.train_appo --env=dmlab_30 --train_for_seconds=3600000 --algo=APPO --gamma=0.99 --use_rnn=True --num_workers=10 --num_envs_per_worker=12 --ppo_epochs=1 --rollout=32 --recurrence=32 --batch_size=2048 --benchmark=False --ppo_epochs=1 --max_grad_norm=0.0 --dmlab_renderer=software --decorrelate_experience_max_seconds=120 --reset_timeout_seconds=300 --encoder_custom=dmlab_instructions --encoder_type=resnet --encoder_subtype=resnet_impala --encoder_extra_fc_layers=1 --hidden_size=256 --nonlinearity=relu --rnn_type=lstm --dmlab_extended_action_set=True --num_policies=1 --pbt_replace_reward_gap=0.05 --pbt_replace_reward_gap_absolute=5.0 --pbt_period_env_steps=10000000 --pbt_start_mutation=100000000 --with_pbt=True --experiment=dmlab_30_resnet_4pbt_mode2_90_12_v86 --dmlab_one_task_per_worker=True --set_workers_cpu_affinity=True --max_policy_lag=35 --pbt_target_objective=dmlab_target_objective --train_for_env_steps=11000000000

(I reduced the number of workers since I ran on a smaller machine, but this should not affect the weight loading)

Are you sure you're running the right commit? Because action parameterizations only appeared in version 6684d23308ecd33a61ea0f28434fc532983276fb (commit message with "version V88", and the correct one is "version V86") If you were running the wrong commit, can you tell me which one?

wait_policy and wait actor

Interesting... Yes, the difference seems to be due to commit not the system. There were performance improvements which might have changed the profile, but then again Version V86 re-runs should match. I will try to take a closer look

experiments with_pbt=False

Interesting find. I didn't try to run this without PBT, just assumed it'd be better. Could be the randomness, although it does not feel like it. The PBT code has changed quite a bit, and this might be one of the reasons for the difference in performance with the latest version, but still the old "Version V83" results should be consistent. Maybe PBT is not very stable with such low number of policies and if it were unlucky in the beginning it leads to inferior performance.

donghoonlee04 commented 3 years ago

@alex-petrenko on 3fe6bd4 commit cannot load your checkpoint. There was a mistake during deploy codes from mac to server when I tried resume the checkpoint. I can resume the train on 3fe6bd4 commit. (The training from scratch used 3fe6bd4 commit.)

Sorry for making you do unnecessary work by my mistake.

donghoonlee04 commented 3 years ago

by the way, correct version is "version 86", not 83 right?

alex-petrenko commented 3 years ago

Oh, I'm so sorry, yes it is V86. I got confused by another run. The experiment folder I sent you is V86 (3fe6bd4)

donghoonlee04 commented 3 years ago

Hi @alex-petrenko I would like to share my progress and some issue. Yesterday, we focused on analyzing the effect of environments (pytorch versions, dmlab versions, and etc) when resume training from your checkpoint. I tried to resume training using 3fe6bd4 from your checkpoint with

image image image Level select_nonmatching_object seems OK (it did not mean that learning from scratch will working), but all laugnage levels are failed.

I think that this is because using "hash(s)" https://github.com/alex-petrenko/sample-factory/blob/beb120e5b2b8761852c6de5a4c186a63e220716e/sample_factory/envs/dmlab/dmlab_gym.py#L54

python3.3 document says: On Python 3.3 and greater, hash randomization is turned on by default. Although they remain constant within an individual Python process, they are not predictable between repeated invocations of Python.

Simple reproduce: image To fix this, use export PYTHONHASHSEED=$SEED before python script (os.environ["PYTHONHASHSEED"]=str(seed) not working), image or modify string_to_hash_bucket() to use other hash functions. For example,

import hashlib
...
def string_to_hash_bucket(s, vocabulary_size):
    return (int(hashlib.md5(s.encode('utf-8')).hexdigest(), 16) % (vocabulary_size - 1)) + 1

I prefer to modify the function am currently testing a modified version on single language task.

I have checked the hash seed is shared among learner processes and policy worker processes that spawn from the same run (thus, this may not affect the reproduce issue), however; the different runs lead to different hash seed: test or resume from the checkpoint is may not possible if the seed is not known and python>=3.3 was used. Did you set the seed number for your experiment?

alex-petrenko commented 3 years ago

This is a very good observation. I believe I modeled this after the original Impala implementation, although they used a Tensorflow function which I replaced with a Python method. Didn't know it can backfire. https://github.com/deepmind/scalable_agent/blob/6c0c8a701990fab9053fb338ede9c915c18fa2b1/experiment.py#L127 I am pretty sure Tensorflow method will not have this problem.

Very good fix suggestion as well! I will fix it in master branch ASAP.

I was also able to reproduce the issue. I continued training from the checkpoint on just a single level, and select_nonmatching_object has the same performance on tensorboard, while language_answer_quantitative_question has very low score. I do believe this is because of the hash issue.

Note that you can continue training on a single level like this:

 python -m algorithms.appo.train_appo --env=dmlab_contributed/dmlab30/rooms_select_nonmatching_object --train_for_seconds=3600000 --algo=APPO --gamma=0.99 --use_rnn=True --num_workers=20 --num_envs_per_worker=2 --ppo_epochs=1 --rollout=32 --recurrence=32 --batch_size=2048 --benchmark=False --ppo_epochs=1 --max_grad_norm=0.0 --dmlab_renderer=software --decorrelate_experience_max_seconds=1 --reset_timeout_seconds=300 --encoder_custom=dmlab_instructions --encoder_type=resnet --encoder_subtype=resnet_impala --encoder_extra_fc_layers=1 --hidden_size=256 --nonlinearity=relu --rnn_type=lstm --dmlab_extended_action_set=True --num_policies=1 --pbt_replace_reward_gap=0.05 --pbt_replace_reward_gap_absolute=5.0 --pbt_period_env_steps=10000000 --pbt_start_mutation=100000000 --with_pbt=True --experiment=dmlab_30_repro --dmlab_one_task_per_worker=True --set_workers_cpu_affinity=True --max_policy_lag=35 --pbt_target_objective=dmlab_target_objective

Sadly, this does not explain why we were unable to reproduce the results when starting from scratch with both old and new version. What is the best hypothesis now? Was it just a lucky PBT run? How are your non-PBT experiments are doing? We're also running some experiment on single levels to make sure there is no regression @BoyuanLong

alex-petrenko commented 3 years ago

https://github.com/alex-petrenko/sample-factory/commit/fca68b5e80ab484ec56daec7f0783f700d7aa4b6

fixed

donghoonlee04 commented 3 years ago

I have two hypotheses now.

  1. PBT randomness (even though it's hard to likely) How can I conduct experiments to check this in limited resources?

  2. Language level fps During testing the modified hash on a single language task (language_select_described_object), I found fps on language tasks are extremely low for both hash() and hashlib.md5(). We used the same machine (48 cores) for all experiments, and num_workers=90, num_envs_per_worker=12. Here are the results.

Benchmark mode (DummySampler): ~13k fps Training mode (appo): starting from ~10k fps, then dropped to 2k quickly

It happened when i tried: hash(), hashlib.md5(), bypass learner loop by setting self.with_training = False

Bypass forward loop in policy_worker showed fps: ~11k (the bottleneck is may in policy_worker forward loop)

Language tasks are also very slow in the multi-task setting (in some experiments on dmlab_30, I added the number of collected episodes stat to tensorboard). image

Thus, the reason that language scores are low (this is true in our experiments) is may because they have not enough samples. Some mismatches in the environment (pytorch version?) between reproducing runs compared with reference run can reduce the fps of language levels and affect language scores. I have look at this.

Training a single language task also slow in your system?

Also, I attached our logs on 4pbt, nopbt experiment. (all logs are subsampled to make tensorboard comfortable) https://drive.google.com/file/d/1hFg0K7DbBjskiKI9KHYMU5KEDULQfP0z/view?usp=sharing

alex-petrenko commented 3 years ago

I also observed very slow FPS on these envs. Can you please attach the profiling output? I.e. run your training on a single language level for a few minutes (i.e. --train_for_seconds=600) and attach the end of the console output, or sf_log.txt in newer versions. It should contain the performance breakdown. It is better to disable experience decorrelation while measuring performance. Could it be that in newer PyTorch versions these language LSTM models became very slow for some reason?

@BoyuanLong we should also look if our language experiments are slow compared to non-language envs. Can you please measure and post the results here?

@donghoonlee04 BTW how do you subsample TB logs?

donghoonlee04 commented 3 years ago

I ran the experiment --train_for_seconds=7200 (it seems require 2 hours to converge low fps) and will share the result (stderr output, tensorboard, cpu utilization).

I am not sure about the effect of pytorch version. Need to figure it out (I remember that experiments using seed_rl(TF) and scalable_agent(TF) on language experiments are not slower than others.)

I used https://github.com/velikodniy/tbparser package to subsample.

alex-petrenko commented 3 years ago

Two hours seems very weird... hard to think of a reason why it takes so long. Is it possible that you run out of cache and it starts generating levels (which is pretty slow)?

Also perhaps smaller number of envs_per_worker would allow it to converge faster since it will go through episodes quicker this way in each env.

donghoonlee04 commented 3 years ago

I remember that language levels do not use the cache, right? https://github.com/alex-petrenko/sample-factory/blob/83f515ef4614b5cc5d242b198e3ab5f6d3249187/sample_factory/envs/dmlab/dmlab30.py#L169

Did you mean that run with empty cache dir? or run with dmlab_use_level_cache=False?

by the way, we fixed https://github.com/alex-petrenko/sample-factory/blob/0bfa7e0bedde2419b56fe12ea72ea73f4b1149b7/sample_factory/envs/dmlab/dmlab_level_cache.py#L97 to

        lvl_seed_files = Path(os.path.join(cache_dir, '_contributed')).rglob(f'*.{LEVEL_SEEDS_FILE_EXT}')

to prevent rglob searches unnecessary directory (.level_cache directory that contains 1.4M map cache files). This makes cache initialization much faster.

donghoonlee04 commented 3 years ago

@alex-petrenko 2 hours train log on language_select_described_object

cpu & gpu utilization image

stderr output, tensorboard, cfg.json https://drive.google.com/file/d/1vYshb0tGU1We71TznbQWUp9bQUb_qr-G/view?usp=sharing

alex-petrenko commented 3 years ago

I remember that language levels do not use the cache, right? https://github.com/alex-petrenko/sample-factory/blob/83f515ef4614b5cc5d242b198e3ab5f6d3249187/sample_factory/envs/dmlab/dmlab30.py#L169

Did you mean that run with empty cache dir? or run with dmlab_use_level_cache=False?

by the way, we fixed https://github.com/alex-petrenko/sample-factory/blob/0bfa7e0bedde2419b56fe12ea72ea73f4b1149b7/sample_factory/envs/dmlab/dmlab_level_cache.py#L97

to

        lvl_seed_files = Path(os.path.join(cache_dir, '_contributed')).rglob(f'*.{LEVEL_SEEDS_FILE_EXT}')

to prevent rglob searches unnecessary directory (.level_cache directory that contains 1.4M map cache files). This makes cache initialization much faster.

Thank you, great fix! @BoyuanLong can you please fix it in master?

alex-petrenko commented 3 years ago

@donghoonlee04 thank you for sending the logs. Looks like most of the time it is just environment simulation that's so slow... in the logs: env_step: 6838.0344, so 6838 seconds just to step through the env.

@BoyuanLong can we reproduce this very slow simulation? If yes, could you please profile it further? Like, what part takes the longest, is it just calling step on the DMLab C++ env, or is it the end of the episode or smth else?

donghoonlee04 commented 3 years ago

cc. @alex-petrenko @sungwoong @BoyuanLong I found the reason why language_select_described_object is slow. After several experiments, I noticed that the env step is especially slow if it outs some reward, i.e. gather the right object (mainly due to re-spawn objects(?)).

see https://www.youtube.com/watch?v=JCJOgEcNgKY

and simple additional profiling by fixing: https://github.com/alex-petrenko/sample-factory/blob/2b14f40fced7a0178f987848831c244626aab4e0/sample_factory/algorithms/appo/actor_worker.py#L616

            start = time.time()
            with timing.add_time('env_step'):
                actions = [s.curr_actions() for s in self.actor_states[env_i]]
                new_obs, rewards, dones, infos, raw_rewards = e.step(actions)
                log.error(f'reward: {raw_rewards}, dones: {dones}, time_elpased: {time.time()-start}')

results are the following:

[31m[2021-05-30 23:13:12,427][02697] reward: [0.0], dones: [False], time_elpased: 0.008475780487060547
[31m[2021-05-30 23:13:12,428][03783] reward: [0.0], dones: [False], time_elpased: 0.004273891448974609
[31m[2021-05-30 23:13:12,428][06416] reward: [0.0], dones: [False], time_elpased: 0.007490873336791992
[31m[2021-05-30 23:13:12,429][00522] reward: [10.0], dones: [False], time_elpased: 1.9994854927062988
[31m[2021-05-30 23:13:12,429][00294] reward: [0.0], dones: [False], time_elpased: 0.004384040832519531
[31m[2021-05-30 23:13:12,430][01524] reward: [0.0], dones: [False], time_elpased: 0.004186868667602539
[31m[2021-05-30 23:13:12,430][00927] reward: [0.0], dones: [False], time_elpased: 0.007602214813232422
[31m[2021-05-30 23:13:12,430][06154] reward: [0.0], dones: [False], time_elpased: 0.008571386337280273
alex-petrenko commented 3 years ago

Just to confirm, this only happens with SampleFactory and not IMPALA implementation?

donghoonlee04 commented 3 years ago

@alex-petrenko I have tested impala, and it also showed fps drop if the language task outs some return values.

BoyuanLong commented 3 years ago

I would like to share our results on select_nonmatch_object environment.

We tested it on several commits (the latest, 3d3eff8, and 3fe6bd4), and the performances are similar across different commits. Only one seed in 3d3eff8 converged to 50% in this case, so it doesn't seem there's an obvious regression since v86.

Attached the tensorboard view below:

image

image

github-actions[bot] commented 3 years ago

This issue is stale because it has been open for 30 days with no activity.