AnswerDotAI / bert24

Apache License 2.0
60 stars 3 forks source link

Add weight decay filtering and StableAdamW #57

Closed warner-benjamin closed 3 months ago

warner-benjamin commented 3 months ago

This PR adds the StableAdamW optimizer as an option. It requires installing optimi pip install torch-optimi. StableAdamW removes the need for gradient clipping, and I've found it to be a pareto improvement over AdamW.

This PR also adds the filter_bias_and_bn option, which prevents weight decay from being applied to linear bias terms and normalization layers. I left it as false to match the current defaults (except in a test) but given its a best practice, we should use it for all of our training.