lucidrains / x-clip

A concise but complete implementation of CLIP with various experimental improvements from recent papers
MIT License
682 stars 46 forks source link

NaN with mock data #10

Open BlinkDL opened 2 years ago

BlinkDL commented 2 years ago

Hi lucidrains,

Try this and it will NaN within 100 steps (latest Github code). The loss looks fine before NaN.

import torch
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True    
torch.backends.cudnn.benchmark = True

import random
import numpy as np
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

num_text_tokens = 10000
batch_sz = 12
text_seq_len = 256
visual_image_size = 256

# mock data

data_sz = 1000
all_text = torch.randint(0, num_text_tokens, (data_sz, text_seq_len)).cuda()
all_images = torch.randn(data_sz, 3, visual_image_size, visual_image_size).cuda()

text = torch.zeros((batch_sz, text_seq_len), dtype=torch.long).cuda()
images = torch.zeros((batch_sz, 3, visual_image_size, visual_image_size)).cuda()

##########################################################################################

import wandb
import datetime
wandb.init(project="Test", name=datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), save_code=False)

from x_clip import CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = num_text_tokens,
    text_enc_depth = 6,
    text_seq_len = text_seq_len,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = visual_image_size,
    visual_patch_size = 32,
    visual_heads = 8,
    use_all_token_embeds = False,           # whether to use fine-grained contrastive learning (FILIP)
    decoupled_contrastive_learning = True,  # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
    extra_latent_projection = True,         # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
    use_visual_ssl = True,                  # whether to do self supervised learning on iages
    visual_ssl_type = 'simclr',             # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
    use_mlm = False,                        # use masked language learning (MLM) on text (DeCLIP)
    text_ssl_loss_weight = 0.05,            # weight for text MLM loss
    image_ssl_loss_weight = 0.05            # weight for image self-supervised learning loss
).cuda()

optimizer = torch.optim.Adam(clip.parameters(), lr=1e-4, betas=(0.9, 0.99))

for step in range(999999):
    for i in range(batch_sz):
        data_id = random.randrange(0, data_sz - 1)
        text[i] = all_text[data_id]
        images[i] = all_images[data_id]

    loss = clip(
        text,
        images,
        freeze_image_encoder = False,   # whether to freeze image encoder if using a pretrained image net, proposed by LiT paper
        return_loss = True              # needs to be set to True to return contrastive loss
    )
    clip.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(clip.parameters(), 1.0)
    optimizer.step()

    now_loss = loss.item()
    wandb.log({"loss": now_loss}, step = step)
    print(step, now_loss)

    if 'nan' in str(now_loss):
        break
lucidrains commented 2 years ago

@BlinkDL Hey Peng Bo! So I quickly checked the script and indeed it NaNs, but not if the visual_ssl is turned off

I suspect it has something to do with augmenting the randomly created images in the visual SSL, but not completely sure