Closed stepp1 closed 1 year ago
Thanks for the PR! It seems that few pre-commit checks are failing. Let me know if you will be able to fix and push them. Else, I'll try to pick this up as soon as I can.
Hi there! There were other ways of calling BatchNorm that I didn't catch at first, so I updated the PR. Now, when running the precommit hook, everything seems ok.
However, I'm getting a bunch of errors (100%!) when running pytest.
Ok, I'll take a look at as to why they might be failing.
Hi there! I was looking through the errors, and the culprit might be the changes in how the stateful ops are handled now. I'm not very familiar w/ equinox, but I'll keep looking.
Edit: I think I narrowed it down to the fact that the new BatchNorm call breaks inside a ConvNormActivation. This occurs because the method now requires both x
and state
when called. Maybe @patrick-kidger can help here
The new BatchNorm
has a different API to the previous experimental version. (Incidentally the other normalisation layers -- LayerNorm
, GroupNorm
-- also support this new API if needed, so they are interchangeable.)
It should be called as output, state = norm_layer(input, state)
, see the stateful tutorial.
I think this should be straightforward enough -- simply thread the additional state through.
This PR updates Equinox BatchNorm as it's no longer an experimental feat.
Fixes #70