paganpasta / eqxvision

A Python package of computer vision models for the Equinox ecosystem.
https://eqxvision.readthedocs.io
MIT License
100 stars 12 forks source link

Removing equinox experimental #71

Closed stepp1 closed 1 year ago

stepp1 commented 1 year ago

This PR updates Equinox BatchNorm as it's no longer an experimental feat.

Fixes #70

paganpasta commented 1 year ago

Thanks for the PR! It seems that few pre-commit checks are failing. Let me know if you will be able to fix and push them. Else, I'll try to pick this up as soon as I can.

stepp1 commented 1 year ago

Hi there! There were other ways of calling BatchNorm that I didn't catch at first, so I updated the PR. Now, when running the precommit hook, everything seems ok.

However, I'm getting a bunch of errors (100%!) when running pytest.

paganpasta commented 1 year ago

Ok, I'll take a look at as to why they might be failing.

stepp1 commented 1 year ago

Hi there! I was looking through the errors, and the culprit might be the changes in how the stateful ops are handled now. I'm not very familiar w/ equinox, but I'll keep looking.

Edit: I think I narrowed it down to the fact that the new BatchNorm call breaks inside a ConvNormActivation. This occurs because the method now requires both x and state when called. Maybe @patrick-kidger can help here

patrick-kidger commented 1 year ago

The new BatchNorm has a different API to the previous experimental version. (Incidentally the other normalisation layers -- LayerNorm, GroupNorm -- also support this new API if needed, so they are interchangeable.)

It should be called as output, state = norm_layer(input, state), see the stateful tutorial.

I think this should be straightforward enough -- simply thread the additional state through.