Closed hello-fri-end closed 4 months ago
The same example works with factorization = 'tucker'
but the from_conv
function does not infer dilation
from the input conv layer.
Example Code:
import torch
import tltorch
test_input = torch.randn((1, 1024,64, 7, 7))
test_conv3d = torch.nn.Conv3d(1024, 512, (3,3,3), padding=(3,1,1), dilation=(3,1,1))
print(test_conv3d(test_input).shape)
fact_conv3d = tltorch.FactorizedConv.from_conv(test_conv3d, rank='same', factorization='tucker')
print(fact_conv3d(test_input).shape)
prints
torch.Size([1, 512, 64, 7, 7])
torch.Size([1, 512, 68, 7, 7])
while,
import torch
import tltorch
test_input = torch.randn((1, 1024,64, 7, 7))
test_conv3d = torch.nn.Conv3d(1024, 512, (3,3,3), padding=(3,1,1), dilation=(3,1,1))
print(test_conv3d(test_input).shape)
fact_conv3d = tltorch.FactorizedConv.from_conv(test_conv3d, rank='same', factorization='tucker', dilation=(3, 1, 1))
print(fact_conv3d(test_input).shape)
prints
torch.Size([1, 512, 64, 7, 7])
torch.Size([1, 512, 68, 7, 7])
Not sure if this is intentional, but felt it worth mentioning.
Great catch @hello-fri-end, thank you for investigating and flagging! dilation
is supported by the conv but the from_conv
doesn't pass the argument - would you be able to open a small PR to fix the issue?
Support for dilated cones in from_conv
is now added in 615fbdd
Minimal Code to reproduce the error:
Error:
The error is actually coming from this line in the
truncated_svd
function of TensorLy. The shape of the matrix passed to thesvd
function istorch.Size([3, 4718592])
. Note, this error is not thrown when whentorch.svd
is directly used. The size of the matrix is only ~54 MBs, it's strange thetl.svd
tries to allocate 83 GBs for it.It's also possible that I'm making a very stupid mistake, in any case, looking forward to some solution here :pray: @JeanKossaifi