tensorly / torch

TensorLy-Torch: Deep Tensor Learning with TensorLy and PyTorch
http://tensorly.org/torch/
BSD 3-Clause "New" or "Revised" License
70 stars 18 forks source link

Operating on the Decomposed Form? #20

Closed PeaBrane closed 1 year ago

PeaBrane commented 2 years ago

First of all, thank you guys for creating this wonderful library, and making it public for us to experiment. This is likely more off an issue with user error on my end than the code itself.

I'm trying to get a small convolutional network to work with the CIFAR-100 dataset, with two fully connected layers at the end factorized using blocktt. While the network is able to train fine, and achieved nearly the same accuracy as the non-tensorized counterpart, the training is a bit slow, and I am getting the following warning:

UserWarning: BlockTT, shape=[512, 512], tensorized_shape=((8, 8, 8), (32, 4, 4)), rank=[1, 4, 4, 1]) is being reconstructed into a matrix, consider operating on the decomposed form. warnings.warn(f'{self} is being reconstructed into a matrix, consider operating on the decomposed form.')

I don't quite understand this warning message. Am I supposed decompose the feature tensor first (e.g. into TT or tucker format) before passing it through the factorized linear layers? I couldn't quite understanding from just reading the documentation on how to do inside a torch.nn.Module block.

I have attached my source code here.

class Block(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = torch.nn.Conv2d(3, 8, 3, 2, 1)
        self.conv2 = torch.nn.Conv2d(8, 16, 3, 2, 1)
        self.conv3 = torch.nn.Conv2d(16, 32, 3, 2, 1)
        self.fc1 = tltorch.FactorizedLinear((32, 4, 4), (8, 8, 8), factorization='blocktt', rank=(1, 4, 4, 1))
        self.fc2 = tltorch.FactorizedLinear((8, 8, 8), (8, 8, 8), factorization='blocktt', rank=(1, 4, 4, 1))
        self.fc3  = torch.nn.Linear(512, 100)

    def forward(self, inputs):
        outputs = F.relu(self.conv3(F.relu(self.conv2(F.relu(self.conv1(inputs))))))
        outputs = outputs.flatten(-3, -1)
        outputs = F.relu(self.fc1(outputs))
        outputs = F.relu(self.fc2(outputs))
        outputs = self.fc3(outputs)
        return outputs
JeanKossaifi commented 1 year ago

This should be straightforward to do since #26, feel free to reopen if you still see the issue!