Open JustinMBrown opened 1 year ago
@JustinMBrown it's a reasonable idea, only issue is that it ends up being a big change, ALL pretrained checkpoints right now are bare state_dict with no extra layer in the dict, every key is a param/buffer and every value a tensor. The train checkpoints (which do have extra keys ) are stripped of everything but the state dict before being published. I followed torchvision and other 'original' libs when this decision was made long ago.
The timm load functions would support stripping this automatically (and could be modified to extract other specific keys like class maps, but it would break anyone trying to just use torch.load() which works right now... https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/_helpers.py#L31
I think I should stash the args in the train state dict regardless though, I've though about this as I've had numerous instances where I toasted the original train folders in disk cleanup and have only the checkpoint left and lost the hparams :/
Although I will point out, I've had multiple occasions where people have been provided with exact hparams, and I get an angry 'it doesn't work' because they don't understand things change when you change the global batch size, use a different dataset, etc the highest value is seeing templates and building an intuition for how to adapt different recipe templates in different uses ...
I will ponder this some more. FYI if you publish the weights to the HF hub, the pretrained_cfg has fields for labels, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/_pretrained.py#L41-L42 ... so for inference script and any side-loading it'd make sense to serialize/deserialize a pretraind_cfg instance w/ the weights.. the timm inference widget for Hub loads this as the cfg is built from the config.json file (https://huggingface.co/timm/nf_resnet50.ra2_in1k/blob/main/config.json) . Right now if you pass 'label_names' and 'label_descriptions' fields to the push_to_hub fn in timm vial model_cfg dict, the HF hub widget will do inference with the correct label names https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/_hub.py#L215-L227
See use of inferencing widget (inference.py should be updated to be closer to this by providing pretrained cfg json or state dicts with it embedded ) ... https://github.com/huggingface/api-inference-community/blob/main/docker_images/timm/app/pipelines/image_classification.py#L25-L42
Thanks @rwightman for the detailes, I am also intrested in the labels of the classes. I did try the code that you suggsetd:
model_id = 'timm/resnetrs350.tf_in1k'
model = timm.create_model(f"hf_hub:{model_id}", pretrained=True)
model.eval()
dataset_info = None
label_names = model.pretrained_cfg.get("label_names", None)
label_descriptions = model.pretrained_cfg.get("label_descriptions", None)
print(label_names,label_descriptions)
But I get None, None
any idea how to get the label per index ?
Thank you
Is your feature request related to a problem? Please describe. I'm not using imagenet, but during inference it loads the imagenet class_map by default.
Describe the solution you'd like Instead, the class_to_idx from the dataset_train.reader.class_to_idx should just be saved somewhere inside the model, and be loaded into class_map during inference by default. and of course, if someone still wants to override the class_map for whatever reason, they could still do so.
I'd make a PR myself, but y'all probably have other consideration for exactly where to save/load it, so here's a sample solution.
Sample solution:
Describe alternatives you've considered We could save the class_to_idx into a class map file and ship it along side the model, but that's cumbersome and tedious. The proposed solution just works by default.
Additional context The same should probably be done with the args.yaml file. There are a ton of timm models on hugging face with pretrained weights, but no args.yaml file with them which makes it near impossible to reproduce their results.