for it, (images,_) in enumerate(metric_logger.log_every(data_loader, 10, header)): #add mask here
# update weight decay and learning rate according to their schedule
it = len(data_loader) * epoch + it # global training iteration
for i, param_group in enumerate(optimizer.param_groups):
param_group["lr"] = lr_schedule[it]
if i == 0: # only the first group is regularized
param_group["weight_decay"] = wd_schedule[it]
# move images to gpu
images = [im.cuda(non_blocking=True) for im in images]
# teacher and student forward passes + compute loss
with torch.cuda.amp.autocast(fp16_scaler is not None):
**teacher_output = teacher(images[:2]) # only the 2 global views pass through the teacher**
student_output = student(images)
I'm very interested in your work.Why images is a batch of image data, just use the first two slices to get a global cropped view of all the images in the batch
I'm very interested in your work.Why images is a batch of image data, just use the first two slices to get a global cropped view of all the images in the batch