jbloomAus / DecisionTransformerInterpretability

Interpreting how transformers simulate agents performing RL tasks
https://jbloomaus-decisiontransformerinterpretability-app-4edcnc.streamlit.app/
MIT License
61 stars 15 forks source link

Swapped device to str instead of torch.device to fix PyArrow crash problem #102

Closed JayBaileyCS closed 11 months ago

JayBaileyCS commented 11 months ago

Tests performed:

Caveat: Entirely possible that this may cause errors in places neither I nor the code tested. On the plus side, fixing it appears easy - there doesn't appear to be any reason I can determine why the device has to be torch.device. The _get_device_index util in PyTorch that is used in torch.device can take in a string just fine - it just internally parses it to a torch.device before doing its thing. Thus, if we run into further errors (e.g, maybe we want to train a new model later) we should easily be able to swap everything to strings without causing arcane BS to occur.

In short, I think it is better to make sure the code we are using atm works if any errors that occur later are obvious and easy to fix - any side-effects that may arise won't take much time to fix, and won't cause subtle issues we don't notice.

Docs: https://github.com/pytorch/pytorch/blob/main/torch/cuda/_utils.py - _get_device_index https://pytorch.org/docs/stable/_modules/torch/cuda.html#device - torch.device