facebookresearch / dinov2

PyTorch code and models for the DINOv2 self-supervised learning method.
Apache License 2.0
8.32k stars 700 forks source link

Is this the right way to fine-tune DINOv2? #276

Open namrahrehman opened 8 months ago

namrahrehman commented 8 months ago

I am trying to finetune dinov2 for image classification on a custom dataset (medical image dataset) with the objective of increasing accuracy. The problem is that when I use linear evaluation I get an adequate accuracy of almost 75%, however when I try to finetune(the whole backbone) I can never get an accuracy higher than 40%, is there something semantically wrong with how I am trying to finetune this model? I even tried it with cifar10 and got an excellent performance on linear evaluation but a poor performance on fine-tuning. Also when I used the model from the hub and ran the following code snippet, I got "Pre-trained DINO weights are not found in the model's state_dict." so instead I had to load the model from hugging face for fine-tuning the whole backbone :

pretrained_dino_keys = [k for k in model.state_dict() if 'dino' in k]

if pretrained_dino_keys:
    print("Pre-trained DINO weights are present in the model's state_dict.")
else:
    print("Pre-trained DINO weights are not found in the model's state_dict.")

the following is my code for fine-tuning:

from transformers import Dinov2ForImageClassification
model = Dinov2ForImageClassification.from_pretrained("facebook/dinov2-small-imagenet1k-1-layer")
for param in model.dinov2.parameters():
    param.requires_grad = True
for param in model.classifier.parameters():
    param.requires_grad = True
# Customize the head for the classification task
num_classes = 10  # Number of classes in the dataset
model.classifier = nn.Linear(768, num_classes).to(device)  a linear layer for classification and move to GPU

# Define the loss function 
loss_fn = nn.CrossEntropyLoss()  

weight_decay = 1e-3 
lr = 0.001
step_size = 5
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

# Create a learning rate scheduler
scheduler = StepLR(optimizer, step_size=step_size, gamma=0.0001)
def make_classification_eval_transform(
    *,
    resize_size: int = 256,
    interpolation=transforms.InterpolationMode.BICUBIC,
    crop_size: int = 224,
) -> transforms.Compose:
    transforms_list = [
        transforms.Resize(resize_size, interpolation=interpolation),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]
    return transforms.Compose(transforms_list)

# Use the make_classification_eval_transform function to create the transformation pipeline
transform = make_classification_eval_transform()

# Set up data loaders for training, validation, and test
train_dataset = ImageFolder(root=train_dataset_path, transform=transform)
valid_dataset = ImageFolder(root=valid_dataset_path, transform=transform)
test_dataset = ImageFolder(root=test_dataset_path, transform=transform)

# Modify data loading to move data to the same device as the model
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)
model = model.to(device)
# Set random seed
torch.manual_seed(1)

# Define the number of epochs
num_epochs = 20

# Initialize lists to store loss and accuracy for each epoch
loss_hist_train = [0.0] * num_epochs
accuracy_hist_train = [0.0] * num_epochs
loss_hist_valid = [0.0] * num_epochs
accuracy_hist_valid = [0.0] * num_epochs

for epoch in range(num_epochs):
    model.train()
    loss_accumulated_train = 0.0  # Initialize to zero
    total_samples_train = 0  # Initialize to zero
    correct_predictions_train = 0  # Initialize to zero

    for x_batch, y_batch in train_loader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        output = model(x_batch)
        logits = output.logits
        loss = loss_fn(logits, y_batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        loss_accumulated_train += loss.item() * y_batch.size(0)  # Accumulate as a scalar
        total_samples_train += y_batch.size(0)

        # Calculate accuracy
        predicted = torch.max(logits, 1)[1]
        correct_predictions_train += torch.sum(predicted == y_batch).item()  # Accumulate as a scalar

    loss_hist_train[epoch] = loss_accumulated_train / total_samples_train  # Calculate average loss per batch
    accuracy_hist_train[epoch] = correct_predictions_train / total_samples_train  # Calculate accuracy directly

    scheduler.step()

    model.eval()
    with torch.no_grad():
        loss_accumulated_valid = 0.0  # Initialize to zero
        total_samples_valid = 0  # Initialize to zero
        correct_predictions_valid = 0  # Initialize to zero

        for x_batch, y_batch in valid_loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            output = model(x_batch)
            logits = output.logits
            loss = loss_fn(logits, y_batch)
            loss_accumulated_valid += loss.item() * y_batch.size(0)  # Accumulate as a scalar
            total_samples_valid += y_batch.size(0)

            # Calculate accuracy
            predicted = torch.max(logits, 1)[1]
            correct_predictions_valid += torch.sum(predicted == y_batch).item()  # Accumulate as a scalar

        loss_hist_valid[epoch] = loss_accumulated_valid / total_samples_valid  # Calculate average loss per batch
        accuracy_hist_valid[epoch] = correct_predictions_valid / total_samples_valid  # Calculate accuracy directly

    print(f'Epoch {epoch + 1} accuracy: {accuracy_hist_train[epoch]:.4f} val_accuracy: {accuracy_hist_valid[epoch]:.4f} loss: {loss_hist_train[epoch]:.4f} val_loss: {loss_hist_valid[epoch]:.4f}')
qasfb commented 8 months ago

What accuracy do you get on the training set ?

namrahrehman commented 8 months ago

@qasfb almost the same as validation accuracy. Over fitting is not an issue.

qasfb commented 8 months ago

StepLR(optimizer, step_size=step_size, gamma=0.0001) This multiplies your learning rate by 0.0001 every step-size=5 iterations, is my understanding correct ?

namrahrehman commented 8 months ago

Yes, so the learning rate decreases by a factor of 0.0001.

qasfb commented 8 months ago

I think this is why it doesn't work: after 5 epochs the learning rate essentially becomes 0 Can you try without that scheduling ?

namrahrehman commented 8 months ago

@qasfb I tried you suggestion on cifar10 and following are the results: With Scheduler: image

Without Scheduler: image

I trained without the scheduler for 20 more epochs, though it seems like the accuracy is increasing. Still, overall there is no significant difference in overall accuracy with or without the scheduler. The overall accuracy is in the 20s for both cases. With the scheduler, it converges faster.

Here is a link to the Colab notebook for these experiments if you want to take a detailed look: https://drive.google.com/file/d/1LmFgW-A5VzUeI6haFz7JkwAGCoKiDYxW/view?usp=sharing

jack89roberts commented 7 months ago

In case it's helpful (as I came across your issue whilst trying to debug something myself), I was getting similarly poor performance fine-tuning DINOv2 with the HuggingFace trainer defaults and found it was very sensitive to the learning rate. Reducing the learning rate to 5e-6 (from the default of 5e-5) achieved much better results (slightly better than just training a linear classification head on top of a frozen base model). This was with a linear scheduler on the learning rate in both cases (so starting at the initial values quoted above then reducing during training), which is also the HuggingFace default.

The learning rate you have above is much higher (1e-3), so maybe try something a lot smaller and see what happens?

namrahrehman commented 7 months ago

@jack89roberts I will try your suggestions and post my results here soon. Thank you so much.

If fine-tuning is not possible (or not the objective of the authors) then there needs to be some other way to increase Dinov2's performance with medical imaging data.

twmht commented 6 months ago

@namrahrehman

Any update on this?

lombardata commented 5 months ago

In case it's helpful (as I came across your issue whilst trying to debug something myself), I was getting similarly poor performance fine-tuning DINOv2 with the HuggingFace trainer defaults and found it was very sensitive to the learning rate. Reducing the learning rate to 5e-6 (from the default of 5e-5) achieved much better results (slightly better than just training a linear classification head on top of a frozen base model). This was with a linear scheduler on the learning rate in both cases (so starting at the initial values quoted above then reducing during training), which is also the HuggingFace default.

The learning rate you have above is much higher (1e-3), so maybe try something a lot smaller and see what happens?

Hi @jack89roberts , which dinov2 model dis you use for your training on HF? The Facebook/dinov2 models, the models finetuned on imagenet or the timm/dinov2 models? Do you know the difference between the Facebook and the Timm models? Thank you in advance and have a good day!

jack89roberts commented 5 months ago

I've used only the facebook/dinov2 ones for HuggingFace transformers (specifically facebook/dinov2-small-imagenet1k-1-layer and facebook/dinov2-base-imagenet1k-1-layer). I've not used the timm ones (or the ones downloadable from the repo/torch hub).

lombardata commented 5 months ago

I've used only the facebook/dinov2 ones for HuggingFace transformers (specifically facebook/dinov2-small-imagenet1k-1-layer and facebook/dinov2-base-imagenet1k-1-layer). I've not used the timm ones (or the ones downloadable from the repo/torch hub).

Thank you very much for these informations. So, if I've well understood, you trained all the model (the unfreezed one, backbone+head) starting with a lr = 5e-6 and linearly decreasing the value with the scheduler ? Have a good day!

jack89roberts commented 5 months ago

Yes that's right, just the HF trainer defaults with the lower learning rate basically.

Raspberry-beans commented 5 months ago

@jack89roberts Hi, Can you specify the GPU memory required for this process.

I will be training a linear head (with frozen DINOv2 backbone) on few custom medical images for segmentation. I have only 8GB of GPU memory available. Would it be enough as the backbone will be kept frozen?

Thanks in advance!

jack89roberts commented 5 months ago

You may be better off asking that elsewhere but from a quick look at the training jobs I have run with DINOv2 small/base I think that should be ok yes. I've not used the large/giant variants.

anonymouslei commented 2 months ago

@namrahrehman can you pls share the linear evaluation code? appreciate it!