salesforce / warp-drive

Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning Framework on a GPU (JMLR 2022)
BSD 3-Clause "New" or "Revised" License
465 stars 78 forks source link

found invalid values error #74

Closed MatPoliquin closed 1 year ago

MatPoliquin commented 1 year ago

When running the example here: https://github.com/salesforce/warp-drive/blob/master/tutorials/simple-end-to-end-example.ipynb

Note:

repo commit: b5d46d4 These tests passed successfully: python warp_drive/utils/unittests/run_unittests_pycuda.py python warp_drive/utils/unittests/run_trainer_tests.py NVIDIA p104-100 8GB

I get this output:

Device: 0 Iterations Completed : 1 / 50

Speed performance stats

Mean policy eval time per iter (ms) : 196.94 Mean action sample time per iter (ms) : 37.12 Mean env. step time per iter (ms) : 85.96 Mean training time per iter (ms) : 123.15 Mean total time per iter (ms) : 453.86 Mean steps per sec (policy eval) : 50775.87 Mean steps per sec (action sample) : 269373.56 Mean steps per sec (env. step) : 116335.91 Mean steps per sec (training time) : 81202.92 Mean steps per sec (total) : 22033.34

Metrics for policy 'runner'

VF loss coefficient : 0.01000 Entropy coefficient : 0.05000 Total loss : 0.09430 Policy loss : 0.33186 Value function loss : 0.20734 Mean rewards : 0.00085 Max. rewards : 1.00000 Min. rewards : -1.00000 Mean value function : 0.04290 Mean advantages : 0.06929 Mean (norm.) advantages : 0.06929 Mean (discounted) returns : 0.11219 Mean normalized returns : 0.11219 Mean entropy : 4.79267 Variance explained by the value function: 0.01151 Std. of action_0 over agents : 3.13083 Std. of action_0 over envs : 3.14615 Std. of action_0 over time : 3.14577 Std. of action_1 over agents : 3.17047 Std. of action_1 over envs : 3.18386 Std. of action_1 over time : 3.18446 Current timestep : 10000.00000 Gradient norm : 0.00000 Learning rate : 0.00500 Mean episodic reward : 1.71000 Mean episodic steps : 100.00000

Metrics for policy 'tagger'

VF loss coefficient : 0.01000 Entropy coefficient : 0.05000 Total loss : 1.78037 Policy loss : 2.01399 Value function loss : 0.59261 Mean rewards : 0.01810 Max. rewards : 1.00000 Min. rewards : 0.00000 Mean value function : 0.06817 Mean advantages : 0.42039 Mean (norm.) advantages : 0.42039 Mean (discounted) returns : 0.48856 Mean normalized returns : 0.48856 Mean entropy : 4.79084 Variance explained by the value function: -0.00882 Std. of action_0 over agents : 3.06860 Std. of action_0 over envs : 3.17762 Std. of action_0 over time : 3.17566 Std. of action_1 over agents : 3.05678 Std. of action_1 over envs : 3.16503 Std. of action_1 over time : 3.16620 Current timestep : 10000.00000 Gradient norm : 0.00000 Learning rate : 0.00200 Mean episodic reward : 9.05000 Mean episodic steps : 100.00000

[Device 0]: Saving the results to the file '/tmp/continuous_tag/example/1679065351/results.json' [Device 0]: Saving the 'runner' torch model to the file: '/tmp/continuous_tag/example/1679065351/runner_10000.state_dict'. [Device 0]: Saving the 'tagger' torch model to the file: '/tmp/continuous_tag/example/1679065351/tagger_10000.state_dict'. Traceback (most recent call last): File "wd_test.py", line 84, in trainer.train() File "/home/warp/github/warp-drive/warp_drive/training/trainer.py", line 415, in train metrics = self._update_model_params(iteration) File "/home/warp/github/warp-drive/warp_drive/training/trainer.py", line 710, in _update_model_params perform_logging=logging_flag, File "/home/warp/github/warp-drive/warp_drive/training/algorithms/policygradient/a2c.py", line 102, in compute_loss_and_metrics m = Categorical(action_probabilities_batch[idx]) File "/home/warp/anaconda3/envs/warp_drive/lib/python3.7/site-packages/torch/distributions/categorical.py", line 64, in init super(Categorical, self).init(batch_shape, validate_args=validate_args) File "/home/warp/anaconda3/envs/warp_drive/lib/python3.7/site-packages/torch/distributions/distribution.py", line 56, in init f"Expected parameter {param} " ValueError: Expected parameter probs (Tensor of shape (100, 100, 5, 11)) of distribution Categorical(probs: torch.Size([100, 100, 5, 11])) to satisfy the constraint Simplex(), but found invalid values: tensor([[[[ 1.2426e+00, -1.2945e+00, 4.1014e-01, ..., 5.5622e-01, -6.7214e-01, -1.2349e+00], [-1.7248e-01, 6.4287e-02, -7.4881e-01, ..., 4.6214e-01, 7.5912e-01, 1.8682e-01], [ 6.3147e-01, 4.5790e-01, -3.2810e-01, ..., 3.1173e-01, 2.7938e-01, 3.7275e-01], [ 1.9841e+00, 7.4553e-01, -6.1727e-01, ..., -8.2579e-01, -1.8078e+00, -5.4283e-01], [ 4.3695e-01, 1.6643e-02, -1.7423e-01, ..., 6.6712e-01, -5.9217e-01, -7.6138e-01]],

     [[ 1.5075e+00, -3.4445e+00, -2.6291e+00,  ...,  8.3555e-01,
        2.2945e+00, -5.3965e-02],
      [-2.0584e-01,  2.0469e+00,  3.3165e-01,  ..., -8.6847e-01,
        1.9418e-01,  9.3736e-01],
      [ 1.3200e+01, -7.2059e+00, -7.2506e-01,  ...,  3.0440e+00,
        2.1811e+00, -1.9477e+00],
      [ 8.5257e-01, -5.0674e-01,  2.2239e+00,  ...,  3.3165e-01,
        5.6384e-01,  3.9233e-01],
      [ 4.9769e-01, -3.7434e-01,  9.3445e-02,  ..., -4.1689e+00,
        1.1729e+00, -1.9694e+00]],

     [[ 2.8376e-01,  4.2903e+00, -1.6735e+00,  ..., -4.2179e-01,
        2.0839e+00,  2.7028e-01],
      [ 1.2377e+00,  3.3001e+00,  1.7086e+00,  ..., -1.5953e+00,
        4.2362e-01, -1.8910e+00],
      [ 1.2417e-01,  1.8337e+00,  2.2134e+00,  ..., -4.1177e-01,
       -2.0779e+00, -6.5099e-01],
      [ 2.7558e-01,  6.8455e-01,  2.6409e-01,  ...,  8.5417e-01,
       -3.8728e-01, -2.9246e-01],
      [ 1.8247e+00,  5.1626e-01, -1.3113e+00,  ...,  1.2756e+00,
       -9.6114e-01,  4.7476e-01]],

     ...,

     [[ 4.9671e-01,  1.1875e+00, -3.7837e-01,  ...,  1.8017e+00,
        9.7922e-02,  9.0826e-01],
      [ 4.0105e-01, -6.0412e-01, -2.3511e+00,  ..., -1.4305e-01,
        3.4514e+00, -3.4735e-01],
      [-1.6592e+00, -3.1455e+00,  3.4837e+00,  ...,  3.0970e+00,
        1.4824e+00, -1.6712e+00],
      [ 2.2480e+00,  2.1861e+00,  2.0778e+00,  ..., -4.3291e+00,
        4.2960e+00,  8.5239e-02],
      [ 4.7889e-01,  4.5080e-01,  5.7547e-01,  ..., -6.0323e-01,
       -3.5546e-01, -3.4922e-01]],

     [[ 6.6736e-01, -1.9598e-01,  4.9250e-01,  ...,  1.0089e+00,
       -6.2808e-01,  3.1778e-01],
      [-6.8272e-01,  1.0324e+00, -2.4395e+00,  ..., -8.7298e-01,
       -8.6526e-01,  6.1645e-01],
      [ 3.4440e-01, -7.2880e-01,  2.2529e-01,  ..., -5.1057e-02,
       -2.4368e-01,  2.5231e-01],
      [ 6.9983e-01, -4.9519e-01,  4.1547e-01,  ...,  4.5329e-01,
       -4.5527e-01,  7.5715e-02],
      [ 6.6658e-01,  9.1576e-02,  5.3606e-01,  ..., -5.4163e-01,
        1.7202e+00,  8.6003e-02]],

     [[-3.9657e+00, -1.4984e+00,  2.5009e+01,  ..., -1.9405e+01,
        2.4046e+01, -1.0344e+01],
      [ 7.6318e-01,  6.2132e-01, -1.6545e-01,  ...,  4.0838e-01,
       -1.0056e+00,  4.0224e+00],
      [-1.6237e-01,  7.6003e-01, -1.5722e+00,  ...,  9.8846e-01,
       -7.1669e-01, -5.8512e-01],
      [ 2.2718e+00,  1.0969e+00,  5.8637e-01,  ..., -1.7726e+00,
       -8.3147e+00,  2.6415e+00],
      [ 4.5803e-01,  1.9375e-01,  6.8278e-01,  ..., -6.5108e-01,
       -7.5676e-02,  1.1231e+00]]],

    [[[-6.6555e-01,  5.1884e-01,  3.0740e+00,  ..., -2.3698e+00,
        2.3828e+00,  1.4393e+00],
      [ 6.6651e-01, -2.8901e-01,  1.0364e-01,  ...,  7.0320e-02,
        7.2539e-01,  2.4268e-01],
      [ 4.0248e-01,  4.1511e-01,  2.9718e-01,  ...,  7.2417e-01,
        6.9254e-02, -7.5353e-01],
      [-1.6744e+01, -2.1953e+01,  1.3020e+01,  ...,  2.0297e+01,
       -1.2317e+01,  1.7973e+01],
      [ 6.4031e-01, -2.7768e-01, -1.1653e-01,  ...,  2.4048e-01,
        2.4544e-01,  2.7243e-01]],

     [[ 1.6415e+00,  8.1022e-01,  1.0484e+00,  ...,  2.1372e+00,
       -4.6998e-01, -6.2497e-01],
      [ 2.9397e-01, -2.7979e-02,  1.0541e+00,  ..., -7.6941e-01,
       -1.6443e-01,  2.7037e-01],
      [ 9.9372e-01, -3.7063e-01,  4.6666e-01,  ...,  1.0245e+00,
        2.1319e-01, -3.5092e-01],
      [ 2.4373e-02, -8.9471e+00,  6.7321e+00,  ..., -4.2705e+00,
        1.5750e+00,  2.9782e+00],
      [ 2.0529e-01,  1.6676e+00,  6.5215e-01,  ..., -6.6583e-01,
        1.9530e-01, -1.3662e-01]],

     [[ 3.3094e-01, -4.1690e-01,  1.4120e-01,  ...,  9.2762e-01,
        3.4857e-01, -4.6941e-01],
      [ 7.5478e+00,  6.1822e+00, -3.2198e+00,  ..., -8.0159e-01,
       -6.3349e-02, -6.2434e+00],
      [ 1.3450e+00, -1.4053e+00, -1.1709e+00,  ..., -5.2519e-01,
       -9.9739e-01, -1.1001e+00],
      [-2.4680e-01,  6.5710e-01, -1.4691e+00,  ...,  1.4205e+00,
        1.8427e+00,  1.1796e-01],
      [ 4.1336e-01,  8.5022e-02,  9.2853e-02,  ...,  3.2010e-01,
        2.9028e-01,  1.8401e-01]],

     ...,

     [[-1.1544e+00,  1.5420e+00,  1.7679e+00,  ..., -6.7182e-01,
       -3.6764e-02, -2.2441e+00],
      [ 1.7255e+00, -9.9439e-01, -1.2645e+00,  ...,  1.1284e+00,
        9.0534e-02,  1.7442e+00],
      [-2.0325e+00,  1.5082e+00,  3.8115e+00,  ..., -4.3420e+00,
        9.2502e-01, -6.1594e+00],
      [ 4.1424e-01, -4.8287e-01,  1.5418e-01,  ..., -2.5157e-01,
        4.0651e-01,  9.1724e-01],
      [ 1.0154e+00, -5.3099e-01,  1.0433e+00,  ...,  1.1691e+00,
       -6.0312e-01, -3.2987e-02]],

     [[-2.6123e-02,  7.3750e-01, -1.7682e-01,  ...,  9.0202e-01,
        1.9137e-01,  1.4555e+00],
      [ 6.1389e-01,  2.6924e-01,  1.4631e+00,  ...,  7.3595e-01,
        8.0151e-01, -6.3749e-01],
      [ 5.2367e-01, -1.0993e+00,  4.3189e-01,  ..., -1.8716e+00,
        5.6906e-01, -4.0940e-01],
      [ 4.9998e-01, -1.3751e+00, -8.6454e-01,  ..., -1.3774e+00,
       -2.5627e-01, -3.8065e-02],
      [-7.3349e-01,  4.3705e-01,  3.5770e+00,  ..., -1.2184e+00,
       -9.5621e-01, -2.6853e+00]],

     [[ 1.7939e+00, -1.2353e+00,  1.3519e+00,  ...,  6.9516e-01,
        1.2708e+00,  7.4168e-01],
      [-6.7631e-01, -1.2218e+00, -2.7282e-01,  ..., -7.6156e-01,
        9.4981e-01, -3.5000e-01],
      [ 5.5513e-01,  7.5819e-01,  1.6906e-01,  ...,  5.1887e-01,
       -1.3181e-02,  8.2934e-02],
      [ 7.8362e-01, -2.5555e-01, -2.0774e+00,  ..., -4.8269e-01,
        1.4186e+00, -3.9613e-01],
      [ 1.4990e+00,  1.3719e+00, -3.4172e+00,  ..., -2.6224e+00,
       -6.0584e+00,  9.6879e+00]]],

    [[[ 5.7833e-01,  4.5936e+00,  1.2302e+01,  ..., -1.4435e+01,
        9.4017e+00,  9.5336e-01],
      [-5.4774e-02,  1.1083e+00,  3.0625e-01,  ..., -1.9470e-01,
       -3.0002e-01, -4.2560e-01],
      [ 8.9458e-01, -3.2593e-01,  7.6646e-02,  ..., -8.6837e-02,
       -6.0646e-01,  6.9351e-01],
      [-3.1247e-01,  5.2223e-01, -3.3424e+00,  ..., -1.9283e+00,
       -1.3650e+00,  1.7507e+00],
      [ 3.2635e-01, -8.6947e-01,  8.3852e-02,  ..., -3.4544e-01,
       -8.9523e-02,  2.5219e-01]],

     [[ 7.7817e-01, -8.5250e-01,  2.8413e-01,  ...,  8.7812e-01,
       -1.0428e+00, -1.3697e-01],
      [-3.6143e+00,  4.5287e+00, -1.2355e+01,  ...,  1.1151e+01,
       -1.1201e+01, -1.5772e+01],
      [ 4.4204e-01,  4.5941e-01,  1.4200e+00,  ..., -1.4958e+00,
        9.9831e-01, -3.0248e-01],
      [ 8.4145e-01, -5.4955e-02, -8.5717e-01,  ...,  1.9797e+00,
       -4.0240e-01,  1.7958e+00],
      [-9.7922e-02,  1.1579e+00, -1.1985e+00,  ...,  9.4564e-01,
        6.3346e-02, -4.4363e-01]],

     [[ 1.4148e+00,  1.0272e+00, -1.5679e-01,  ..., -9.9173e-01,
       -4.1617e+00,  9.9793e-01],
      [ 2.2756e-01, -5.4218e-02,  5.5711e-01,  ..., -5.2833e-01,
       -4.2351e-02,  8.4122e-01],
      [-2.4643e+00, -8.9988e-01,  3.7735e+00,  ..., -2.7195e+00,
        4.3069e+00,  1.3947e+00],
      [ 6.3190e-01, -1.7120e-01,  2.0941e-01,  ..., -1.2354e-01,
        5.8095e-01,  1.0262e-01],
      [ 5.8950e-01,  3.5717e-01,  4.0016e-01,  ...,  2.1093e-01,
        1.1742e-01, -7.6782e-01]],

     ...,

     [[-1.9296e+00,  2.0762e-01,  8.3194e-01,  ...,  5.6747e-02,
        2.8068e+00,  5.9143e-01],
      [ 1.0056e+00,  9.8933e-02,  1.0071e+00,  ...,  1.0832e+00,
        5.7345e-02, -1.9594e+00],
      [-6.6266e-03, -3.1431e-01, -2.9666e-01,  ...,  4.4572e-01,
        9.9407e-01, -6.2249e-02],
      [-2.6428e+00, -7.0332e-01,  9.8558e-01,  ...,  1.9927e+00,
        2.0943e+00,  2.2718e+00],
      [ 3.6267e+00, -4.7667e-01,  1.4884e+00,  ..., -1.1679e+00,
        7.7638e-02,  1.2573e+00]],

     [[ 1.1881e+00,  1.8042e-01,  4.6982e-01,  ...,  7.4913e-01,
       -7.1513e-01,  2.7843e-02],
      [ 1.1181e-01,  6.3832e-01, -2.0892e-01,  ...,  9.0355e-02,
       -8.3373e-02, -3.4605e-01],
      [ 9.5217e-01, -4.3010e-01, -3.9673e-01,  ...,  1.0727e+00,
       -2.3093e-01,  5.4155e-01],
      [ 5.2008e-02,  3.7345e-01,  1.9147e-01,  ...,  7.9313e-01,
       -1.0457e+00, -2.7206e-01],
      [ 1.0753e+00,  5.5819e-01,  7.0341e-01,  ..., -5.4942e-01,
       -5.0333e-01,  6.3525e-01]],

     [[ 5.7962e-02,  4.2277e-01,  3.8236e-01,  ..., -5.1404e-01,
        8.8852e-01,  1.8998e-01],
      [ 5.3058e-01,  2.2198e-01, -4.9676e-02,  ..., -1.4092e-01,
        2.2237e-01, -2.0980e-01],
      [-4.8962e+01,  1.1259e+01,  5.6544e+01,  ..., -4.0250e+01,
       -2.7558e+01, -3.0411e+01],
      [-1.0543e+01, -4.6906e+00,  1.3036e+01,  ..., -2.7036e-02,
       -2.4293e+00, -1.0603e+01],
      [ 1.4160e+00,  5.3356e-01, -1.9352e+00,  ...,  1.2995e+00,
        3.9623e-01,  8.5555e-01]]],

    ...,

    [[[ 9.4254e-01,  5.1876e-01,  3.7956e-01,  ..., -4.1182e-01,
        1.4662e+00,  1.4871e+00],
      [ 3.7801e-02,  1.1907e+00,  7.0360e-01,  ...,  2.1115e-01,
        2.9789e-01,  5.0459e-01],
      [-6.5351e+00, -3.1801e+00,  2.4453e+00,  ..., -1.1528e+00,
        7.1053e+00,  3.1471e+00],
      [ 9.9387e-01,  4.7417e-01,  4.0640e-01,  ...,  8.3128e-01,
       -1.5710e-01,  3.1348e-01],
      [ 3.0644e-02,  2.1654e+00, -3.2200e-01,  ...,  2.9643e+00,
       -5.0624e-01,  6.8427e+00]],

     [[-4.5878e+00,  3.3556e+00, -7.1328e+00,  ..., -7.6168e+00,
        1.5838e+01, -1.4438e+01],
      [ 1.3205e+01,  6.6298e+00,  4.7031e+01,  ..., -2.0646e+01,
       -1.8688e+01,  6.8704e+01],
      [ 1.1237e+00,  2.3903e+00, -1.2417e+00,  ...,  1.7067e+00,
       -3.1532e+00, -1.9685e+00],
      [-4.7536e-02, -7.7557e-01, -5.1901e-01,  ...,  7.7719e-01,
        1.7095e+00,  2.2881e+00],
      [ 1.5270e+00,  2.2445e+00,  1.3476e+00,  ..., -3.8061e-01,
        5.2737e-01, -3.4140e-01]],

     [[-4.1516e-01, -6.1130e-01,  1.7781e+00,  ..., -4.2010e-01,
        2.4555e+00,  1.3414e+00],
      [ 4.4015e-01,  3.5935e-01,  6.7883e-01,  ..., -4.6652e-01,
       -1.2027e+00,  5.6941e-02],
      [ 7.6944e-01, -6.9579e-01,  5.1049e-01,  ...,  6.6508e-01,
        1.4991e-01, -3.0746e-01],
      [ 9.5676e-01,  1.0666e+00,  2.7229e+00,  ..., -1.3453e+00,
       -2.7133e+00,  9.0167e-04],
      [ 1.9745e-02, -1.6052e+00,  2.6452e+00,  ..., -1.0596e+00,
        8.1113e-01,  7.7127e-01]],

     ...,

     [[-4.9620e+00,  2.8767e+00, -1.5517e+00,  ..., -6.7731e+00,
       -3.9575e+00,  3.0952e-01],
      [ 5.0870e-01, -3.2030e-01,  4.9990e-01,  ...,  3.4475e-01,
        2.8955e-02,  2.0354e-01],
      [ 6.9547e-01,  5.1287e-01, -1.4381e+00,  ..., -5.0634e-01,
        4.5152e-01,  1.8553e-01],
      [ 4.1317e-01,  7.2432e-01,  1.3053e+00,  ..., -6.4745e-01,
        4.6445e-01,  5.1297e-01],
      [ 6.1733e-01, -1.9964e+00,  1.5222e+00,  ..., -2.9912e-01,
        4.2244e-01,  1.2825e+00]],

     [[ 4.7489e-01,  4.7182e-01,  6.6222e-01,  ...,  1.1634e+00,
        4.2618e-01, -2.2452e+00],
      [ 1.8447e+00, -3.4247e-01,  4.3219e+00,  ..., -2.6348e+00,
       -7.0686e-02, -5.1091e+00],
      [ 1.6583e-01,  9.2552e-01,  1.2414e+00,  ..., -6.8017e-01,
        3.9580e-01,  2.7543e-01],
      [ 4.1585e+00,  2.1752e+00,  2.9503e+00,  ...,  5.6982e-01,
        6.1948e-01,  9.9855e-01],
      [ 5.5671e-01,  3.4761e-01, -1.2956e+00,  ..., -3.1286e-01,
        1.2305e+00,  1.0559e+00]],

     [[ 2.3741e+01, -5.9836e+01,  5.4721e+01,  ..., -3.8232e+00,
        4.5851e+00,  3.2750e+01],
      [ 1.0989e+00, -2.8977e+00,  3.5956e+00,  ...,  2.5251e-01,
       -2.2446e+00, -1.3616e+00],
      [-1.3798e-01,  2.1394e+00,  2.0645e-01,  ...,  4.6853e-01,
        3.7650e+00, -5.4336e-01],
      [ 5.8165e-01, -1.2998e+00,  2.1851e-01,  ..., -1.6847e+00,
       -6.0834e-01,  7.0311e-01],
      [ 3.9262e-01,  1.8806e+00,  4.1490e+00,  ...,  1.0796e+00,
       -8.9924e-01, -4.6079e-01]]],

    [[[ 2.2265e-01, -2.0570e-01,  3.1797e-02,  ...,  3.7660e-02,
       -4.1366e-01,  2.0542e-01],
      [ 5.8108e-01,  7.6933e-01, -5.1310e-01,  ..., -1.6667e-01,
        2.4480e-01,  4.2667e-01],
      [-3.0719e+00,  2.4038e+00,  3.3761e+00,  ...,  8.7145e-01,
        1.2536e+00, -5.6292e+00],
      [ 8.9647e-01,  9.3441e-01,  3.8667e-01,  ..., -1.1763e-01,
        1.0418e-01, -2.4992e-01],
      [ 1.0936e+01, -6.2251e+01,  4.3816e+01,  ..., -4.3446e+01,
       -2.9938e+01,  6.8241e+01]],

     [[-1.8638e-01, -2.4923e-01, -1.1012e+00,  ...,  1.8603e-01,
       -3.2592e-01,  2.3948e-01],
      [ 3.9760e+00, -4.1706e+00,  5.0572e-01,  ...,  6.0632e+00,
       -3.4053e+00,  1.1460e+01],
      [ 9.2532e-01,  1.3921e+00, -1.7843e-01,  ..., -1.2403e+00,
       -9.8856e-01,  1.0621e+00],
      [ 4.3452e-02, -1.6717e-01, -5.7880e-02,  ...,  1.4633e+00,
        1.0175e-01, -9.2251e-01],
      [ 1.2661e+00,  3.1550e-02,  4.6298e-02,  ..., -7.8469e-02,
       -1.8855e+00, -4.5504e-01]],

     [[ 4.6267e-01, -9.4152e-01,  6.4885e-01,  ..., -3.4818e-01,
       -7.6115e-01,  2.2274e+00],
      [ 7.2543e-01,  2.6525e+00,  4.4636e-01,  ..., -9.3224e-01,
        4.5972e+00, -3.1881e+00],
      [-2.1488e+00, -5.1438e-01, -8.9369e+00,  ..., -1.2142e+00,
        2.1087e+00, -2.0650e+01],
      [-1.3089e-01,  8.0855e-01,  3.3616e-01,  ...,  2.5241e-01,
        8.3287e-02,  3.3158e-01],
      [ 2.5745e+00, -1.1222e+00, -3.4501e+00,  ...,  1.9027e+00,
       -2.6546e+00, -4.7814e-01]],

     ...,

     [[-8.4899e-01,  4.5336e-01, -5.7752e-01,  ...,  6.3741e-01,
       -2.6161e-01,  1.2684e+00],
      [ 6.0437e-01,  6.3276e-02, -1.3153e-01,  ..., -3.1852e-03,
        1.8475e-01,  1.0924e-01],
      [-3.0929e+00, -3.1014e+00,  5.9858e-01,  ..., -4.4175e-01,
        1.2333e+00,  7.4185e+00],
      [ 1.6183e+00,  1.0251e+00, -1.3564e-01,  ..., -1.1713e+00,
       -6.7476e-01,  9.0628e-01],
      [-1.6322e+00,  2.2092e+00, -1.2780e+00,  ...,  3.3654e+00,
        3.6797e+00, -5.7386e+00]],

     [[ 7.6706e-01, -2.6295e-01,  8.4968e-01,  ...,  6.1631e-01,
       -3.7476e-01, -1.2049e+00],
      [-2.4539e-01,  1.2085e-01,  3.7377e-01,  ...,  1.1437e+00,
        1.1793e+00, -1.1422e-01],
      [ 6.3429e-01,  8.0557e-01,  5.9677e-01,  ..., -2.5139e-01,
       -3.2989e-01, -1.5825e-01],
      [ 1.0353e+00, -5.4504e-01,  1.1037e+00,  ..., -5.5962e-01,
        1.4610e-01, -1.9254e+00],
      [ 8.9514e-02,  1.0989e+00,  1.5606e+00,  ..., -7.3987e-01,
        1.9311e+00, -2.2786e+00]],

     [[ 1.2851e+00, -1.9851e+00, -4.5494e+00,  ...,  3.7657e+00,
        4.5665e+00, -9.2168e-01],
      [-1.7239e-01,  4.6246e-01,  3.7013e-01,  ..., -6.1598e-01,
        6.5582e-01, -1.4461e+00],
      [-1.1701e+00,  4.2045e-01,  3.5514e-01,  ..., -2.9292e+00,
       -2.5521e+00,  2.6618e-01],
      [ 1.1595e+00, -8.0284e-01,  1.5254e+00,  ...,  9.7177e-01,
        9.2073e-01,  4.3168e-01],
      [-5.2398e+00,  1.4340e+01, -6.2417e+01,  ..., -9.0160e+00,
        1.3309e+01,  2.2005e+01]]],

    [[[-4.4581e+00, -6.7095e+01, -3.2797e+01,  ...,  1.1565e+01,
       -1.3261e+01, -1.4874e+01],
      [ 4.1974e-01,  3.6623e+00, -3.5843e+00,  ...,  1.0202e+00,
       -1.4485e+00, -1.9347e+00],
      [-4.6281e+00,  1.2632e+01, -8.6199e+00,  ...,  1.7171e+01,
       -1.8479e+01, -1.8846e+01],
      [ 3.5012e-01, -8.2658e-01,  8.9016e-01,  ..., -8.7550e-01,
        1.2175e-01, -1.5617e+00],
      [-3.3812e+00,  6.0930e+00,  5.5390e+00,  ..., -3.2472e+00,
       -2.8270e+00,  4.2411e-01]],

     [[-3.7637e+00, -1.7642e+00, -2.3450e+00,  ..., -3.4761e+00,
        4.4729e+00, -1.9583e+00],
      [ 1.1669e+00,  9.0970e-01, -2.6114e+00,  ..., -1.8751e-01,
        1.1157e+00,  1.8243e+00],
      [ 1.0297e-01,  5.3322e-01, -3.3761e-01,  ...,  5.7520e-02,
       -5.7774e-02, -3.6703e-01],
      [ 1.5873e+00, -2.0804e+00,  1.0169e+00,  ..., -6.6074e-01,
       -1.9426e+00, -1.2487e-01],
      [-6.2900e-01,  1.2562e+00, -3.7006e-01,  ...,  2.1023e-01,
        3.7004e-02,  1.8923e+00]],

     [[ 8.4332e-01, -4.3401e-02, -3.0546e-01,  ..., -8.0304e-01,
       -6.1729e-01,  5.2937e-01],
      [-2.1099e+01, -2.7131e+01, -2.8109e+02,  ..., -1.0522e+02,
        2.8456e+01,  2.6699e+01],
      [-2.3551e-01, -8.7806e-02,  9.8973e-01,  ..., -4.2571e-01,
       -9.5759e-01,  1.3006e-01],
      [ 6.8565e-01,  3.3882e-01, -6.1284e-01,  ..., -2.1239e-01,
       -1.9207e-02,  1.1479e+00],
      [-9.2754e-01,  8.7619e-01,  6.3741e-01,  ...,  7.7548e-01,
        2.8069e+00, -2.3306e+00]],

     ...,

     [[ 2.2937e+00,  8.0051e+00,  3.9795e+00,  ..., -2.5096e+00,
        5.5701e+00, -2.5075e+00],
      [-3.2691e-01, -1.3848e+00, -1.8447e+00,  ...,  5.6993e-01,
       -1.7544e+00,  1.1677e+00],
      [ 1.3561e+00,  3.8136e-01,  1.4107e+00,  ...,  1.9896e+00,
       -6.5320e-01, -4.6294e+00],
      [-2.3471e-01, -5.2743e-01, -5.6733e-01,  ...,  4.3915e-01,
        5.0340e-01,  1.1045e+00],
      [ 4.1318e+00,  5.7137e-01, -1.5194e+00,  ...,  2.8802e+00,
       -7.9121e+00,  4.0104e+00]],

     [[ 4.8319e-01,  2.7313e-01,  3.9362e-01,  ...,  3.3351e-02,
       -2.2658e-01, -2.1984e-01],
      [-6.3220e+00, -9.0551e+00,  6.8837e+00,  ...,  5.3650e+00,
       -5.2190e+00,  1.1687e+00],
      [ 2.1542e+00,  1.1602e+00, -8.8548e-02,  ...,  9.0859e-01,
       -2.5196e+00, -1.3161e-01],
      [ 5.4348e-01, -7.5089e-01,  2.0165e+00,  ...,  4.7323e-01,
        4.6918e-01,  8.5174e-01],
      [ 1.2587e+00, -5.3498e+00,  2.2615e+00,  ...,  4.3795e-02,
       -1.9938e+00,  3.1597e+00]],

     [[-2.8105e-01, -3.4248e+00,  1.7655e+00,  ..., -8.7769e-01,
        3.2186e+00, -6.3708e+00],
      [ 3.4623e-01,  3.0659e-01, -1.3735e-01,  ...,  5.6169e-02,
       -1.6627e+00,  1.1547e+00],
      [-3.5326e+00,  6.8798e-01, -9.7799e-01,  ...,  4.0571e+00,
       -2.0529e-01,  5.0493e-01],
      [ 4.6537e-01,  3.3993e-01,  3.9738e-01,  ..., -4.7107e-02,
        1.0712e-01, -3.3944e-01],
      [ 1.1307e+00,  2.1601e+00,  1.0504e-01,  ...,  6.5564e-01,
       -1.1295e+00, -2.6074e+00]]]], device='cuda:0',
   grad_fn=<DivBackward0>)

Exception ignored in: <function PyCUDASampler.del at 0x7f66e6914830> Traceback (most recent call last): File "/home/warp/github/warp-drive/warp_drive/managers/pycuda_managers/pycuda_function_manager.py", line 510, in del File "/home/warp/anaconda3/envs/warp_drive/lib/python3.7/site-packages/pycuda/driver.py", line 480, in function_call pycuda._driver.LogicError: cuFuncSetBlockShape failed: invalid resource handle

Emerald01 commented 1 year ago

This seems a pure Pytorch.Categorical issue when the round off error raised an exception. https://discuss.pytorch.org/t/distributions-categorical-fails-with-constraint-simplex-but-manual-check-passes/163209/3 Actually in V100 or A100 GPUs I do not see any issue, so maybe slightly differences in GPUs.