rentruewang / koila

Prevent PyTorch's `CUDA error: out of memory` in just 1 line of code.
Apache License 2.0
1.82k stars 62 forks source link

unet3d - koila.errors.UnsupportedError #22

Open etienne87 opened 2 years ago

etienne87 commented 2 years ago

I am trying to apply koila lazy eval on a Unet3D.

# defining the model
import torch
import torch.nn as nn
import torch.nn.functional as F

def conv3(in_channels, out_channels, stride, norm='BatchNorm3d', act='GELU'):
    return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 3, 1, 1),
            getattr(nn, norm)(out_channels),
            getattr(nn, act)())

def double_conv3(in_channels, out_channels, stride):
    return nn.Sequential(conv3(in_channels, out_channels, 1),
                         conv3(out_channels, out_channels, stride))

def merge_skip(x, skip):
    x = F.upsample(x, size=skip.shape[-3:], mode='trilinear', align_corners=True)

class Unet3D(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=4, base=16):  

        enc_channels = [in_channels]+[base * 2**i for i in range(num_layers)]
        dec_channels = [base * 2**i for i in range(num_layers-1,-1,-1)]+[out_channels]

        self.encoders = nn.ModuleList()
        for i in range(len(enc_channels)-1):
            cin = enc_channels[i]
            cout = enc_channels[i+1]
            enc = double_conv3(cin, cout, 2)

        self.decoders = nn.ModuleList()
        for i in range(len(dec_channels)-1):
            cin_skip = enc_channels[-i-2]
            cin_up = dec_channels[i]
            cin = cin_skip + cin_up 
            cout = dec_channels[i+1]
            dec = double_conv3(cin, cout, 1)    

    def forward(self, x, return_all=False):
        out = [x]
        for encoder in self.encoders:
            x = encoder(x)
        n = len(out)
        for i, decoder in enumerate(self.decoders): 
            skip = out[n - 2 - i]
            x = merge_skip(out[-1], skip)
            x = decoder(x)

        if return_all:
            return out 
            return out[-1]

# test of koila on unet
def test_lazy():
    net = Unet3D(1,3)
    s = 64 
    b,c,d,h,w = 2,1,s,s,s
    x = torch.randn(b,c,d,h,w).cuda()
    t = torch.randint(0,3, (b,d,h,w)).cuda()

    loss_fn = nn.CrossEntropyLoss()

    lazy_x, lazy_t = lazy(x, t, batch=0)
    lazy_out = net(lazy_x)
    lazy_loss = loss_fn(lazy_out, lazy_t) 
    assert isinstance(lazy_loss, LazyTensor), type(lazy_loss)

# This fails

This fails and outputs:

tensors = (tensor([[[[[-8.9936e-02, -7.9037e-02, -1.5048e-02,  ...,  2.9969e-01,
             2.9774e-01, -1.0489e-01],
        ...]]], device='cuda:0',
       grad_fn=<UpsampleTrilinear3DBackward1>), <koila.lazy.LazyTensor object at 0x7fa21bf99880>)
dim = 1, args = (), kwargs = {}, shapes = [torch.Size([2, 128, 64, 64, 64]), (2, 64, 64, 64, 64)]
no_dim = [torch.Size([2, 64, 64, 64]), (2, 64, 64, 64)], result_size = torch.Size([2, 64, 64, 64])
size = (2, 64, 64, 64)

    def cat(
        tensors: Sequence[TensorLike], dim: int = 0, *args: Any, **kwargs: Any
    ) -> PrePass:
        mute_unused_args(*args, **kwargs)

        if len(tensors) == 0:
            raise ValueError("Expected a sequence of tensors. Got empty sequence.")

        shapes = [t.size() for t in tensors]
        no_dim = [t[:dim] + t[dim + 1 :] for t in shapes]

        result_size = no_dim[0]
        for size in no_dim[1:]:
            if result_size != size:
                raise ValueError(
                    f"Dimension should be equal outside dim {dim}. Got {shapes}."

        if len(set(interfaces.bat(t) for t in tensors)) != 1:
>           raise UnsupportedError
E           koila.errors.UnsupportedError

../miniconda3/envs/snakes/lib/python3.9/site-packages/koila/ UnsupportedError
rentruewang commented 2 years ago

Hi, that means the batch sizes don't match, and the library doesn't know how to deal with that situation.

Since PyTorch's broadcasting rules are extensive, not all rules are supported yet.

I'll see what I can do about it in the upcoming changes in #18 with a much more modular implementation.