BlackSamorez / tensor_parallel

Automatically split your PyTorch models on multiple GPUs for training & inference
MIT License
629 stars 39 forks source link

fix recursion error when setting tp_wrapped_module #122 #123

Open Ar-Kareem opened 1 year ago

Ar-Kareem commented 1 year ago

To fix the error mentioned in #122 This happens whenever the attribute tp_wrapped_module is changed (for example inside LoRA or other PEFT methods)

I am not 100% sure this works as I have not started training yet but it has certainly made my example in #122 work as expected.

BlackSamorez commented 1 year ago

Fix the style, please. black . && isort .

Ar-Kareem commented 1 year ago

Fixed styling, although I'm not sure why the tests are failing.

BlackSamorez commented 1 year ago

The failing tests are not related to this issue. The Falcon-40B related tests fail to load the model off the web. I'll look into it. I'll also look into what you've done in more detail because it's always complicated with getattr/setattr.

Ar-Kareem commented 1 year ago

Sure, all I did is made sure to re-add tp_wrapped_module to __dict__, this is necessary for any use case where you need to do wrapper.tp_wrapped_module = some_module (which is what I wanted LoRA to do) And when that's executed, obviously the set_attr for nn.Module will be called and specifically the below line will clear tp_wrapped_module from the __dict__ module.py#L1726 Thus the get_attr for the wrapper will cause a recursion error as doing self.tp_wrapped_module will recursively call get_attr