openai / weak-to-strong

MIT License
2.5k stars 305 forks source link

Bugfix: last minibatch padding #31

Closed ojh31 closed 8 months ago

ojh31 commented 8 months ago

To avoid the error TypeError: TransformerWithHead.forward() missing 1 required positional argument: 'input_ids', similar to https://github.com/pytorch/pytorch/issues/15161 when using dataparallel and the last minibatch gets compressed, we add right-padding to input_ids and then take off the extra logits/labels that this creates.