Open RobinKa opened 12 months ago
I encountered the same problem when running a custom environment and lstm-based model on Ray 2.8.0. I had it print out the batch shapes, and with a bit of cleanup, my error is
Failure # 1 (occurred at 2023-11-10_19-43-36)
ray::PPO.train() (pid=10060, ip=10.91.0.26, actor_id=c9854846c60d32b8fc828da401000000, repr=PPO)
File "ray/tune/trainable/trainable.py", line 342, in train
raise skipped from exception_cause(skipped)
File "ray/tune/trainable/trainable.py", line 339, in train
result = self.step()
File "ray/rllib/algorithms/algorithm.py", line 853, in step
results, train_iter_ctx = self._run_one_training_iteration()
File "ray/rllib/algorithms/algorithm.py", line 2854, in _run_one_training_iteration
results = self.training_step()
File "ray/rllib/algorithms/ppo/ppo.py", line 429, in training_step
train_batch = synchronous_parallel_sample(
File "ray/rllib/execution/rollout_ops.py", line 101, in synchronous_parallel_sample
full_batch = concat_samples(all_sample_batches)
File "ray/rllib/policy/sample_batch.py", line 1580, in concat_samples
return concat_samples_into_ma_batch(samples)
File "ray/rllib/policy/sample_batch.py", line 1731, in concat_samples_into_ma_batch
out[key] = concat_samples(batches)
File "ray/rllib/policy/sample_batch.py", line 1651, in concat_samples
raise ValueError(
ValueError: Cannot concat data under key 'obs', b/c sub-structures under that key don't match. `samples`=[SampleBatch(150 (seqs=5): [... snip ...]), SampleBatch(150 (seqs=7): []), SampleBatch(150 (seqs=4): []), SampleBatch(150 (seqs=2): []), SampleBatch(150 (seqs=3): []), SampleBatch(150 (seqs=3): []), SampleBatch(150 (seqs=2): []), SampleBatch(150 (seqs=3): []), SampleBatch(150 (seqs=6): []), SampleBatch(150 (seqs=1): []), SampleBatch(150 (seqs=7): []), SampleBatch(150 (seqs=3): []), SampleBatch(150 (seqs=6): []), SampleBatch(150 (seqs=2): []), SampleBatch(150 (seqs=5): []), SampleBatch(150 (seqs=4): []), SampleBatch(150 (seqs=2): []), SampleBatch(150 (seqs=4): []), SampleBatch(150 (seqs=2): []), SampleBatch(150 (seqs=2): []), SampleBatch(150 (seqs=10): []), SampleBatch(150 (seqs=2): []), SampleBatch(150 (seqs=2): []), SampleBatch(150 (seqs=5): []), SampleBatch(150 (seqs=7): []), SampleBatch(150 (seqs=4): []), SampleBatch(150 (seqs=8): [])]
Original error:
all the input array dimensions for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 99 and the array at index 1 has size 38
Batch shapes: [(5, 99, 1430), (7, 38, 1430), (4, 63, 1430), (2, 122, 1430), (3, 69, 1430), (3, 102, 1430), (2, 104, 1430), (3, 88, 1430), (6, 64, 1430), (1, 150, 1430), (7, 44, 1430), (3, 110, 1430), (6, 36, 1430), (2, 126, 1430), (5, 70, 1430), (4, 99, 1430), (2, 96, 1430), (4, 79, 1430), (2, 123, 1430), (2, 148, 1430), (10, 21, 1430), (2, 132, 1430), (2, 117, 1430), (5, 55, 1430), (7, 51, 1430), (4, 120, 1430), (8, 49, 1430)]
The relevant piece of code is
My guess is that it is assuming the T dimension matches across all the tensors when trying to concatenate along B. My understanding is this would work if s.zero_padded
is True since they'd all have T = max_seq_len
, but at least in my case this isn't true. The sample batches should then be zero-padded to all have the same T dimension..? (not necessarily max_seq_len
, but perhaps the max dimension of the batches). Or do I have some configuration wrong that is causing this mismatch?
@RobinKa Thanks for posting this. I can replicate the error of @jfurches on ray==2.7.1
. However, I cannot replicate on ray-nightly
. WIth the nightly install I can run the example for long times without any error, even though the threshold rate has been exceeded for a long time. I guess this error has been already fixed.
Could you try the last version or the nightly one?
What happened + What you expected to happen
Running self_play_with_open_spiel with use_lstm=True fails when a new policy gets added after exceeding the win rate threshold.
Versions / Dependencies
Ray 2.7.1 Python 3.10.6
Reproduction script
Add
"use_lstm"=True
to the model config of self_play_with_open_spiel, run it and wait for the win rate threshold to be exceeded so a new policy is created (can lower the threshold for quicker reproduction).This also happens with other environments (I noticed this on my own environment where I copied most of the code of the example where the same thing happens).
Issue Severity
High: It blocks me from completing my task.