MisaOgura / flashtorch

Visualization toolkit for neural networks in PyTorch! Demo -->
https://youtu.be/18Iw4qYqfPo
MIT License
734 stars 87 forks source link

Fix/cuda transfer error #3

Closed MisaOgura closed 5 years ago

MisaOgura commented 5 years ago

This PR fixes an error mentioned in #2. This was happening due to target tensor not being transferred to GPU.

This PR also introduces an API change to Backprop.calculate_gradient method. Users can now explicitly set which device to use via use_gpu flag, which is set to False by default.