n2cholas / jax-resnet

Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
https://pypi.org/project/jax-resnet/
MIT License
103 stars 8 forks source link

Implement Transfer Learning API #2

Closed n2cholas closed 3 years ago

n2cholas commented 3 years ago

Closes #1. Implements a Sequential combinator as well as a slice to easily extract portions of the models.

TODO:

n2cholas commented 3 years ago

@cgarciae any thoughts on this initial design?

cgarciae commented 3 years ago

Hey @n2cholas ! We had a discussion in this PR as ways to approach this: https://github.com/poets-ai/elegy/pull/169

My thoughts:

n2cholas commented 3 years ago

Elegy's approach looks very flexible, I'll definitely give it a try in the future. A blog post would be very nice!

Since this project's scope is ResNet-style architectures, I'll iterate on this slice API a bit and stick to it.

I'm curious to see how Flax will address this problem. nn.compact gives a great boost for ease of implementation and readability, but somewhat sacrifices re-usability of the code (since, for now, you have to treat the Module like black box).

I actually quite liked jax.experimental.stax's combinator ideology--it made dataflow explicit and would enable simple arbitrary model surgery.

n2cholas commented 3 years ago

@cgarciae we'll provide a Sequential module and a slice_variables method. It's trivial to slice a sequential model yourself (sliced_model = Sequential(model.layers[start:end]), and slice_variables will give you the corresponding variables dict (sliced_variables = slice_variables(variables, start, end)).

Once Flax has its own Sequential module (PR), I'll switch to that.

Thanks again for opening this up!