mlverse / torchvision

R interface to torchvision
https://torchvision.mlverse.org
Other
62 stars 14 forks source link

Freezeing/unfreezeing weights error #33

Closed kailex closed 3 years ago

kailex commented 3 years ago

To freeze weights of a layer we have to disable gradients computing:

library(torch)
library(torchvision)

m_nn <- model_resnet18(pretrained = FALSE)
m_nn$parameters$conv1.weight$requires_grad <- FALSE
m_nn$parameters$fc.weight$requires_grad <- TRUE

This snippet produces a bunch of errors despite the fact that the requires_grad variables got their respective values:

Error in (function () : unused argument (base::quote(list(, , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , ,

, , , , , , , , , , , , , , , , , , , , , , , , , , , , , ))) \> m_nn$parameters$conv1.weight$requires_grad [1] FALSE \> m_nn$parameters$fc.weight$requires_grad [1] TRUE
dfalbel commented 3 years ago

The error message should be better now. It's not possible to modify a parameter from the parameter's list. You need to access them from the module. For example:

m_nn <- model_resnet18(pretrained = FALSE)
m_nn$conv1$weight$requires_grad <- FALSE

Also, only since mlverse/torch#419 it's allowed to modify the requires_grad attribute with $<-, before that you would need to use something like:

m_nn$conv1$weight$requires_grad_(FALSE)

And that you could also do from the parameter list.