KhrulkovV / tt-pytorch

59 stars 15 forks source link

tt_conv is not implemented here #10

Open Silk760 opened 3 years ago

Silk760 commented 3 years ago

I have read the paper, it is very interesting work, but I was thinking to use it with Conv layers, but the Conv layer is not implemented.

I have researched other Github repository, there is no PyTorch implementation of the Conv layer. Can you help me out about if there any code or some work that shows how can be implemented in PyTorch?

elena-orlova commented 3 years ago

Yes, convolutional layers are not the focus of our paper. You can have a look at this repository https://github.com/musco-ai/musco-pytorch , like in musco/pytorch/compressor/decompositions/ directory. However, there is no example of tt_conv, but there's pretty similar Tucker-2 decomposition for conv layers. Also, another close example is available here https://github.com/NVlabs/conv-tt-lstm This paper is devoted to a compression of conv_LSTM layers with TT decomposition.

In general, it's common to reshape a 4D tensor (convolution) (C_in, C_out, k, k) to a 3D (C_in, C_out, k*k), where C_in - number of input channels, C_out - number of output channels and k - a kernel size. So, applying a TT decomposition, you get 3 cores. I'd suggest that you can implement this as a custom class (where r1 and r2 are ranks of TT decomposition) in such manner:

class tt_conv2d(Module):
   ...
   # inside your class
   self.conv1 = nn.Conv2d(c_in, r1*r2, 1, stride=1, padding=0) 
   self.conv2_weight = nn.Parameter(torch.Tensor(1, r1, *kernel_size))
   self.conv3 = nn.Conv2d(r2, c_out, 1, stride=1, padding=0) 

def forward(self, input):
   out = self.conv1(input)
   out = nn.functional.conv2d(out, self.conv2_weight.repeat(r2, 1, 1, 1),
             stride=stride, groups=r2, dilation= dilation, padding= padding)
   out = self.conv3(out)
return out