LaurentMazare / tch-rs

Rust bindings for the C++ api of PyTorch.
Apache License 2.0
4.39k stars 347 forks source link

Embedding gradient seems to be incorrect when using batch values #271

Closed bumblepie closed 4 years ago

bumblepie commented 4 years ago

I've been trying to set up a basic DQN agent to learn the OpenAI Taxi Environment. I've been trying with a simple network with just an embedding layer which is basically a Q table, but it only seems to work when I do updates on batches of size one - if it's set to anything higher, it tries to adjust winning scores the wrong way and so never learns the environment properly. I've distilled it down to the basics in the following gists: Rust: https://gist.github.com/bumblepie/e67ab01118d3ccc561a38edaa37d219a Python: https://gist.github.com/bumblepie/3e619c48fc9e09d4cd9a5f851ab5edca The problem seems to be that the gradient isn't calculating correctly for the embedding layer's weights - there's a mismatch between pytorch and tch-rs. Since the gradient is positive when it should be negative, that explains the behaviour described above. I've taken a look at the source code, and I'm not entirely sure where the error would be but it's probably something to do with the embedding layer, backwards propogation and reduction. It also occurs when using mean squared error, so it's not just the huber loss function that's wrong.

LaurentMazare commented 4 years ago

Thanks for reporting the issue and trimming it down to something easy to run/debug. I think there is a small typo on your snippets, the indexes are different between the two of them.

expected_state_action_values = torch.Tensor(
    [lookup[0][0], lookup[1][0], lookup[2][1]]).view([-1,1])
let expected_state_action_values =
    Tensor::of_slice(&[lookup[1][0], lookup[2][1], lookup[0][0]]);

However even after tweaking this I agree that the results are different. I think it comes from state_action_values having shape 3x1 and expected_state_action_values having shape 3. Adding .view([-1, 1]) in the rust definition of this last variable seems to bring the results in line if I'm not mistaking.

bumblepie commented 4 years ago

That seems to have fixed the issue, and the DQN agent seems to be working now, thanks for the help!