ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.97k stars 5.77k forks source link

A3C + pytorch seems to crash on jenkins #2610

Closed ericl closed 4 years ago

ericl commented 6 years ago

I've seen this multiple times, never before with the current ray backend though.

+ docker run -e RAY_USE_XRAY=1 --rm --shm-size=10G --memory=10G cfca7c87523a7d361ba790a1c243ee28904b7148b9b0bd0029ed2da28d2e28eb python /ray/python/ray/rllib/train.py --env CartPole-v1 --run A3C --stop '{"training_iteration": 2}' --config '{"num_workers": 2, "use_pytorch": true}'
WARNING: Your kernel does not support swap limit capabilities or the cgroup is not mounted. Memory limited without swap.
Detected environment variable 'RAY_USE_XRAY'.
Process STDOUT and STDERR is being redirected to /tmp/raylogs/.
Waiting for redis server at 127.0.0.1:15291 to respond...
Waiting for redis server at 127.0.0.1:46790 to respond...
Warning: Reducing object store memory because /dev/shm has only 10737418240 bytes available. You may be able to free up space by deleting files in /dev/shm. If you are inside a Docker container, you may need to pass an argument with the flag '--shm-size' to 'docker run'.

======================================================================
View the web UI at http://localhost:8888/notebooks/ray_ui78021.ipynb?token=83f0de448b2c4722ea862e3a0d9b766b34f716a75fd787a1
======================================================================

WARNING: Serializing objects of type <class 'ray.rllib.evaluation.sample_batch.SampleBatch'> by expanding them as dictionaries of their fields. This behavior may be incorrect in some cases.
WARN: gym.spaces.Box autodetected dtype as <type 'numpy.float32'>. Please provide explicit dtype.
Constructing fcnet [256, 256] <class 'torch.nn.modules.activation.Tanh'>
/ray/thirdparty/build/arrow/cpp/src/plasma/client.cc:628 Check failed: object_entry != objects_in_use_.end() 
A worker died or was killed while executing task 000000003130127e1a001b624d57a9d8f7792bc3.
Remote function train failed with:

Traceback (most recent call last):
  File "/ray/python/ray/worker.py", line 934, in _process_task
    *arguments)
  File "/ray/python/ray/actor.py", line 261, in actor_method_executor
    method_returns = method(actor, *args)
  File "/ray/python/ray/tune/trainable.py", line 137, in train
    result = self._train()
  File "/ray/python/ray/rllib/agents/a3c/a3c.py", line 102, in _train
    self.optimizer.step()
  File "/ray/python/ray/rllib/optimizers/async_gradients_optimizer.py", line 44, in step
    gradient, info = ray.get(fut)
  File "/ray/python/ray/worker.py", line 2493, in get
    raise RayGetError(object_ids, value)
RayGetError: Could not get objectid ObjectID(010000003130127e1a001b624d57a9d8f7792bc3). It was created by remote function <unknown> which failed with:

Remote function <unknown> failed with:

Invalid return value: likely worker died or was killed while executing the task.

A worker died or was killed while executing task 00000000d24e47fe917ec9734fcbe8b2229a24b4.
A worker died or was killed while executing task 00000000077dc670da5ee1bcfe07be2fda36bdb0.
== Status ==
Using FIFO scheduling algorithm.

Created LogSyncer for /root/ray_results/default/A3C_CartPole-v1_0_2018-08-08_22-12-14o1hPr0 -> 
== Status ==
Using FIFO scheduling algorithm.
Resources requested: 3/48 CPUs, 0/0 GPUs
Result logdir: /root/ray_results/default
RUNNING trials:
 - A3C_CartPole-v1_0:   RUNNING

Error processing event: Traceback (most recent call last):
  File "/ray/python/ray/tune/trial_runner.py", line 265, in _process_events
    result = ray.get(result_id)
  File "/ray/python/ray/worker.py", line 2493, in get
    raise RayGetError(object_ids, value)
RayGetError: Could not get objectid ObjectID(01000000f86621c5f95317bb8f95d539d6e6a24c). It was created by remote function train which failed with:

Remote function train failed with:

Traceback (most recent call last):
  File "/ray/python/ray/worker.py", line 934, in _process_task
    *arguments)
  File "/ray/python/ray/actor.py", line 261, in actor_method_executor
    method_returns = method(actor, *args)
  File "/ray/python/ray/tune/trainable.py", line 137, in train
    result = self._train()
  File "/ray/python/ray/rllib/agents/a3c/a3c.py", line 102, in _train
    self.optimizer.step()
  File "/ray/python/ray/rllib/optimizers/async_gradients_optimizer.py", line 44, in step
    gradient, info = ray.get(fut)
  File "/ray/python/ray/worker.py", line 2493, in get
    raise RayGetError(object_ids, value)
RayGetError: Could not get objectid ObjectID(010000003130127e1a001b624d57a9d8f7792bc3). It was created by remote function <unknown> which failed with:

Remote function <unknown> failed with:

Invalid return value: likely worker died or was killed while executing the task.

Worker ip unknown, skipping log sync for /root/ray_results/default/A3C_CartPole-v1_0_2018-08-08_22-12-14o1hPr0
== Status ==
Using FIFO scheduling algorithm.
Resources requested: 0/48 CPUs, 0/0 GPUs
Result logdir: /root/ray_results/default
ERROR trials:
 - A3C_CartPole-v1_0:   ERROR, 1 failures: /root/ray_results/default/A3C_CartPole-v1_0_2018-08-08_22-12-14o1hPr0/error_2018-08-08_22-12-21.txt

== Status ==
Using FIFO scheduling algorithm.
Resources requested: 0/48 CPUs, 0/0 GPUs
Result logdir: /root/ray_results/default
ERROR trials:
 - A3C_CartPole-v1_0:   ERROR, 1 failures: /root/ray_results/default/A3C_CartPole-v1_0_2018-08-08_22-12-14o1hPr0/error_2018-08-08_22-12-21.txt
ericl commented 6 years ago

cc @guoyuhong @richardliaw

ericl commented 6 years ago

The error is:

/ray/thirdparty/build/arrow/cpp/src/plasma/client.cc:628 Check failed: object_entry != objects_in_use_.end() 
guoyuhong commented 6 years ago

https://github.com/ray-project/ray/issues/2494 At first I thought this could be caused by some concurrent issue. After debug, it is not the reason. I believe this crash happened in worker part not the backend part. Is it possible caused by some script misuse, which is similar to use fetch in raylet?

ericl commented 6 years ago

I think the only difference is that we return pytorch gradients converted to numpy. These might have different structure than the tensorflow values.

guoyuhong commented 6 years ago

This error is not introduced by raylet. In legacy ray, it the same has this problem. However, it is very hard to repro it in my local env. Do you have any clue?

richardliaw commented 6 years ago

If you can install docker on your local env, it is quite easy to reproduce most of these tests in my experience.

You can run run_multi_node_tests.sh, which will create a near exact replica of the docker environment.

guoyuhong commented 6 years ago

This failure on Jenkins test happened 25% of all tests. And my local MacBook cannot support 5 dockers..

ericl commented 6 years ago

Hmm, I also cannot reproduce locally. I've never seen it fail on legacy ray though, do you have a link?

On Wed, Aug 8, 2018, 9:24 PM Yuhong Guo notifications@github.com wrote:

This failure on Jenkins test happened 25% of all tests. And my local MacBook cannot support 5 dockers..

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/ray-project/ray/issues/2610#issuecomment-411632393, or mute the thread https://github.com/notifications/unsubscribe-auth/AAA6SkPyRPVZUgqgzuIyuiKf2cM8TxOeks5uO7mGgaJpZM4V025s .

guoyuhong commented 6 years ago

Yes. Here: https://github.com/ray-project/ray/issues/2494 . In the description, there are many error. Actually, at that time when this issue was posted, there are two major issues in Jenkins test.

  1. https://github.com/ray-project/ray/issues/2503
  2. https://github.com/ray-project/ray/issues/2494 I though raylet could fix both, but it failed. I only fix the timeout issue. The plasma client check failure continues...
robertnishihara commented 6 years ago

There's probably a bug or misuse of the Plasma API in these lines https://github.com/ray-project/ray/blob/df7ee7ff1eaba5afbfa8e84c47b5a727a66f5ec7/src/ray/raylet/node_manager.cc#L776-L786

Do we have a backtrace? Is the fatal error happening when the data buffer destructor runs?

cc @pcmoritz

guoyuhong commented 6 years ago

When I debug the misuse of fetch in raylet mode of script python/ray/rllib/utils/actors.py in https://github.com/ray-project/ray/pull/2567 . The log comes from worker instead of backend. If a node_manager failed, there will be many other error message posted. Therefore, I believe this is only a script side failure in PlasmaClient.

guoyuhong commented 6 years ago

For the code snippet that @robertnishihara posted, I changed it to the following code, but it still did not work.

if (status.ok()) { 
   ARROW_CHECK_OK(store_client_.Seal(object_id.to_plasma_id()));
}
guoyuhong commented 6 years ago

I add some test code in arrow and there is call stack now: https://amplab.cs.berkeley.edu/jenkins/job/Ray-PRB/8009/console . The output is:

WARNING: Serializing objects of type <class 'ray.rllib.evaluation.sample_batch.SampleBatch'> by expanding them as dictionaries of their fields. This behavior may be incorrect in some cases.
F0906 07:00:24.489765   299 client.cc:637]  Check failed: object_entry != objects_in_use_.end() 
*** Check failure stack trace: ***
    @     0x7fead4739ced  google::LogMessage::Fail()
    @     0x7fead473b648  google::LogMessage::SendToLog()
    @     0x7fead47397fb  google::LogMessage::Flush()
    @     0x7fead4739a19  google::LogMessage::~LogMessage()
    @     0x7fead4734f20  arrow::ArrowLog::~ArrowLog()
    @     0x7fead470ca42  plasma::PlasmaClient::Impl::PerformRelease()
    @     0x7fead470cd21  plasma::PlasmaClient::Impl::Release()
    @     0x7fead4713edd  std::_Sp_counted_ptr_inplace<>::_M_dispose()
    @     0x7fead498efe1  arrow::Buffer::~Buffer()
    @     0x7fead5ec3731  __pyx_tp_dealloc_7pyarrow_3lib_Buffer()
    @     0x7feaee527aa9  array_dealloc
    @     0x7feaaa9c35ac  _ZNSt17_Function_handlerIFvPvEZN5torch5utils17tensor_from_numpyEP7_objectEUlS0_E_E9_M_invokeERKSt9_Any_dataS0_
    @     0x7fead7f3cc9c  at::deleteInefficientStdFunctionContext()
    @     0x7fead81e7314  THStorage_free
    @     0x7fead8286a1f  THTensor_free
    @     0x7fead80c98b9  at::CPUFloatTensor::~CPUFloatTensor()
    @     0x7feaaad1cbd1  torch::autograd::Variable::Impl::~Impl()
    @     0x7feaaad1cd49  torch::autograd::Variable::Impl::~Impl()
    @     0x7feaaad33e89  THPVariable_clear()
    @     0x7feaf0e9df72  collect
    @     0x7feaf0e9e971  _PyObject_GC_Malloc
    @     0x7feaf0e9e9be  _PyObject_GC_New
    @     0x7feaf0dee095  PyGen_New
    @     0x7feaf0df828a  function_call
    @     0x7feaf0dd37a3  PyObject_Call
    @     0x7feaf0de263d  instancemethod_call
    @     0x7feaf0dd37a3  PyObject_Call
    @     0x7feaf0e2c650  slot_tp_iter
    @     0x7feaf0dd490f  PyObject_GetIter
    @     0x7feaf0e68cfe  PyEval_EvalFrameEx
    @     0x7feaf0e6df9e  PyEval_EvalFrameEx
    @     0x7feaf0e6f4e9  PyEval_EvalCodeEx
WARN: gym.spaces.Box autodetected dtype as <type 'numpy.float32'>. Please provide explicit dtype.
Constructing fcnet [256, 256] <class 'torch.nn.modules.activation.Tanh'>
The worker with ID 359663a0a9229ff83a2de6b4a931d737f96580c3 died or was killed while executing the task with ID 92b019464da617fb5f5243277bd4de3d81d2474d
Remote function train failed with:

It looks like this is called by pytorch. @ericl , do you have an idea which script calls this?

ericl commented 6 years ago

That looks like pyarrow interpreting a numpy array as a tensor. Somehow that is deallocating the array which crashes?

On Thu, Sep 6, 2018, 12:34 AM Yuhong Guo notifications@github.com wrote:

I add some test code in arrow and there is call stack now: https://amplab.cs.berkeley.edu/jenkins/job/Ray-PRB/8009/console . The output is:

WARNING: Serializing objects of type <class 'ray.rllib.evaluation.sample_batch.SampleBatch'> by expanding them as dictionaries of their fields. This behavior may be incorrect in some cases. F0906 07:00:24.489765 299 client.cc:637] Check failed: object_entry != objects_inuse.end() Check failure stack trace: @ 0x7fead4739ced google::LogMessage::Fail() @ 0x7fead473b648 google::LogMessage::SendToLog() @ 0x7fead47397fb google::LogMessage::Flush() @ 0x7fead4739a19 google::LogMessage::~LogMessage() @ 0x7fead4734f20 arrow::ArrowLog::~ArrowLog() @ 0x7fead470ca42 plasma::PlasmaClient::Impl::PerformRelease() @ 0x7fead470cd21 plasma::PlasmaClient::Impl::Release() @ 0x7fead4713edd std::_Sp_counted_ptr_inplace<>::_M_dispose() @ 0x7fead498efe1 arrow::Buffer::~Buffer() @ 0x7fead5ec3731 __pyx_tp_dealloc_7pyarrow_3lib_Buffer() @ 0x7feaee527aa9 array_dealloc @ 0x7feaaa9c35ac _ZNSt17_Function_handlerIFvPvEZN5torch5utils17tensor_from_numpyEP7_objectEUlS0_E_E9_M_invokeERKSt9_AnydataS0 @ 0x7fead7f3cc9c at::deleteInefficientStdFunctionContext() @ 0x7fead81e7314 THStorage_free @ 0x7fead8286a1f THTensor_free @ 0x7fead80c98b9 at::CPUFloatTensor::~CPUFloatTensor() @ 0x7feaaad1cbd1 torch::autograd::Variable::Impl::~Impl() @ 0x7feaaad1cd49 torch::autograd::Variable::Impl::~Impl() @ 0x7feaaad33e89 THPVariable_clear() @ 0x7feaf0e9df72 collect @ 0x7feaf0e9e971 _PyObject_GC_Malloc @ 0x7feaf0e9e9be _PyObject_GC_New @ 0x7feaf0dee095 PyGen_New @ 0x7feaf0df828a function_call @ 0x7feaf0dd37a3 PyObject_Call @ 0x7feaf0de263d instancemethod_call @ 0x7feaf0dd37a3 PyObject_Call @ 0x7feaf0e2c650 slot_tp_iter @ 0x7feaf0dd490f PyObject_GetIter @ 0x7feaf0e68cfe PyEval_EvalFrameEx @ 0x7feaf0e6df9e PyEval_EvalFrameEx @ 0x7feaf0e6f4e9 PyEval_EvalCodeEx WARN: gym.spaces.Box autodetected dtype as <type 'numpy.float32'>. Please provide explicit dtype. Constructing fcnet [256, 256] <class 'torch.nn.modules.activation.Tanh'> The worker with ID 359663a0a9229ff83a2de6b4a931d737f96580c3 died or was killed while executing the task with ID 92b019464da617fb5f5243277bd4de3d81d2474d Remote function train failed with:

It looks like this is called by pytorch. @ericl https://github.com/ericl , do you have an idea which script calls this?

— You are receiving this because you were mentioned.

Reply to this email directly, view it on GitHub https://github.com/ray-project/ray/issues/2610#issuecomment-418995044, or mute the thread https://github.com/notifications/unsubscribe-auth/AAA6SqvmK2l-bvLd0OljCAEAK8L-qUvPks5uYNAEgaJpZM4V025s .

robertnishihara commented 6 years ago

@guoyuhong @ericl @pcmoritz interesting, I would expect tensor_from_numpy to create a PyTorch Tensor, which holds a reference to the numpy array that it was created from, which in turn should hold a reference to the pyarrow buffer backed by shared memory.

However, it looks like the pyarrow buffer is getting released, which makes me think that isn't happening..

I can reproduce the issue in docker, but somehow not when I add print statements..

guoyuhong commented 6 years ago

FYI: From the following error message:

Remote function train failed with:

Traceback (most recent call last):
  File "/ray/python/ray/worker.py", line 931, in _process_task
    *arguments)
  File "/ray/python/ray/actor.py", line 261, in actor_method_executor
    method_returns = method(actor, *args)
  File "/ray/python/ray/rllib/agents/agent.py", line 212, in train
    return Trainable.train(self)
  File "/ray/python/ray/tune/trainable.py", line 146, in train
    result = self._train()
  File "/ray/python/ray/rllib/agents/a3c/a3c.py", line 106, in _train
    self.optimizer.step()
  File "/ray/python/ray/rllib/optimizers/async_gradients_optimizer.py", line 44, in step
    gradient, info = ray.get(fut)
  File "/ray/python/ray/worker.py", line 2514, in get
    raise RayGetError(object_ids, value)

It looks like compute_gradients fails in function torch.from_numpy in the worker. Therefore the later worker tries to get the output of this function but fails.

richardliaw commented 6 years ago

Is this a pyarrow issue? Should we open an issue on PyArrow to track?

guoyuhong commented 6 years ago

I have some new findings. I need to confirm. Let's hold for a second.

guoyuhong commented 6 years ago

I added some more test code. At the beginning, I doubted that this could be caused by some racing issue while killing a worker, but the experiments didn't indicate that. The issue only happens in PyTorch test. No Tensorflow tests have this problem. Moreover, the PyTorch crash log is always the earliest log from the output. Therefore, it could be caused by PyTorch or PyArrow problem. @robertnishihara Is the release not an expected behavior? @richardliaw Could you please open an issue on PyArrow to track?

robertnishihara commented 6 years ago

@guoyuhong Release definitely should be called when objects going out of scope. It's not obvious to me that it should be called in tensor_from_numpy, but I'm not sure.

guoyuhong commented 6 years ago

@robertnishihara Thanks. I may turn to other work items and put this problem aside due to our arrangement.

robertnishihara commented 6 years ago

@guoyuhong ok, thanks for looking into it, and please let me know if you have other ideas.

robertnishihara commented 6 years ago

I'm going to look into this a little bit now.

pcmoritz commented 6 years ago

Given the stacktrace, I wonder if it is possible that Release is called from a different thread, that might cause problems.

ericl commented 6 years ago

@pcmoritz @guoyuhong I've managed to reproduce this by having a plasma object released by a different thread than the one that created it (confirming the above hypothesis).

Unfortunately I couldn't reduce the bug to a small script, but if you take this PR here: https://github.com/ray-project/ray/pull/3147/files#diff-47c1b54f34a8f5fe87bc33effe2818d5R155

and change self.outqueue.put((ra, replay, prio_dict, replay.count)) to self.outqueue.put((ra, None, prio_dict, replay.count)) on that line, then running in the rllib dir ./train.py --env Pendulum-v0 --run APEX_DDPG --ray-num-cpus 8 --stop '{"training_iteration": 2}' --config '{"num_workers": 2, "optimizer": {"num_replay_buffer_shards": 1}, "learning_starts": 100, "min_iter_time_s": 1}' will consistently crash within about 10 seconds of running.

robertnishihara commented 6 years ago

Oh nice!

robertnishihara commented 6 years ago

The plasma_client is not thread safe. So whenever we use it, we protect it with a lock. However, the plasma client gets used implicitly whenever a numpy array backed by the object store goes out of scope (because it is used to call "release" on the object store).

That may be the problem.