tensorly / torch

TensorLy-Torch: Deep Tensor Learning with TensorLy and PyTorch
http://tensorly.org/torch/
BSD 3-Clause "New" or "Revised" License
70 stars 18 forks source link

`tltorch.FactorizedConv.from_conv` tries to allocate ~83GB memory for an input of shape (1024, 512, 3, 3, 3) #32

Closed hello-fri-end closed 2 weeks ago

hello-fri-end commented 7 months ago

Minimal Code to reproduce the error:

import torch
import tltorch
test_conv3d = torch.nn.Conv3d(1024, 512, (3,3,3), padding=(3,1,1))
print(tltorch.FactorizedConv.from_conv(test_conv3d, rank='same', factorization='cp'))

Error:

RuntimeError: [enforce fail at alloc_cpu.cpp:83] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 89060441849856 bytes. Error code 12 (Cannot allocate memory)

The error is actually coming from this line in the truncated_svd function of TensorLy. The shape of the matrix passed to the svd function is torch.Size([3, 4718592]). Note, this error is not thrown when when torch.svd is directly used. The size of the matrix is only ~54 MBs, it's strange the tl.svd tries to allocate 83 GBs for it.

It's also possible that I'm making a very stupid mistake, in any case, looking forward to some solution here :pray: @JeanKossaifi

hello-fri-end commented 7 months ago

The same example works with factorization = 'tucker' but the from_conv function does not infer dilation from the input conv layer. Example Code:

import torch
import tltorch
test_input = torch.randn((1, 1024,64, 7, 7))
test_conv3d = torch.nn.Conv3d(1024, 512, (3,3,3), padding=(3,1,1), dilation=(3,1,1))
print(test_conv3d(test_input).shape)
fact_conv3d = tltorch.FactorizedConv.from_conv(test_conv3d, rank='same', factorization='tucker')
print(fact_conv3d(test_input).shape)

prints

torch.Size([1, 512, 64, 7, 7])
torch.Size([1, 512, 68, 7, 7])

while,

import torch
import tltorch
test_input = torch.randn((1, 1024,64, 7, 7))
test_conv3d = torch.nn.Conv3d(1024, 512, (3,3,3), padding=(3,1,1), dilation=(3,1,1))
print(test_conv3d(test_input).shape)
fact_conv3d = tltorch.FactorizedConv.from_conv(test_conv3d, rank='same', factorization='tucker', dilation=(3, 1, 1))
print(fact_conv3d(test_input).shape)

prints

torch.Size([1, 512, 64, 7, 7])
torch.Size([1, 512, 68, 7, 7])

Not sure if this is intentional, but felt it worth mentioning.

JeanKossaifi commented 3 months ago

Great catch @hello-fri-end, thank you for investigating and flagging! dilation is supported by the conv but the from_conv doesn't pass the argument - would you be able to open a small PR to fix the issue?

JeanKossaifi commented 2 weeks ago

Support for dilated cones in from_conv is now added in 615fbdd