tensorly / torch

TensorLy-Torch: Deep Tensor Learning with TensorLy and PyTorch
http://tensorly.org/torch/
BSD 3-Clause "New" or "Revised" License
74 stars 19 forks source link

Tensor Regression Layer #12

Closed Silk760 closed 2 years ago

Silk760 commented 2 years ago

I am trying to use tensor regression layers instead of the fully connected layer in my model, the paper claim I can just replace the fully connected layer which removed the need to flatten the tensor and keep the spatial information which increase the model accuracy.

I did this with small network

` import torch.nn as nn import torch.nn.functional as F import torch

class Net(nn.Module):

def __init__(self):
    super(Net, self).__init__()
    self.conv1  = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=(5,5),stride=(1,1),padding=2)
    self.pool1  = nn.MaxPool2d(kernel_size=2)
    self.conv2  = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=(5,5),stride=(1,1),padding=0)
    self.pool2  = nn.MaxPool2d(kernel_size=2)
    self.linear1 = nn.Linear(16*5*5,120)
    self.linear2 = tltorch.TRL(120,84,factorization='tt',rank=100)
    self.classifier = nn.Linear(84,10)

def forward(self, x):
    out = self.conv1(x)
    out = self.pool1(out)
    out = self.conv2(out)
    out = self.pool2(out)
    #out = torch.flatten(out,1)
    out = self.linear1(out)
    out = self.linear2(out)
    out =self.classifier(out)

    return out

`

training the model the loss is always nan. If there any example shows how to use tensor regression layer in tensor contraction layer in the state of the art models will be very helpful.

JeanKossaifi commented 2 years ago

It is really hard to debug code based on this small snippet. The layer could be uninitialised, the learning rate could be wrong, etc. I recommend you explicitly initialise your layer with the desired variance.

In your code, why do you flatten the input if you want to use a TRL? You're essentially just using a matrix factorisation in that way. I'll upload an example notebook when I get a moment.

JeanKossaifi commented 2 years ago

Closing as inactive, feel free to reopen if you still have the issue.

rutujagurav commented 2 years ago

An example use of TRL in a network would be good because it is proving non-trivial to use this thing easily. Any idea where this is on the priority list?