stanfordnlp / pyvene

Stanford NLP Python Library for Understanding and Improving PyTorch Models via Interventions
http://pyvene.ai
Apache License 2.0
589 stars 55 forks source link

device of unit_locations should follow tensor_input #171

Closed aryopg closed 1 month ago

aryopg commented 1 month ago

Description

In modeling_utils.py, there is a part of the code that causes error when using GPU. It seems that it's because the unit_locations variable was forced to be in cpu.

Testing Done

No test done yet, just to flag the issue and to raise the attention of the main contributors

Checklist:

frankaging commented 1 month ago

@PinetreePantry can you take a look at this one? thanks!