tracel-ai / models

Models and examples built with Burn
Apache License 2.0
180 stars 24 forks source link

Add ImageNet pre-trained weights under `pretrained` feature flag #18

Closed laggui closed 8 months ago

laggui commented 8 months ago

This PR adds new methods to directly initialize a ResNet-{18, 34, 50, 101, 152} with ImageNet pre-trained weights from torchvision.

The weights are automatically downloaded from the web to a default ~/.cache/resnet-burn/ folder using download_file_as_bytes which provides a progress bar.

Changes:

Because loading the pre-trained weights requires a fix not yet in a released version of candle-core, the current burn dependency is pinned to a specific revision that pins the correct dependency.

TODO:

laggui commented 8 months ago

@antimora tagging you specifically regarding the pattern we might want to adopt for pre-trained weights as discussed in #16

antimora commented 8 months ago

Thank you! I'll review it soon.

laggui commented 8 months ago

Also fixed some minor issues with the bigger ResNet models.