FluxML / FastAI.jl

Repository of best practices for deep learning in Julia, inspired by fastai
https://fluxml.ai/FastAI.jl
MIT License
585 stars 51 forks source link

Add a feature registry for models #267

Open lorenzoh opened 1 year ago

lorenzoh commented 1 year ago

Implements #246.

PR Checklist

Usage examples

From #269:


using FastAI: models
# loading this adds the models to registry
using FastVision

# Load original model, 1000 output classes, no weights (`ResNet(18)`):
load(models()["metalhead/resnet18"]);

# Load original model, 1000 output classes, with weights (`ResNet(18), pretrain=true`):
load(models()["metalhead/resnet18"], pretrained = true);

# Load only backbone, without weights:
load(models()["metalhead/resnet18"], variant = "backbone");

# Load only backbone, with weights:
load(models()["metalhead/resnet18"], pretrained = true, variant = "backbone");

# Load model for task, adapting layers as necessary:
task = ImageClassificationSingle((256, 256), 1:5, C = Gray{N0f8}) # input with 1 color channel, 5 classes
load(models()["metalhead/resnet18"], input = task.blocks.x, output = task.blocks.y)
# Also works with pretrained weights
load(models()["metalhead/resnet18"], pretrained = true, input = task.blocks.x, output = task.blocks.y)

# Correct variants are selected automatically given the blocks:
load(models()["metalhead/resnet18"], output = FastAI.ConvFeatures)  # uses backbone variant

# Support for multiple checkpoints, selectable by name:
load(models()["metalhead/resnet18"], checkpoint = "imagenet1k")

Docs

The proposed interface is well-described by the registry description, pasted below:

A FeatureRegistry for models. Allows you to find and load models for various learning tasks using a unified interface. Call models() to see a table view of available models:

using FastAI
models()

Which models are available depends on the loaded packages. For example, FastVision.jl adds vision models from Metalhead to the registry. Index the registry with a model ID to get more information about that model:

using FastAI: models
using FastVision  # loading the package extends the list of available models

models()["metalhead/resnet18"]

If you've selected a model, call load to then instantiate a model:

model = load("metalhead/resnet18")

By default, load loads a default version of the model without any pretrained weights.

load(model) also accepts keyword arguments that allow you to specify variants of the model and weight checkpoints that should be loaded.

Loading a checkpoint of pretrained weights:

Loading a model variant for a specific task:

github-actions[bot] commented 1 year ago

A documentation preview has been successfully built, view it here: Documentation preview PR-267

lorenzoh commented 1 year ago

@darsnack @theabhirath would love to get feedback on the API. Anything unclear or an important feature missing, let me know!