Closed jemus42 closed 2 years ago
You can find an example of using a pre-trained model with luz here: https://mlverse.github.io/luz/articles/examples/dogs-vs-casts-binary-classification.html
net <- torch::[nn_module](https://torch.mlverse.org/docs/reference/nn_module.html)(
initialize = function(num_classes) {
self$model <- model_alexnet(pretrained = TRUE)
},
forward = function(x) {
self$model(x)[,1]
}
)
Luz setup
, fit
, etc only work with module generators ie. the object returned by torch::nn_module
(those objects when called like torch::nn_linear(10, 10)
generate a nn_module
). This is to make sure luz will always own the initialization of its module thus it can safely use in-place operations, like move parameters to a specific device, modify parameter values etc without affecting the global environment.
While we could make luz work with nn_modules
(as opposed to nn_module_generators
) to make code less verbose, this is somehow, in general when using pre-trained models like those from torchvision you will need to modify the model head, freeze parameters and etc and thus needing a custom torch::nn_module
anyway.
Ah, thanks a lot, that clears it up.
For using a pretrained model I had so far resorted to doing something like
for (par in model$parameters) {
par$requires_grad_(FALSE)
}
model$classifier$`6` <- torch::nn_linear(
in_features = model$classifier$`6`$in_features,
out_features = 10
)
but I see how wrapping a pretrained model in a new nn_module
is preferred.
I wanted to use a prebuilt architecture with luz, either pretrained or "fresh". In
luz::setup()
I ran into an error which I followed to this line checking for aforward()
method in the module.I was then trying to figure out the difference between the pre-made models and self-defined
torch::nn_module()
models which I had used successfully with luz before, so I extracted the code for AlexNet and tried it manually, which works fine:The only difference I can point to is the classes of each model:
Given the code for
model_alexnet
here I don't understand whymodel_alexnet(pretrained = FALSE)
andtorchvision:::alexnet
should differ at all, as the function just returnstorchvision:::alexnet
ifpretrained = FALSE
.I am not sure if this an issue with the model setup in torchvision or with
luz::setup
, but I thought I'd start here.Created on 2022-02-07 by the reprex package (v2.0.1)