This PR adds a Burn implementation for the ResNet family of computer vision models. Implementation and weights are based on torchvision.
TO-DO
[x] ResNet implementation for resnet-{18, 34, 50, 101, 152}
[x] Inference example
[x] Use PyTorchFileRecorder to import weights
[ ] Switch to next Burn release version
The Cargo.toml file currently points to a local workspace copy of Burn since parsing the resnet weights with_key_remap requires a fix that came in after the 0.12.1 release and a fix on Candle's side to handle module parameters with pickle
[ ] Automatic download of pre-trained weights (need to re-export download_file_as_bytes for common use in burn)
Would be great to have a version that includes the fixes but that also depends on Candle releasing a new version with the module parameters fix for pickle.
Issue #11
This PR adds a Burn implementation for the ResNet family of computer vision models. Implementation and weights are based on torchvision.
TO-DO
Cargo.toml
file currently points to a local workspace copy of Burn since parsing the resnet weightswith_key_remap
requires a fix that came in after the0.12.1
release and a fix on Candle's side to handle module parameters with pickledownload_file_as_bytes
for common use in burn)