mlfoundations / wise-ft

Robust fine-tuning of zero-shot models
https://arxiv.org/abs/2109.01903
Other
654 stars 67 forks source link

Added first pass of Fisher computation and merging #2

Closed mmatena closed 2 years ago

mmatena commented 2 years ago

Fisher computation for the fine-tuned model works as well as merging. However, I run into errors when computing the Fisher for the zero-shot model.

Here is the error. The input to the classification head has shape [1, 3, 224, 224] for the zero-shot model.

Traceback (most recent call last):                                                                                                                              
  File "src/models/fisher.py", line 135, in <module>
    compute_fisher(args)
  File "src/models/fisher.py", line 105, in compute_fisher
    logits = utils.get_logits(inputs, model)
  File "/home/owner/Desktop/projects/wise-ft/src/models/utils.py", line 76, in get_logits
    return classifier(inputs)
  File "/home/owner/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/owner/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 159, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/home/owner/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/owner/Desktop/projects/wise-ft/src/models/modeling.py", line 77, in forward
    outputs = self.classification_head(inputs)
  File "/home/owner/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/owner/Desktop/projects/wise-ft/src/models/modeling.py", line 52, in forward
    return super().forward(inputs)
  File "/home/owner/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/owner/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/nn/functional.py", line 1692, in linear
    output = input.matmul(weight.t())
RuntimeError: mat1 dim 1 must match mat2 dim 0
mmatena commented 2 years ago

To give some information on this contribution, I added the src/models/fisher.py script which computes the Fisher and saves it to a pytorch checkpoint. Since there is a 1-to-1 correspondence with the diagonal Fisher weights and model parameters, it makes sense to save it this way.

To the src/wise_ft.py script, I added the --fishers flag. If not set, then it does the isotropic merge that it was doing before. If present, it should be a comma-separated list of paths to the saved fisher checkpoints.

Here are the command I was using. The one with the # Errors comment is what produced the stack trace above.

MODELS_DIR=~/Desktop/projects_data/wise-ft/models
FISHERS_DIR=~/Desktop/projects_data/wise-ft/fishers
DATA_LOCATION=~/Desktop/projects_data/wise-ft/data

# Works
python src/models/fisher.py   \
    --train-dataset=CIFAR10  \
    --load=$MODELS_DIR/checkpoint_10.pt  \
    --data-location=$DATA_LOCATION \
    --fisher=FISHERS_DIR/fisher_checkpoint_10.pt \
    --epochs=1

# Errors
python src/models/fisher.py   \
    --train-dataset=CIFAR10  \
    --model=ViT-B/32  \
    --template=openai_imagenet_template  \
    --load=$MODELS_DIR/zeroshot.pt  \
    --data-location=$DATA_LOCATION \
    --fisher=FISHERS_DIR/fisher_zeroshot.pt \
    --epochs=1

# Works
python src/wise_ft.py   \
    --eval-datasets=CIFAR10  \
    --load=$MODELS_DIR/zeroshot.pt,$MODELS_DIR/checkpoint_10.pt  \
    --fisher=$FISHERS_DIR/fisher_checkpoint_10.pt,$FISHERS_DIR/fisher_zeroshot.pt \
    --results-db=/tmp/results.jsonl  \
    --save=/tmp/wiseft  \
    --data-location=$DATA_LOCATION \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
gabrielilharco commented 2 years ago

Hey @mmatena, thanks a lot for this! I think the issue here is that the model is trying to process cached features, but it's receiving raw images as inputs. To inform the model that it should process raw images, you can set model.process_images = True. I see that was being done in your fisher.py file, but only if _TRAIN_PREPROCESSING is set to True, which is not. Could you try moving model.process_images = True outside of the if statement here?

mmatena commented 2 years ago

@gabrielilharco Thanks! Zero-shot fisher computation runs without any issues now.