drivendataorg / zamba

A Python package for identifying 42 kinds of animals, training custom models, and estimating distance from camera trap videos
https://zamba.drivendata.org/docs/stable/
MIT License
110 stars 25 forks source link

Don't download model weights and load imagenet weights if using models for inference #144

Open pjbull opened 2 years ago

pjbull commented 2 years ago

We currently use load_from_checkpoint in our ModelManager to initialize models when doing inference. This can cause the models to download the pretrained imagenet weights from the internet even thought we don't need those. To address this, we need a parameter that we pass in to the __init__ of the model to indicate we are doing inference/loading from a checkpoint, and then we need to pass this paramter to load_from_checkpoint in the ModelManager.

We should check across all of our models for this behavior, but this is how it works for the time_distributed model:

ejm714 commented 2 years ago

This also arises if we are training a model that has labels which are a subset of the zamba labels, which means we "resume training" instead of replacing the head. This stems from the fact that finetune_from is still None in this case; we should instead do model_class(finetune_from={official_ckpt}) rather than load from checkpoint

https://github.com/drivendataorg/zamba/blob/7986c417f33839c0a8d14ac66201472acbfb393a/zamba/models/model_manager.py#L139-L147

In addition: we may want super().load_from_checkpoint instead here to avoid re-passing through the init with the timm weight download: https://github.com/drivendataorg/zamba/blob/7fb2a0f9599bf55bf2f538d6a0a736963cb9d9eb/zamba/models/efficientnet_models.py#L27

papapizzachess commented 10 months ago

The code has changed a lot. Has this bug been resolved? I'm trying to work on this if it's not resolved.

ejm714 commented 10 months ago

@papapizzachess yes this bug still exists and the code sections in the issue description are still correct.