Closed mmatena closed 2 years ago
To give some information on this contribution, I added the src/models/fisher.py
script which computes the Fisher and saves it to a pytorch checkpoint. Since there is a 1-to-1 correspondence with the diagonal Fisher weights and model parameters, it makes sense to save it this way.
To the src/wise_ft.py
script, I added the --fishers
flag. If not set, then it does the isotropic merge that it was doing before. If present, it should be a comma-separated list of paths to the saved fisher checkpoints.
Here are the command I was using. The one with the # Errors
comment is what produced the stack trace above.
MODELS_DIR=~/Desktop/projects_data/wise-ft/models
FISHERS_DIR=~/Desktop/projects_data/wise-ft/fishers
DATA_LOCATION=~/Desktop/projects_data/wise-ft/data
# Works
python src/models/fisher.py \
--train-dataset=CIFAR10 \
--load=$MODELS_DIR/checkpoint_10.pt \
--data-location=$DATA_LOCATION \
--fisher=FISHERS_DIR/fisher_checkpoint_10.pt \
--epochs=1
# Errors
python src/models/fisher.py \
--train-dataset=CIFAR10 \
--model=ViT-B/32 \
--template=openai_imagenet_template \
--load=$MODELS_DIR/zeroshot.pt \
--data-location=$DATA_LOCATION \
--fisher=FISHERS_DIR/fisher_zeroshot.pt \
--epochs=1
# Works
python src/wise_ft.py \
--eval-datasets=CIFAR10 \
--load=$MODELS_DIR/zeroshot.pt,$MODELS_DIR/checkpoint_10.pt \
--fisher=$FISHERS_DIR/fisher_checkpoint_10.pt,$FISHERS_DIR/fisher_zeroshot.pt \
--results-db=/tmp/results.jsonl \
--save=/tmp/wiseft \
--data-location=$DATA_LOCATION \
--alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
Hey @mmatena, thanks a lot for this! I think the issue here is that the model is trying to process cached features, but it's receiving raw images as inputs. To inform the model that it should process raw images, you can set model.process_images = True
. I see that was being done in your fisher.py
file, but only if _TRAIN_PREPROCESSING
is set to True
, which is not. Could you try moving model.process_images = True
outside of the if statement here?
@gabrielilharco Thanks! Zero-shot fisher computation runs without any issues now.
Fisher computation for the fine-tuned model works as well as merging. However, I run into errors when computing the Fisher for the zero-shot model.
Here is the error. The input to the classification head has shape [1, 3, 224, 224] for the zero-shot model.