ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.12k stars 5.6k forks source link

[RLlib] PPO torch over 5X slower than tensorflow on atari and uses up all RAM #6962

Closed pmacalpine closed 4 years ago

pmacalpine commented 4 years ago

What is the problem?

The recently added torch implementation of PPO #6826 is over 5X slower when training on atari (breakout) and also ends up slowly consuming all the system RAM (perf/ram_util_percent) before crashing. Looking at logs (attached) it seems as if a possible source of the slowdown is sampler_perf/mean_inference_ms where torch is over 40ms while tensorflow is under 2ms. Perhaps the torch PPO implementation is not utilizing the GPU nearly as much as the tensorflow version as on my systems torch is using 1394MB of GPU memory while tensorflow is using 4446MB.

I have attached logs of running both torch and tensorflow versions of PPO on atari breakout as well as the yaml files for the runs.

Ray version and other system information (Python version, TensorFlow version, OS): System: Azure Standard_NC12 with DSVM image python: 3.7.6 ray: 0.9.0.dev0 (built from source rev 81238945b993fe8c41331b3cf85a55129e7e4267) OS: Ubuntu 18.04.03 LTS torch: 1.3.0 tensorflow: 2.0.0

Note that I see similar slow speed and consumption of RAM with PPO torch on other system setups (Azure NC_6 and NC_24) and with other environments (procgen), networks (pytorch version of cnn impala network ported from baselines), and versions of torch (1.4.0).

Reproduction (REQUIRED)

Please provide a script that can be run to reproduce the issue. The script should have no external library dependencies (i.e., use fake or mock data / environments): For pytorch: rllib train -v -f atari-ppo-pt.yaml For tensorflow: rllib train -v -f atari-ppo-tf.yaml

If we cannot run your script, we cannot fix your issue.

ericl commented 4 years ago

Interesting, previously we saw more on the order of 20-30% slowdown with torch only. It looks the configs you're using are the same as we benchmarked.

By the way, the inference for rollouts is done on CPU only, so it's odd it's much slower at 40ms.

cc @sven1977

ericl commented 4 years ago

I did a bit of digging into this and it seems to happen when OMP_NUM_THREADS is not set. Here blue is pytorch with OMP_NUM_THREADS unset, and orange is torch with OMP_NUM_THREADS=1.

image

This patch should auto set it to avoid it: https://github.com/ray-project/ray/pull/6998/files

edoakes commented 4 years ago

To summarize, the issue is that if OMP_NUM_THREADS is too low, ray.put() becomes slow. If OMP_NUM_THREADS is high, then we see performance degradation like the above due to thrashing.

@ericl and I discussed, we think the most reasonable solution is:

@pcmoritz @robertnishihara any thoughts?

pmacalpine commented 4 years ago

Thanks for looking into this issue! The fix in #6998 of setting OMP_NUM_THREADS=1 speeds things up and makes it much closer in performance to the tensorflow version of PPO.

The other equally problematic bug in this issue is that torch PPO is slowly eating up all the system RAM until all RAM is exhausted and the process is killed. I still see this problem with OMP_NUM_THREADS=1.

Can this issue be reopened to also address the RAM issue? Or, if preferred, I can open a new issue just addressing the RAM problem.

ericl commented 4 years ago

@pmacalpine do you know which process is eating the RAM, is it the trainer or the workers?

pmacalpine commented 4 years ago

@pmacalpine do you know which process is eating the RAM, is it the trainer or the workers?

I think it is the workers. In the log error.txt it says the following:

ray.memory_monitor.RayOutOfMemoryError: More than 95% of the memory on node patbert-dsvm-nc12-2 is used (104.66 / 110.14 GB). The top 10 memory consumers are:

PID     MEM     COMMAND
28591   8.73GiB ray::RolloutWorker
28588   8.71GiB ray::RolloutWorker
28593   8.69GiB ray::RolloutWorker
28592   8.67GiB ray::RolloutWorker
28590   8.65GiB ray::RolloutWorker
28586   8.62GiB ray::RolloutWorker
28585   8.62GiB ray::RolloutWorker
28587   8.61GiB ray::RolloutWorker
28595   8.59GiB ray::RolloutWorker
28584   8.57GiB ray::RolloutWorker

In addition, up to 13.08 GiB of shared memory is currently being used by the Ray object store. You can set the object store size with the `object_store_memory` parameter when starting Ray, and the max Redis size with `redis_max_memory`. Note that Ray assumes all system memory is available for use by workers. If your system has other applications running, you should manually set these memory limits to a lower value.
pmacalpine commented 4 years ago

Is anyone actively looking into this issue and/or have an idea of what the cause of this memory leak might be? If not I might look into it a bit more as the issue is blocking me, but would prefer not to if someone who has more knowledge of the codebase is actively doing so.

ericl commented 4 years ago

I've taken a cursory look at torch_policy.compute_actions(), but didn't see anything suspicious there. I think it is originating from something in compute_actions(), but we do have torch.no_grad() there so I'm not sure what would be leaking.

pmacalpine commented 4 years ago

I was looking at https://github.com/pytorch/pytorch/issues/13246 to see if the memory issue/leak with pytorch reported there might be related. Not surprisingly we also see the same memory leak with pytorch A2C.

pmacalpine commented 4 years ago

I've taken a cursory look at torch_policy.compute_actions(), but didn't see anything suspicious there. I think it is originating from something in compute_actions(), but we do have torch.no_grad() there so I'm not sure what would be leaking.

I don't think the source of the problem is torch_policy.compute_actions(). I changed https://github.com/ray-project/ray/blob/6745459f966a9d9e68210aab19d948ead0c9f563/rllib/evaluation/sampler.py#L549 to eval_results[policy_id] = (np.array([0]), [], {'action_logp' : np.array([-0.0084]), 'behaviour_logits': np.array([np.array([5.15, -2.9, 0.33, -3.35])]), 'vf_preds': np.array([-0.21])}) such that fake values are used instead of calling compute_actions(). I still see the memory leak and eventually run out of cpu RAM even though compute_actions() is never called. I see this problem with both torch 1.3 and 1.4.

ericl commented 4 years ago

Interesting. What if you also comment out set_weights() in the rollout worker class?

pmacalpine commented 4 years ago

Interesting. What if you also comment out set_weights() in the rollout worker class?

The memory leak is still there with RAM utilization climbing from around 10% to over 50% and still steadily rising. Any other suggestions for possible code to comment out that might be causing a memory leak?

ericl commented 4 years ago

Hmm, with both those replaced I don't think we are even calling into pytorch any more beyond initializing the model. Perhaps go through TorchPolicy and replace code there with noops in case they are used.

It's also possible just deserializing (torch) weights is leaking memory somehow. You can have get_weights return an empty dict to see if that changes anything.

pmacalpine commented 4 years ago

Hmm, with both those replaced I don't think we are even calling into pytorch any more beyond initializing the model. Perhaps go through TorchPolicy and replace code there with noops in case they are used.

It's also possible just deserializing (torch) weights is leaking memory somehow. You can have get_weights return an empty dict to see if that changes anything.

I still see the memory leak even after I've replaced pretty much everything in TorchPolicy to be noops as well as get_weights() and set_weights() in RolloutWorker, and also the change I mentioned above to sampler.py such that torch_policy.compute_actions() is never called.

I have however found a way to stop the memory leak with no changes to code other than not always calling the line at https://github.com/ray-project/ray/blob/fae99ecb8e8d750bddcb3674f720f068541dc15d/rllib/evaluation/rollout_worker.py#L481 that reads batches from the input reader, but instead setting batches = [self.last_batch] assuming that self.last_batch is not None (I purposely always set self.last_batch at https://github.com/ray-project/ray/blob/fae99ecb8e8d750bddcb3674f720f068541dc15d/rllib/evaluation/rollout_worker.py#L521). With this change, which is kind of like using a fake sampler, RAM slowly increases from around 10% usage to just over 30% before plateauing (I think I saw similar RAM usage with tensorflow).

I've also tried to isolate the effect of reading batches from the input reader by keeping all of the noop code and setting self._fake_sampler = True in rollout_worker.py, but not returning self.last_batch from RolloutWorker.sample() until right after reading batches from the input reader, and still see the memory leak. If I run this same test, but without reading batches from the input reader, RAM usage stays pretty much constant at right below 10%.

RAM In the above image the higher orange line is when reading batches from the input reader and the lower red line is when not doing so.

Does it make sense for a memory leak to happen when reading from the input reader, or maybe something else that would cause a memory leak isn't happening because new data isn't being sampled? When not always refreshing batches I see "WARNING:root:NaN or Inf found in input tensor", "lib/python3.7/site-packages/numpy/core/fromnumeric.py:3335: RuntimeWarning: Mean of empty slice", and "lib/python3.7/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars" messages making me wonder if some other computation that might be causing a memory leak is being skipped.

pmacalpine commented 4 years ago

I've been using muppy to try and better determine what memory is being leaked. The following is the diff of added/leaked memory between 100 iterations of learning for a single worker (roughly 50,000 time steps for each worker) which shows the main accumulation of memory being torch.Tensor: types | # objects | total size torch.Tensor | 498 |  38.91 KB int | 116 | 3.18 KB float | 106 | 2.48 KB

Also, maybe of interest, if I change sample_batch_size from 100 to 500 (with still the same 10 workers and train_batch_size 5000) torch.Tensor memory is leaked at about 1/5 the rate: types | # objects | total size torch.Tensor | 100 | 7.81 KB float | 159 | 3.73 KB int | 78 | 2.14 KB

ericl commented 4 years ago

That suggests it's something to do with serializing Torch tensors, since each sample batch gets serialized and set back to the learner process. What's odd is that the sample batch should only be containing numpy arrays, there shouldn't be any tensors there.

Eric

On Tue, Feb 18, 2020 at 1:32 PM Patrick MacAlpine notifications@github.com wrote:

I've been using muppy https://pythonhosted.org/Pympler/muppy.html to try and better determine what memory is being leaked. The following is the diff of added/leaked memory between 100 iterations of learning for a single worker (roughly 50,000 time steps for each worker) which shows the main accumulation of memory being torch.Tensor: types | # objects | total size torch.Tensor | 498 | 38.91 KB int | 116 | 3.18 KB float | 106 | 2.48 KB

Also, maybe of interest, if I change sample_batch_size from 100 to 500 (with 10 workers and train_batch_size 5000) torch.Tensor memory is leaked at about 1/5 the rate: torch.Tensor | 100 | 7.81 KB float | 159 | 3.73 KB int | 78 | 2.14 KB

— You are receiving this because you modified the open/close state. Reply to this email directly, view it on GitHub https://github.com/ray-project/ray/issues/6962?email_source=notifications&email_token=AAADUSXZHAQQSMKVZ7WME6DRDRHXRA5CNFSM4KNNKOO2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEMFG34Q#issuecomment-587886066, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAADUSTAFX2R6BOQQOKJXZ3RDRHXRANCNFSM4KNNKOOQ .

pmacalpine commented 4 years ago

That suggests it's something to do with serializing Torch tensors, since each sample batch gets serialized and set back to the learner process. What's odd is that the sample batch should only be containing numpy arrays, there shouldn't be any tensors there. Eric

Yes, I only see numpy arrays as the values stored in the dictionary for the batch.

ericl commented 4 years ago

One other thing we do approximately once per sample batch is calling postprocess_ppo_gae(), which among other things calls policy._value().

On Tue, Feb 18, 2020 at 2:18 PM Patrick MacAlpine notifications@github.com wrote:

That suggests it's something to do with serializing Torch tensors, since each sample batch gets serialized and set back to the learner process. What's odd is that the sample batch should only be containing numpy arrays, there shouldn't be any tensors there. Eric Yes, I only see numpy arrays as the values stored in the dictionary for the batch.

— You are receiving this because you modified the open/close state. Reply to this email directly, view it on GitHub https://github.com/ray-project/ray/issues/6962?email_source=notifications&email_token=AAADUSVUOJ466QJZDK724HTRDRNETA5CNFSM4KNNKOO2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEMFSFDA#issuecomment-587932300, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAADUSULH7MNCBQ4IBBXBYDRDRNETANCNFSM4KNNKOOQ .

pmacalpine commented 4 years ago

One other thing we do approximately once per sample batch is calling postprocess_ppo_gae(), which among other things calls policy._value().

Preventing policy._value() from being called by setting completed = True at https://github.com/ray-project/ray/blob/c957ed58edfbfb9d9574e8ea1c73d12c1002d7c0/rllib/agents/ppo/ppo_tf_policy.py#L175 stops the memory leak of torch tensors.

ericl commented 4 years ago

Oh, maybe we need to do with torch.no_grad() in "def value" of ppo_torch_policy.

Eric

On Tue, Feb 18, 2020 at 3:10 PM Patrick MacAlpine notifications@github.com wrote:

One other thing we do approximately once per sample batch is calling postprocess_ppo_gae(), which among other things calls policy._value().

Preventing policy._value() from being called by setting completed = True at https://github.com/ray-project/ray/blob/c957ed58edfbfb9d9574e8ea1c73d12c1002d7c0/rllib/agents/ppo/ppo_tf_policy.py#L175 stops the memory leak of torch tensors.

— You are receiving this because you modified the open/close state. Reply to this email directly, view it on GitHub https://github.com/ray-project/ray/issues/6962?email_source=notifications&email_token=AAADUSSCVHHAAWZD6H4N6TTRDRTERA5CNFSM4KNNKOO2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEMFWT2A#issuecomment-587950568, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAADUSQ7ODLACQM5Z5ARHR3RDRTERANCNFSM4KNNKOOQ .

pmacalpine commented 4 years ago

Oh, maybe we need to do with torch.no_grad() in "def value" of ppo_torch_policy. Eric

Yes, that stops the torch tensor leaks! There are a few other small things still accumulating memory, but maybe they'll be taken care of by the garbage collector later. I'll test with this change to see if it fixes the overall problem.

sven1977 commented 4 years ago

Hmm, that's not the solution, though. We need the gradient for the vf_loss. This is all very strange... I'm experimenting with A2C as that's also showing the same behavior.

sven1977 commented 4 years ago

I'm printing out the number of Torch tensors and it's always constant now with my new fix (no garbage collection necessary anymore), 46 tensors with tuned_examples/atari-ppo.yml + --torch option on a single env but memory consumption still creeps up in PPO torch Atari.

pmacalpine commented 4 years ago

I can confirm the change of adding with torch.no_grad() to def value (attached) fixes this issue. ppo_torch_policy.zip LeakFixed While there are some extra floats and ints accumulated as I mentioned earlier, RAM usage eventually plateaus (I assume the garbage collector kicks in) at a little over 30% usage in the same way as what I believe I've also seen with tensorflow although I can double check that. The jump in RAM usage at 6M is due to starting tensorboard.

I can also confirm that this fix also fixes the GPU memory leak I reported in #7182. I'm happy to submit a PR for this fix to resolve both of these issues.

While a colleague saw a rise in RAM usage with A2C, I'd have to double check if A2C/A3C is also leaking tensors or if RAM usage eventually plateaus as that's following a different code path.

pmacalpine commented 4 years ago

Hmm, that's not the solution, though. We need the gradient for the vf_loss. This is all very strange... I'm experimenting with A2C as that's also showing the same behavior.

I see your point about the gradient needed for the vf_loss (previously I thought this wasn't necessary but I think you're right). Interestingly learning does seem to be progressing well however.

AvgReward

I just talked with my colleague and A2C RAM usage does plateau and so seems to not be leaking tensors (haven't verified that however).

ericl commented 4 years ago

@sven1977 we should insert the no_grad at a higher level, in the postprocess_trajectory function of the policy. We never differentiate through that, since it is called during experience collection only.

sven1977 commented 4 years ago

@ericl Yes, you are right. We do not differentiate through the post processing (value_function), only for the loss. I'll move the no_grad up.

The torch.Tensor leak is a different issue, which - I hope - is fixed by this one here: https://github.com/ray-project/ray/pull/7237 All unnecessary Tensors are however always garbage collected in the end and this PR does not fix the memory leak (which is also happening in A2C torch).

I have a minimal version of PPO, which does not leak anymore and the key is the compute_advantages function (which also is used in A2C). So if we move the no_grads above that call, it may fix it. I'll try all this now ...

sven1977 commented 4 years ago

Ok, I moved the no_grad all the way up into postprocess_ppo_gae (leaving it only around self._value was not enough, the call to compute_advantages seems to require being in no_grad as well, which nicely explains A2C's leak). It looks very good now with a fully functional PPO. Also the torch.Tensor count is now constant with my fix from yesterday. #7237 I'll keep it running for a while ... then confirm vs A2C as well. Thanks a lot for your help everyone!

sven1977 commented 4 years ago

This PR closes this issue: https://github.com/ray-project/ray/pull/7238 Will be merged into master in a few days.