FluxML / FastAI.jl

Repository of best practices for deep learning in Julia, inspired by fastai
https://fluxml.ai/FastAI.jl
MIT License
585 stars 51 forks source link

Model registry #246

Open lorenzoh opened 2 years ago

lorenzoh commented 2 years ago

In addition to existing feature registries, FastAI.jl will be getting a model registry.

The model registry should allow

Examples

Some code examples to show how the model registry can be used:

Loading models

Load a pretrained ResNet implemented in Metalhead.jl for transfer learning:

load(models()["metalhead/resnet18/head"], pretrained=true)

Load the ResNet as an untrained backbone for a different task

load(models()["metalhead/resnet18/backbone"], pretrained=false)

Searching for models

Find models that take in preprocessed images:

filter(models(), input=ImageTensor{2})

Or find a suitable model for a supervised learning task directly:

task = SupervisedTask(_)/ImageSegmentation(_)/TabularClassificationSingle(_)
filter(models(), input=task.blocks.x, output=task.blocks.y)

List models implemented in PyTorch:

filter(models(), backend=:pytorch)

Find models of a certain size:

filter(models(), input=ImageTensor{2}, nparams=<(1000000))

Training workflow

Since models in the registry are associated with block information, we can use them to automatically construct task-specific models using the taskmodel API (possibly extended by an additional backend argument).

config = models()["torchvision/resnet18/backbone"]
backbone = load(config)

task = ImageSegmentation(_)
# build the task-specific model
model = taskmodel(task,           # includes info about required input and target block for task
                  config.backend  # dispatch on the DL library used, here :pytorch
                  backbone,
                  config.input,   # backbone input block: `ImageTensor{2}(3)`
                  config.output)  # backbone output block: `ConvFeatures{2}(512)`

learner = tasklearner(task, data; model)
fit!(learner, 10)