lucidrains / x-unet

Implementation of a U-net complete with efficient attention as well as the latest research findings
MIT License
265 stars 19 forks source link

How do you train this beast? #8

Open aegonwolf opened 1 year ago

aegonwolf commented 1 year ago

Hi there,

thanks a lot for all your great repos and implementations!

I've wanted to try this for a segmentation problem and I've had issues training on colabs 40GB GPU with dimensions 256x256. The Model I've wanted to use is initialized like so:

gen = XUnet(
        dim = target_shape,
        channels = 3,
        dim_mults = (1, 2, 4, 4),
        nested_unet_depths = (4, 3, 2, 1),     # nested unet depths, from unet-squared paper
        consolidate_upsample_fmaps = True,     # whether to consolidate outputs from all upsample blocks, used in unet-squared paper
).to(device)

Is there a trick or what do you estimate the needed Memory is? I set pin_memory to false, which improved it a little, but still wasn't able to do a single pass (batch_size = 1).

I also noticed most of the memory is reserved, and not allocated, irrespective of the initial size? (always around 35 - 38 GB).

qbeer commented 6 months ago

Better later than never:

class XUnet(nn.Module):

    @beartype
    def __init__(
        self,
        dim,
        init_dim = None,
        out_dim = None,
        frame_kernel_size = 1,
        dim_mults: MaybeTuple(int) = (1, 2, 4, 8),
        num_blocks_per_stage: MaybeTuple(int) = (2, 2, 2, 2),
        num_self_attn_per_stage: MaybeTuple(int) = (0, 0, 0, 1),
        nested_unet_depths: MaybeTuple(int) = (0, 0, 0, 0),
        nested_unet_dim = 32,
        channels = 3,
        use_convnext = False,
        consolidate_upsample_fmaps = True,
        skip_scale = 2 ** -0.5,
        weight_standardize = False,
        attn_heads: MaybeTuple(int) = 8,
        attn_dim_head: MaybeTuple(int) = 32
    ):

Lower the attention heads and/or attention dims.