hananshafi / vits-for-small-scale-datasets

[BMVC 2022] Official repository for "How to Train Vision Transformer on Small-scale Datasets?"
141 stars 13 forks source link

teacher_output = teacher(images[:2]) # only the 2 global views pass through the teacher,why????? #7

Closed newer7 closed 5 months ago

newer7 commented 6 months ago
for it, (images,_) in enumerate(metric_logger.log_every(data_loader, 10, header)):   #add mask here
    # update weight decay and learning rate according to their schedule
    it = len(data_loader) * epoch + it  # global training iteration
    for i, param_group in enumerate(optimizer.param_groups):
        param_group["lr"] = lr_schedule[it]
        if i == 0:  # only the first group is regularized
            param_group["weight_decay"] = wd_schedule[it]

    # move images to gpu
    images = [im.cuda(non_blocking=True) for im in images]

    # teacher and student forward passes + compute loss
    with torch.cuda.amp.autocast(fp16_scaler is not None):
        **teacher_output  = teacher(images[:2])  # only the 2 global views pass through the teacher**
        student_output  = student(images)

I'm very interested in your work.Why images is a batch of image data, just use the first two slices to get a global cropped view of all the images in the batch

hananshafi commented 5 months ago

Hello, images contains 2 global views and 8 local views of original image. It is not a batch of images, it is a list of images,