jatinchowdhury18 / RTNeural

Real-time neural network inferencing
BSD 3-Clause "New" or "Revised" License
543 stars 57 forks source link

`groups` support for Conv1D? #114

Closed purefunctor closed 7 months ago

purefunctor commented 8 months ago

Hi, this is a really awesome project!

I'm trying to port a model that makes use of 1D convolutions but the immediate thing I ran into was that the Conv1D layer didn't have a parameter for groups, as present in PyTorch/Tensorflow. Learning resources on low-level NN programming is a little terse, but I'd like to tackle this!

My (high-level) understanding of it goes something along the lines of:

from torchinfo import summary

x = torch.nn.Conv1d(9, 3, kernel_size=1, groups=1, bias=False)
y = torch.nn.Conv1d(9, 3, kernel_size=1, groups=3, bias=False)

summary(x)
summary(y)

Which gives:

=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Conv1d                                   27
=================================================================
Total params: 27
Trainable params: 27
Non-trainable params: 0
=================================================================
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Conv1d                                   9
=================================================================
Total params: 9
Trainable params: 9
Non-trainable params: 0
=================================================================

In the case of x, with groups=1, each input channel has its own 1x1 kernel. This is reflected in the weights:

tensor([[[-0.2372],
         [-0.0259],
         [ 0.0075],
         [ 0.0602],
         [ 0.0205],
         [-0.2290],
         [-0.3304],
         [ 0.1302],
         [ 0.0697]],

        [[ 0.3148],
         [-0.1045],
         [ 0.2087],
         [ 0.2184],
         [ 0.2869],
         [-0.1255],
         [-0.0349],
         [ 0.2754],
         [ 0.0341]],

        [[-0.0065],
         [ 0.0904],
         [ 0.1445],
         [ 0.0337],
         [-0.0661],
         [ 0.2763],
         [-0.1375],
         [ 0.0841],
         [-0.0864]]]) torch.Size([3, 9, 1])

For each output channel, there are 9 1x1 kernels which correspond to each input channel.

Meanwhile, for the case of y with groups=3, it has the following weights:

tensor([[[-0.0507],
         [ 0.1446],
         [ 0.1827]],

        [[-0.1260],
         [ 0.2465],
         [ 0.5095]],

        [[ 0.4771],
         [ 0.1377],
         [-0.0265]]]) torch.Size([3, 3, 1])

For each output channel, there are 3 1x1 kernels which correspond to each input channel group.

An intuitive way I've found to see how this works is:

x.weight.requires_grad = False
y.weight.requires_grad = False

torch.nn.init.constant_(x.weight[0], 1.0)
torch.nn.init.constant_(x.weight[1], 0.9)
torch.nn.init.constant_(x.weight[2], 0.8)

torch.nn.init.constant_(y.weight[0], 1.0)
torch.nn.init.constant_(y.weight[1], 0.9)
torch.nn.init.constant_(y.weight[2], 0.8)

i = torch.ones(9, 2)
print(i)
print(x(i))
print(y(i))

and this yields:

tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])
tensor([[9.0000, 9.0000],
        [8.1000, 8.1000],
        [7.2000, 7.2000]])
tensor([[3.0000, 3.0000],
        [2.7000, 2.7000],
        [2.4000, 2.4000]])

The result of y(i) yields significantly less "energy" than x(i) as each output channel now has less kernels to work with, 3 instead of 9.

jatinchowdhury18 commented 8 months ago

Hello!

Thanks for the issue, and for the exploration of the "groups" functionality. I'll just add the description that I found in the Conv1D documentation:

groups controls the connections between inputs and outputs. in_channels and out_channels must both be divisible by groups. For example,

  • At groups=1, all inputs are convolved to all outputs.
  • At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated.
  • At groups= in_channels, each input channel is convolved with its own set of filters (of size out_channels/in_channels).

I think I have a rough sense of how to implement this, but it may be a few weeks until I have time to really sit down and work on it. With that in mind, I guess I'll list out my workflow for how I would probably approach the problem, and then if you (or someone else) would like to take a shot at, that would be cool!

Anyway, hopefully this is helpful in case you or anyone wants to tackle this. If not, just give it a little time while I finish up some other things, and I can come back to this. If you do start working on it, feel free to message me with any questions or intermediate progress updates!

purefunctor commented 8 months ago

That's indeed really helpful, yup!

I'm trying to understand what the "state" and "state_cols" are in the convolutions, do you have any pointers on that?

Another good way to think about grouped convolutions is that each group can be thought of as if they had their own convolutions--as in, if I had a 9in->9out convolution with 3 groups, I'd have 3 3in->3out convolutions that'll get summed up.

jatinchowdhury18 commented 8 months ago

Very cool! For state_cols the idea is that it's a "helper" variable to store only the columns of the state that will be multiplied by the weights (they're not guaranteed to be contiguous depending on the dilation rate). See how the state_cols are set here.