QUVA-Lab / escnn

Equivariant Steerable CNNs Library for Pytorch https://quva-lab.github.io/escnn/
Other
357 stars 47 forks source link

Instance Norm as normalization? #69

Open psteinb opened 1 year ago

psteinb commented 1 year ago

Dear @Gabri95 sorry to bug you. I am currently trying to come up with an equivariant Unet architecture which is very close to a "standard" Unet, I use as a reference. For this, I came across the matter of different normalization schemes. I looked at your implementations here and you appear to be focusing on batch norm only.

However, I was wondering if anything speaks against implementing InstanceNorm? The difference being that the mean/var is not computed across the entire batch, but rather across each sample in a batch.

Gabri95 commented 1 year ago

Hey @psteinb,

No, I don't see any issue with that!

You can try to adapt the IIDBatchNormnD to IIDInstanceNormnD: I think adapting the dimensions over which mean and std are computed should be sufficient to implement InstanceNorm.

I'm currently implementing a version of Layer/GroupNorm. You can also take a look at that once I am done!

Best, Gabriele

psteinb commented 1 year ago

Alright, I'll look into IIDBatchNormnD tomorrow then. Hope to send a PR until Wednesday. :crossed_fingers:

psteinb commented 1 year ago

Ok, I started working on it. I took the IIDBatchNormnD code and wanted to adapt it accordingly. At this point, the test cases appear to be tailored to continuous groups (which I don't have experience with so far). So I got stuck here and there. Feel free to take over, it may be some time until I can see to it again. My apologies.

psteinb commented 1 year ago

Hey @Gabri95, just to check in. I am looking at this again.

A minor question, for the batch normalisation the escnn library uses a matrix P to split the contributions of the expectation value of a batch across the representations of the input type. https://github.com/QUVA-Lab/escnn/blob/94380ef401c51841b87a9dd4ed292637aa5883ab/escnn/nn/modules/batchnormalization/iid.py#L188 (see also section 4.2 of your thesis)

As I am implementing an instance norm, for an input batch of e.g. "images" BxCxWxH the instance norm requires to compute the mean and variance only across WxH for each sample and for each channel. (This way the mean values would have shape (B,C,1,1). The variance is shaped equally.)

However, these instance norm normalisation coefficients do not represent an expectation value across the entire batch (rather only for the signals of a single channel). So I wonder, do I actually need to multiply my mean values with P in the first place? (My hunch is that the answer is 'No, I don't need to do this multiplication', but I would love to be certain.)

Would be cool to hear your thoughts.