Closed laggui closed 8 months ago
@antimora tagging you specifically regarding the pattern we might want to adopt for pre-trained weights as discussed in #16
Thank you! I'll review it soon.
Also fixed some minor issues with the bigger ResNet models.
This PR adds new methods to directly initialize a ResNet-{18, 34, 50, 101, 152} with ImageNet pre-trained weights from
torchvision
.The weights are automatically downloaded from the web to a default
~/.cache/resnet-burn/
folder usingdownload_file_as_bytes
which provides a progress bar.Changes:
init
andinit_with
methods following good practicepretrained
feature flagweights
moduleresnet*_pretrained
methodsBecause loading the pre-trained weights requires a fix not yet in a released version of
candle-core
, the current burn dependency is pinned to a specific revision that pins the correct dependency.TODO:
candle-core
version is released and included in a burn release/patch