SequentialDict always tries to find a corresponding value in a DotDict. This changes SequentialDict to only lookup values in a DotDict when it is a str. Other values are passed directly to the underlying module.
Type of change
Please check all relevant options.
[x] Improvement (non-breaking)
[ ] Bug fix (non-breaking)
[ ] New feature (non-breaking)
[ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
[ ] This change requires a documentation update
Testing
Please describe the tests that you ran to verify your changes. Consider listing any relevant details of your test configuration.
This is useful if you wanted to do a weighted sum of losses. Right now you would have to pre-define the number and values of these weights. However, this enables us to pass the literal weight values via a sequence.
What does this PR do?
SequentialDict
always tries to find a corresponding value in aDotDict
. This changesSequentialDict
to only lookup values in aDotDict
when it is astr
. Other values are passed directly to the underlying module.Type of change
Please check all relevant options.
Testing
Please describe the tests that you ran to verify your changes. Consider listing any relevant details of your test configuration.
pytest
CUDA_VISIBLE_DEVICES=0 python -m mart experiment=CIFAR10_CNN_Adv trainer=gpu trainer.precision=16
reports 70% (21 sec/epoch).CUDA_VISIBLE_DEVICES=0,1 python -m mart experiment=CIFAR10_CNN_Adv trainer=ddp trainer.precision=16 trainer.devices=2 model.optimizer.lr=0.2 trainer.max_steps=2925 datamodule.ims_per_batch=256 datamodule.world_size=2
reports 70% (14 sec/epoch).Before submitting
pre-commit run -a
command without errorsDid you have fun?
Make sure you had fun coding 🙃