QUVA-Lab / e2cnn

E(2)-Equivariant CNNs Library for Pytorch
https://quva-lab.github.io/e2cnn/
Other
596 stars 75 forks source link

fix shape issue in drop_gates #64

Closed ahyunSeo closed 1 year ago

ahyunSeo commented 1 year ago

Different from GatedNonLinearity1, gates in InducedGatedNonLinearity1 are reshaped from (b, c, h, w) to (b, -1, self.quotient_size, h, w) in L218.

When the config self.drop_gates is False, this results in error in L226 and L228 when the gates are copied to the output.

if not self.drop_gates:
    # copy the gates in the output
    if self._contiguous[GATES_ID]:
        output[:, self.gates_indices[0]:self.gates_indices[1], ...] = gates
    else:
        output[:, self.gates_indices, ...] = gates

For those who wonder when to 'not' drop the gates: I needed it to make the input and output type of the "ReLU" unchanged for identity connection in my ResNet. When using vanilla O2, it doesn't matter. It matters when I use Induced irreps for better performance.