n2cholas / jax-resnet

Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
https://pypi.org/project/jax-resnet/
MIT License
103 stars 8 forks source link

Transfer Learning API #1

Closed cgarciae closed 3 years ago

cgarciae commented 3 years ago

Hey @n2cholas !

I was wondering how to properly do transfer learning, maybe this feature is not implemented yet but is it possible to select the second to last layer? More generally, can you select other intermediate layers?

I want to make an example of doing Transfer Learning in Elegy and these pre-trained models look perfect for the task.

n2cholas commented 3 years ago

Hi @cgarciae! This is a good question, and not something I have given significant thought to yet. I'm open to suggestions!

Right now, it is possible to get the activations from the second to last layer using the backbone_only option:

from jax_resnet import pretrained_resnet
import jax.numpy as jnp

ResNet50, variables = pretrained_resnet(50)
model = ResNet50(backbone_only=True)
out = model.apply(variables, jnp.ones((1, 224, 224, 3)), 
                  train=False, mutable=False)
print(out.shape)  # (1, 7, 7, 2048)

I'm not sure how to extract earlier layers. In a PyTorch sequential model, it's easy to grab a prefix/slice of the model, but Flax doesn't have such an abstraction. Do you think it would be worth creating a lightweight Sequential wrapper to make that easier here?

Here's a mini transfer learning project I worked on recently with jax_resnet. In the training section, you can see how I manage frozen/trainable parameters and such.

cgarciae commented 3 years ago

Do you think it would be worth creating a lightweight Sequential wrapper to make that easier here?

I like the idea as it exposes a simple API. You would need to prune unused parameters from variables also right?

Here's a mini transfer learning project I worked on recently with jax_resnet.

This is very nice! You mind if I port this to Elegy?

n2cholas commented 3 years ago

Do you think it would be worth creating a lightweight Sequential wrapper to make that easier here?

I like the idea as it exposes a simple API. You would need to prune unused parameters from variables also right?

Right. I'll think more about a clean API for this in the coming week.

Here's a mini transfer learning project I worked on recently with jax_resnet.

This is very nice! You mind if I port this to Elegy?

I'd be honoured! My only ask is that you use the official COVIDx dataset instead of my modified split: https://github.com/lindawangg/COVID-Net/blob/master/docs/COVIDx.md. Some of the custom optax code should be unnecessary soon: the selective additive weight decay will be possible with just optax (pr), and hopefully the scheduled optimizer as well (issue).