facebookresearch / dinov2

PyTorch code and models for the DINOv2 self-supervised learning method.
Apache License 2.0
9.03k stars 791 forks source link

How to train video action recognition task on Dinov2? #16

Closed steveice closed 1 year ago

steveice commented 1 year ago

Can you provide the guideline about how to finetune the model to do the video action recognition task?

Thank you!!

patricklabatut commented 1 year ago

If you are trying to reproduce the video action recognition results, these are described in the second paragraph of sub-section 7.2 of the paper. These results were obtained with a linear classifier trained on features from a number of evenly spaced frames, without any fine-tuning.

steveice commented 1 year ago

Thank you for explaining the detailed. Can you provide more details about "For SSv2, we opt for concatenation to retain more temporal information than with feature averaging. "? How many frames was picked for SSv2 task?

ccharest93 commented 1 year ago

for N extracted frames, pass them throught the network and get N CLS tokens. Two methods going forward:

  1. Average CLS tokens embedding over the N frames so that the input to your classification head is 1x Embed_Dim
  2. Concatenate the N CLS token so that the input to your classification head is Nx Embed_Dim

One will train over less parameters because the input dimension is smaller, Two retains more information but has much higher param count in classification head.

You could also do something in the middle where you average the first N/2 CLS token and the last N/2 CLS token giving you input dimension for classification head of 2 x Embed_dim.

woctezuma commented 1 year ago

Also, as mentioned in the paper, concatenation allows to "retain [...] temporal information" compared to average pooling.

How many frames was picked for SSv2 task?

I imagine it is N=8 frames as well, but I understand that you would like confirmation from the authors.

Paper

pierrefdz commented 1 year ago

Also, as mentioned in the paper, concatenation allows to "retain [...] temporal information" compared to average pooling.

How many frames was picked for SSv2 task?

I imagine it is N=8 frames as well, but I understand that you would like confirmation from the authors.

Paper

Hi, author here! I confirm that what @ccharest93 said is correct, and N=8 in both cases.

The underlying idea is that:

Remarks:

I hope this is useful

pierrefdz commented 1 year ago

Closing as answered, thanks for your interest!

Batwho commented 1 year ago

Could you please also provide the detailed linear classifier structure and related key parameters? I tried using a simple MLP (with just two linear layers) on the UCF dataset but the accuracy is pretty low.

pierrefdz commented 1 year ago

Hi, results from the paper are done using a single layer as linear classifier. What do you mean by "pretty low"? Would you be able to share your implementation and hyper-parameters for the optimization of the layer?

Batwho commented 1 year ago

Hi! Thanks for your quick response. The accuracy after 10 epochs I got is around 5%. So I guess there is probably something I didn't do right. I tried a single layer with 384 input dim (using vits14) to 101 classes. The optimizer is Adam with lr=0.001, StelpLR scheduler with step_size =1, and gamma=0.95.

If this setting should make it work, I could then show my code as well.

Batwho commented 1 year ago

code:

# Dataset Class
class UCFDataset(torch.utils.data.Dataset):

    def __init__(self, dataset_dir, subset, video_list_file, frames_per_clip=16):
        super().__init__()
        self.dataset_dir = dataset_dir
        self.video_dir = video_dir
        self.subset=subset
        self.video_list_file = video_list_file
        self.video_list = []
        self.labels = []
        self.indices = []

        for i in [1,2,3]:
            with open(f'{dataset_dir}/{video_list_file}{str(i)}.txt') as video_names_file:
                if self.subset=="train":
                    tempvideo_list,templabels = zip(*(files[:-1].split() for files in video_names_file.readlines()))
                    self.video_list += tempvideo_list
                    self.labels += templabels
                else:
                    tempvideo_list = [files[:-1] for files in video_names_file.readlines()]
                    templabels = [None]
                    self.video_list += tempvideo_list
                    self.labels += templabels

        self.frames_per_clip = frames_per_clip

        self.transform = tv.transforms.Compose([
          tv.transforms.GaussianBlur(9, sigma=(0.1, 2.0)),
          tv.transforms.Resize(256,interpolation=tv.transforms.InterpolationMode.BICUBIC),
          tv.transforms.CenterCrop(224),
          tv.transforms.ToTensor(),
          tv.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ])

    def __len__(self):
        return len(self.video_list)

    def __getitem__(self, idx):
        videoname = f'{self.video_list[idx]}'
        vid = decord.VideoReader(f'{self.video_dir}/{videoname}', ctx=decord.cpu(0))
        nframes = len(vid)

        # if number of frames of video is less than frames_per_clip, repeat the frames
        if nframes <= self.frames_per_clip:
            idxs = np.arange(0, self.frames_per_clip).astype(np.int32)
            idxs[nframes:] %= nframes

        # else if frames_per_clip is greater, sample uniformly seperated frames
        else:
            idxs = np.linspace(0, nframes-1, self.frames_per_clip)
            idxs = np.round(idxs).astype(np.int32)

        imgs = []
        for k in idxs:
            frame = Image.fromarray(vid[k].asnumpy())
            frame = self.transform(frame)
            imgs.append(frame)
        imgs = torch.stack(imgs)

        # if its train subset, return both the frames and the label 
        if self.subset=="train":
            label = int(self.labels[idx]) - 1    
        # else, for test subset, read the label index
        else:
            with open(f'{dataset_dir}/classInd.txt') as classIndices:
                label=int(classIndices[videoname.split('/')[0]])
        return imgs,label

class MLP(nn.Module):

    def __init__(self, dim, inner_dim,n_class,encoder):     #dim would be the output image feature from dinov2                                
        super().__init__()
        # mlp with GELU activation function
        self.encoder = encoder
        self.mlp = nn.Sequential(
            nn.Linear(dim, n_class),
        )

    def forward(self, x):
        # x is [16,8,3,224,224]
        avg = []

        for i in range(8):
            xi = x[:,i,:]
            #encode x to [8,384]
            with torch.no_grad():
                e = self.encoder(xi).reshape(x.shape[0],1,384)
            avg.append(e)
        avg = torch.cat(avg,dim=1)    
        avg = reduce(avg, "f t c -> f c",'mean')        #[16,384]
        return self.mlp(avg)

# Dataset
train_val_data = UCFDataset( dataset_dir = dataset_dir, subset="train", video_list_file="trainlist0",frames_per_clip=frames_per_clip)

train_len=int(0.85*len(train_val_data))
train_val_split = [ train_len, len(train_val_data) - train_len ] 

train_data , val_data = random_split(train_val_data,train_val_split)
test_data = UCFDataset( dataset_dir = dataset_dir, subset="test", video_list_file="testlist0" ,frames_per_clip=frames_per_clip)

# Dataloaders
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=test_batch_size)
test_loader = DataLoader(test_data, batch_size=test_batch_size)

# data loading params
batch_size = 256
test_batch_size = 1
num_workers = 8
pin_memory = True
num_classes=101

dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
dinov2_vits14.to(device)
for param in dinov2_vits14.parameters():
    param.requires_grad= False
model = MLP(384,512,101,dinov2_vits14)
#frames, _ = next(iter(train_loader))
#tb_writer.add_graph(model, frames)
model.to(device)

# define the loss and optimizers
loss_criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

# training step for every epoch
def train_step(loader,epoch,):

    model.train()
    total_epoch_loss=0

    for batch_id, (video_data,labels) in enumerate(loader):

        # video_data,labels = video_data.to(device), labels.to(device)
        video_data,labels = video_data.to(device), labels.to(device)

        optimizer.zero_grad()

        prediction = model(video_data)

        loss = loss_criterion(prediction,labels)
        total_epoch_loss += loss.item()

        loss.backward()

        optimizer.step()

        del video_data
        del labels

        gc.collect()

        #tb_writer.add_scalar("Train/Loss",loss.item(),((len(loader))*(epoch-1))+batch_id)

        print(f"\n[Train Epoch]: {epoch} Train Loss: {loss.item()}")
    return total_epoch_loss

# validation step for every epoch
def val_step(loader,epoch=None):

    model.eval()
    total_loss=0
    corrects=0

    with torch.no_grad():
        for batch_id, (video_data,labels) in enumerate(loader):

            video_data,labels = (video_data).to(device), labels.to(device)

            prediction = model(video_data)

            loss = loss_criterion(prediction,labels)
            total_loss += loss.item()
            corrects+= (torch.argmax(prediction,dim=1)==labels).sum()

    accuracy = corrects/(len(loader)*batch_size)

    print(f"\n[Val Epoch]: {epoch} , Accuracy: {accuracy}, Valid Loss: {loss.item()}")

    return accuracy

# Driving train test loop
for epoch in tqdm(range(1,epochs+1)):
    train_step(train_loader, epoch)
    val_step(val_loader, epoch)
    scheduler.step()
    torch.save(model,"dino_model.pt")
pierrefdz commented 1 year ago

Thanks, I'll try to have a more thorough look at it in the following days.

Some points that might cause the problem:

with torch.no_grad():    
    B,C,T,H,W = inp.shape
    inp = inp.transpose(1,2).reshape(B*T,C,H,W) # b c t h w -> b t c h w -> b*t c h w
    output = model(inp)
    output = output.reshape(B,T,-1) # b*t d -> b t d
    output = output.mean(dim=-2) # b t d -> b d
output = linear_classifier(output) # b d -> b l
Batwho commented 1 year ago

Thanks, I'll try to have a more thorough look at it in the following days.

Some points that might cause the problem:

  • maybe try switching to SGD with different LRs (try without lr scheduling, it will have a minor impact)
  • remove the gaussian blur augmentation
  • check that the forward function in your MLP does the right thing (that everything has the right dimension, that you average on the time dimension, etc.). My code snippet for feeding the videos is:
with torch.no_grad():    
    B,C,T,H,W = inp.shape
    inp = inp.transpose(1,2).reshape(B*T,C,H,W) # b c t h w -> b t c h w -> b*t c h w
    output = model(inp)
    output = output.reshape(B,T,-1) # b*t d -> b t d
    output = output.mean(dim=-2) # b t d -> b d
output = linear_classifier(output) # b d -> b l

Could you please also share your epoch, batch size, and lr? I guess it might be the reason that I haven't trained enough time due to limited GPU RAM.

Batwho commented 1 year ago

Problem solved, it was due to a bug in dataloader at val dataloader initialization. Thank you @pierrefdz and feel free to close this issue.

pierrefdz commented 1 year ago

Thanks for keeping me updated on this @Batwho. Don't hesitate to re-open if you need anything else.