lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
8k stars 757 forks source link

Text to video no Attentions layers #332

Open axel588 opened 1 year ago

axel588 commented 1 year ago

I don't understand why in the Unet3D we don't use attention layers for text conditionning ( sorry if this is dumb question ).

this : layer_attns = (False, False, False, True), layer_cross_attns = False

lucidrains commented 1 year ago

@axel588 hmm, if you are training with text conditioning, but have no cross attention layers set, it should error out (does it not?) i can add it if you show me a script where this is not true

axel588 commented 1 year ago

@lucidrains I applied the attention layer at first, but even with a dimension of 8 ( very low yes ) and a batch of 1 it overflows my 24gb memory card graphic card, this configuration below takes 23Gb of VRAM with 2 of batch, how to solve memory issue ? this code work for text conditionning without attention layer and gives no error, but yes the sample seems random relative to the prompt :


unet = Unet3D(
      dim = config.dim, # the "Z" layer dimension, i.e. the number of filters the outputs to the first layer
      #cond_dim = config.cond_dim,
      channels = 5,
      dim_mults = config.dim_mults, # the channel dimensions inside the model (multiplied by dim)
     # num_resnet_blocks = config.num_resnet_blocks,
     # layer_attns = (False,) + (True,) * (len(config.dim_mults) - 1),
     # layer_cross_attns = (False,) + (True,) * (len(config.dim_mults) - 1)
    )

    imagen = ElucidatedImagen(
        unets = (unet),
        image_sizes = (reshaped_m),
        cond_drop_prob = 0.1,
        text_encoder_name = 't5-base',
        channels=5,
        num_sample_steps = (64), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are)
        sigma_min = 0.002,           # min noise level
        sigma_max = (80),       # max noise level, @crowsonkb recommends double the max noise level for upsampler
        sigma_data = 0.5,            # standard deviation of data distribution
        rho = 7,                     # controls the sampling schedule
        P_mean = -1.2,               # mean of log-normal distribution from which noise is drawn for training
        P_std = 1.2,                 # standard deviation of log-normal distribution from which noise is drawn for training
        S_churn = 80,                # parameters for stochastic sampling - depends on dataset, Table 5 in apper
        S_tmin = 0.05,
        S_tmax = 50,
        S_noise = 1.003,
    ).cuda()```