The previous collate_fn was occasionally causing issues with PyTorch Lightning. A new collator class has been introduced in the hope of improving this functionality.
Loss Function Handling:
Removed the separate loss_function from eval_step as it was too confusing.
Variational loss is now recorded manually.
Shape loss will also need to be manually added to loss_dict if it is to be recorded.
Although mutating objects feels a bit icky, it’s the cleanest approach for readability.
Ideally, a specific Loss object that inherits from ModelOutput would be preferred.
Evaluation Step:
eval_step must still call super() for mixins to function properly.
Test Improvements:
Tests have been updated to be more future-proof, sensible, and to use configuration signatures in case they change.
The tests now better reflect how users should interact with the code.
Jupytext & Notebooks:
Added a jupytext.toml for notebooks and included a basic notebook file (simple.py) in the scripts directory.
Pull Request Summary
Collator Class:
collate_fn
was occasionally causing issues with PyTorch Lightning. A new collator class has been introduced in the hope of improving this functionality.Loss Function Handling:
loss_function
fromeval_step
as it was too confusing.loss_dict
if it is to be recorded.Loss
object that inherits fromModelOutput
would be preferred.Evaluation Step:
eval_step
must still callsuper()
for mixins to function properly.Test Improvements:
Jupytext & Notebooks:
jupytext.toml
for notebooks and included a basic notebook file (simple.py
) in thescripts
directory.