Closed cgarciae closed 3 years ago
Hi @cgarciae! This is a good question, and not something I have given significant thought to yet. I'm open to suggestions!
Right now, it is possible to get the activations from the second to last layer using the backbone_only
option:
from jax_resnet import pretrained_resnet
import jax.numpy as jnp
ResNet50, variables = pretrained_resnet(50)
model = ResNet50(backbone_only=True)
out = model.apply(variables, jnp.ones((1, 224, 224, 3)),
train=False, mutable=False)
print(out.shape) # (1, 7, 7, 2048)
I'm not sure how to extract earlier layers. In a PyTorch sequential model, it's easy to grab a prefix/slice of the model, but Flax doesn't have such an abstraction. Do you think it would be worth creating a lightweight Sequential wrapper to make that easier here?
Here's a mini transfer learning project I worked on recently with jax_resnet
. In the training section, you can see how I manage frozen/trainable parameters and such.
Do you think it would be worth creating a lightweight Sequential wrapper to make that easier here?
I like the idea as it exposes a simple API. You would need to prune unused parameters from variables
also right?
Here's a mini transfer learning project I worked on recently with jax_resnet.
This is very nice! You mind if I port this to Elegy?
Do you think it would be worth creating a lightweight Sequential wrapper to make that easier here?
I like the idea as it exposes a simple API. You would need to prune unused parameters from
variables
also right?
Right. I'll think more about a clean API for this in the coming week.
Here's a mini transfer learning project I worked on recently with jax_resnet.
This is very nice! You mind if I port this to Elegy?
I'd be honoured! My only ask is that you use the official COVIDx dataset instead of my modified split: https://github.com/lindawangg/COVID-Net/blob/master/docs/COVIDx.md. Some of the custom optax code should be unnecessary soon: the selective additive weight decay will be possible with just optax (pr), and hopefully the scheduled optimizer as well (issue).
Hey @n2cholas !
I was wondering how to properly do transfer learning, maybe this feature is not implemented yet but is it possible to select the second to last layer? More generally, can you select other intermediate layers?
I want to make an example of doing Transfer Learning in Elegy and these pre-trained models look perfect for the task.