paganpasta / eqxvision

A Python package of computer vision models for the Equinox ecosystem.
https://eqxvision.readthedocs.io
MIT License
100 stars 12 forks source link

Adding missing classification networks to the model deck. #3

Closed paganpasta closed 2 years ago

paganpasta commented 2 years ago

Missing models listed in no particular order.

paganpasta commented 2 years ago

Tagging @oarriaga for possible interest as per https://github.com/patrick-kidger/equinox/issues/143

oarriaga commented 2 years ago

Hey great work with the repo! I can make a PR. have you thought about where to host the weights? Currently for https://github.com/oarriaga/paz I have another repo hosting any pre-trained weights.

paganpasta commented 2 years ago

Hi,

The weights is the tricky bit. I was initially planning to not train on my own and leverage existing weights of say torchvision (or timm) for this task (https://github.com/paganpasta/eqxvision/issues/4).

Thoughts?

oarriaga commented 2 years ago

Yes I think training is not a viable option. I meant to ask where to host any weights of pre-trained models that have already been ported. For example, for the VGG model I coded in equinox, I have already ported the weights too, and I serialized them in an equinox file. Thus, you could host that file in another repo and download it automatically when the user instantiates the model i.e. keras-style.

I think that model deviations might not exist. Specially for networks that only do Conv2D, MaxPool, Dropout. Maybe BatchNorm could be an issue in equinox? But even then the forward pass should be simple. There were a few caveats when porting the weights from VGG mostly the Flatten layer which required to move the axis before the reshape. But for the other models that you have coded, I think porting weights should not be an issue. Specially if one is porting from another channels first repo framework like pytorch.

paganpasta commented 2 years ago

Ahh right. I should have explained it better.

So what I meant by conversion using torchvision or timm was that we can directly rely on the weights hosted by these libraries and do something like URLs -> download -> convert -> load. This way we don't have to host the .eqx weights as well. I doubt conversion will be time consuming. Thoughts?

On the convertor front, perhaps, it is not feasible to have a common convertor function across package and we can start small with per-model convertor. Or a mix of two, have common conversion rules and then assembling can be handled per model. You have more experience on this so you'll be able to weigh in better. The starting point can be the conversion-code you already have!

One thing I wanted to add to VGG implementation was supporting _bn variants as well for VGGs to be in line with models supported on torchvision. Perhaps we should discuss this in a PR in more detail.

oarriaga commented 2 years ago

I see, yes this makes sense. I am still thinking about whether or not it might be computationally inefficient to have a function that translates the weights every single time a users requests it, or rather do it only once and host "duplicated" weights somewhere else. It seems like memory vs time complexity debate.

My translation code is not generic enough and it goes from keras -> equinox. I think you are considering going from timm's repo so it will not work directly.

I can gladly make the PR with the VGG model that I have, but let me know how you want to proceed with the translation or weight hosting.

paganpasta commented 2 years ago

Yes, that is true regarding the memory vs time complexity debate. Any insights as to how much time it takes to make a single conversion? If you wish to share the script for conversion I can then tweak and check for PyTorch weights. I am not familiar with Keras' saving and loading of weights to comment on it. At the moment, I am willing to increase load time by 10% of the entire process (60 seconds to download weights then 6 seconds to convert and load). After seeing some numbers on the time spent at different steps it'll be easier to settle the debate. What say?

Lets start with VGG integration and chase weight conversion/hosting in a separate thread?

oarriaga commented 2 years ago

OK great a basic PR with the model is here now https://github.com/paganpasta/eqxvision/pull/10