vidarlab / multi-view-hybrid

Official repository for the WACV 2024 paper "Multi-view Classification with Hybrid Fusion and Mutual Distillation"
1 stars 0 forks source link

Generating Predictions for New Samples with MV-HFMD Model #1

Closed Otoliths closed 5 months ago

Otoliths commented 6 months ago

Hello,

I've successfully trained the MV-HFMD model on the Hotels-8k dataset using the instructions provided in the official repository for the WACV 2024 paper "Multi-view Classification with Hybrid Fusion and Mutual Distillation." I'm now interested in applying the trained model to generate predictions for new, unseen samples.

Could you please provide guidance on how to use the trained model to generate predictions for new multi-view collections? Specifically, I'm looking for instructions on how to prepare the input data and how to use the trained model(last.pth and best.pth) to output predictions for both individual images and the entire multi-view collection. Additionally, I would like to know how to visualize the prediction results similar to the provided example (how-to-predict-new-samples-with-your-pytorch-model).

Thank you for your assistance.

sblack15 commented 6 months ago

If a model is trained, then you can load it with:

best_weights = torch.load(f'{args.save_dir}/best.pth')
model.load_state_dict(best_weights['model'])

You can generate predictions on the validation set by instantiating the dataloder classes found in main.py (note that you need to instantiate the train dataset first to find all the classes, and feed this as an argument to the val dataset class):

dataset_train = HotelsDataset(args.data_dir, split='train', n=args.n, train=True)
dataset_val = HotelsDataset(args.data_dir, split='val', n=args.n, classes=dataset_train.classes, train=False)

loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=128, shuffle=False, num_workers=args.num_workers, drop_last=False, pin_memory=True)

The images found in each batch have shape BxNx3xHxW (B = batch_size, N = num views / collection).

After feeding a batch of images to the model, it will output a dictionary with two keys, "mv_collection" and "single". These correspond to the logits for the entire collection (shape BxK) and then the predictions for each individual image found in a collection (shape (B*N) x K), respectively.

for i, (images, targets, paths) in enumerate(oader_val):
  output = model(images)
  mv_logits = output['mv_collection']['logits']
  single_view_logits = output['single']['logits']

  mv_pred = torch.argmax(mv_logits, dim=1)
  single_view_pred = torch.argmax(single_view_logits, dim=1)