pfnet / pfrl

PFRL: a PyTorch-based deep reinforcement learning library
MIT License
1.2k stars 157 forks source link

CUDA error when sampling action from SoftmaxCategorialHead distribution #118

Closed tkelestemur closed 3 years ago

tkelestemur commented 3 years ago

I have a custom environment implemented in Gym API. It has a 3-channel image observations and 4 actions. I'm training PPO with a CNN-based policy network. I get a CUDA error when sampling from the SoftmaxCategorialHead. The error happens at a different step even tough I'm using pfrl.utils.set_random_seed(args.seed).

The error is below with CUDA_LAUNCH_BLOCKING=1:

/pytorch/aten/src/ATen/native/cuda/MultinomialKernel.cu:190: sampleMultinomialOnce: block: [0,0,0], thread: [3,0,0] Assertion `val >= zero` failed.
/pytorch/aten/src/ATen/native/cuda/MultinomialKernel.cu:190: sampleMultinomialOnce: block: [3,0,0], thread: [0,0,0] Assertion `val >= zero` failed.
/pytorch/aten/src/ATen/native/cuda/MultinomialKernel.cu:190: sampleMultinomialOnce: block: [3,0,0], thread: [1,0,0] Assertion `val >= zero` failed.
/pytorch/aten/src/ATen/native/cuda/MultinomialKernel.cu:190: sampleMultinomialOnce: block: [3,0,0], thread: [2,0,0] Assertion `val >= zero` failed.
/pytorch/aten/src/ATen/native/cuda/MultinomialKernel.cu:190: sampleMultinomialOnce: block: [3,0,0], thread: [3,0,0] Assertion `val >= zero` failed.
/pytorch/aten/src/ATen/native/cuda/MultinomialKernel.cu:190: sampleMultinomialOnce: block: [4,0,0], thread: [0,0,0] Assertion `val >= zero` failed.
/pytorch/aten/src/ATen/native/cuda/MultinomialKernel.cu:190: sampleMultinomialOnce: block: [4,0,0], thread: [1,0,0] Assertion `val >= zero` failed.
/pytorch/aten/src/ATen/native/cuda/MultinomialKernel.cu:190: sampleMultinomialOnce: block: [4,0,0], thread: [2,0,0] Assertion `val >= zero` failed.
/pytorch/aten/src/ATen/native/cuda/MultinomialKernel.cu:190: sampleMultinomialOnce: block: [4,0,0], thread: [3,0,0] Assertion `val >= zero` failed.
THCudaCheck FAIL file=/pytorch/torch/csrc/generic/serialization.cpp line=31 error=710 : device-side assert triggered
Traceback (most recent call last):
  File "/home/tarik/projects/pfrl/pfrl/experiments/train_agent_batch.py", line 71, in train_agent_batch
    actions = agent.batch_act(obss)
  File "/home/tarik/projects/pfrl/pfrl/agents/ppo.py", line 654, in batch_act
    return self._batch_act_train(batch_obs)
  File "/home/tarik/projects/pfrl/pfrl/agents/ppo.py", line 712, in _batch_act_train
    batch_action = action_distrib.sample().cpu().numpy()
  File "/home/tarik/venvs/research/lib/python3.8/site-packages/torch/distributions/categorical.py", line 107, in sample
    samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T
RuntimeError: CUDA error: device-side assert triggered

My network:

model = nn.Sequential(
  IMPALACNN(),
  pfrl.nn.Branched(
      nn.Sequential(
          lecun_init(nn.Linear(512, n_actions), 1e-2),
          SoftmaxCategoricalHead(),
      ),
      lecun_init(nn.Linear(512, 1))
  )
)

where IMPALACNN can be seen here.

As far as I understand the issues comes from sampling with infinite logits but I don't know why would the last linear layer would produce infinite values.

Edit: I printed out the probs and the logits of the Categorical dist and it seems like the policy network produces NaN values for a single batch of observation data. I double checked my environment and it doesn't return any NaN or inf values in the observations. The strange thing is that all the policy network returns NaN for each environment. I'm training with 16 environment with GPU. Here is the output of the probs and logits before the error:

probs tensor([[nan, nan, nan, nan],                                                                                                                                                                          
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan]], device='cuda:0')                                                                                                                                                              
logts tensor([[nan, nan, nan, nan],                                                                                                                                                                          
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan],                                                                                                                                                                                
        [nan, nan, nan, nan]], device='cuda:0') 
muupan commented 3 years ago

Those nans suggest that either an observation or the network's parameters contain nan. If the latter is the case, I would recommend checking loss values for each update to see how training gets unstable and lead to nan.

tkelestemur commented 3 years ago

It turns out that I had nans in my observations. Closing.