tracel-ai / models

Models and examples built with Burn
Apache License 2.0
180 stars 24 forks source link

General pattern for weight download and conversation #16

Open antimora opened 8 months ago

antimora commented 8 months ago

This ticket is a two fold request:

  1. Further enhancing Resnet-Burn model, which was recently added.
  2. Come up with a general requirement and solution to the models added to models repo.

Now that we are adding popular models to the burn-model repo, we should consider the end user experience and come up with some basis top level requirements of what is expected when a user adopts/uses migrated model. This can evolve into a standard across other modes.

Here is my proposal:

  1. Each model should offer an automatic weights download from a known source. The source can be overwritten if needed. We should offer in a library form and binary executable under bin folder. The destination can be defaulted to some cache location or specified by a user.
  2. If the source file is non-burn format, we convert the file and the subsequent loading uses burn native file.
  3. (Optional) Converted file is uploaded to HuggingFace portal under Burn organization.
antimora commented 8 months ago

I am inviting @laggui @ashdtu @nathanielsimard @louisfd @Luni-4, and others for your inputs.

laggui commented 8 months ago

Funny you mention that, I was just working on adding automatic loading of pre-trained weights to the ResNet models 😄 So great timing!

Since I haven't pushed any of my changes yet (PR should come soon), I'll summarize the way I am currently approaching this.

By default, the models support no_std and I've added a pretrained feature flag that requires std and adds optional dependencies such as burn-import crate to use the PyTorchFileRecorder and burn/network (new since this PR) to use the download_file_as_bytes function with a download progress bar.

Regarding your specific points:

  1. For storing the downloaded weights, right now I followed the default pattern I observed in burn: put them in the ~/.cache directory under the model name (e.g., ~/.cache/resnet-burn).
  2. For loading the weights I currently added resnet*_pretrained methods that do exactly as you described: download the .pth checkpoint and use the PyTorchFileRecorder to load them.
  3. Haven't done anything in that regard yet, but we briefly talked about something like that with @nathanielsimard
nathanielsimard commented 8 months ago

Something that I would also like to see is exporting models without a specified backend. So users can chose the backend.