Open GokulNC opened 1 year ago
Can you please check once if it works on the latest torch versions? If not, can you please push a requirements.txt
with the appropriate versions locked? Thanks!
Modifying the MODEL_LIST
dict like this helps:
def get_model_refs(all_model_modules):
return [model_name for model_name in all_model_modules[1:] if not model_name.endswith("Weights")]
MODEL_LIST = {
models.densenet: get_model_refs(models.densenet.__all__),
models.mnasnet: get_model_refs(models.mnasnet.__all__),
models.mobilenet: get_model_refs(models.mobilenet.mv2_all) + get_model_refs(models.mobilenet.mv3_all),
models.resnet: get_model_refs(models.resnet.__all__),
models.shufflenetv2: get_model_refs(models.shufflenetv2.__all__),
models.squeezenet: get_model_refs(models.squeezenet.__all__),
models.vgg: get_model_refs(models.vgg.__all__),
}
Also, removing the explicit pretrained=False
helped.
Thanks for letting me know. You can raise a PR for the corrected version, and I'll merge it. Thank you.
Sure, please keep this issue open so that I can document other issues as well before raising a PR.
The data_csv_to_json.py file also seems to require the above MODEL_LIST
change.
Also for some reason, the script skips ["mobilenet_v2", "MobileNetV3"]
. That should be removed.
Thanks for open-sourcing this. I am trying to run the script, but it throws the following error: