Dipeshtamboli / GPU-Benchmarking

A code to benchmark GPU performance on different models
1 stars 0 forks source link

TypeError: __call__() got an unexpected keyword argument 'pretrained' #1

Open GokulNC opened 1 year ago

GokulNC commented 1 year ago

Thanks for open-sourcing this. I am trying to run the script, but it throws the following error:

Traceback (most recent call last):
  File "/home/username/tmp/GPU-Benchmarking/torch_train_gpu.py", line 183, in <module>
    train_result = train(precision)
  File "/home/username/tmp/GPU-Benchmarking/torch_train_gpu.py", line 81, in train
    model = getattr(model_type, model_name)(pretrained=False)
TypeError: __call__() got an unexpected keyword argument 'pretrained'
GokulNC commented 1 year ago

Can you please check once if it works on the latest torch versions? If not, can you please push a requirements.txt with the appropriate versions locked? Thanks!

GokulNC commented 1 year ago

Modifying the MODEL_LIST dict like this helps:

def get_model_refs(all_model_modules):
    return [model_name for model_name in all_model_modules[1:] if not model_name.endswith("Weights")]

MODEL_LIST = {
    models.densenet: get_model_refs(models.densenet.__all__),
    models.mnasnet: get_model_refs(models.mnasnet.__all__),
    models.mobilenet: get_model_refs(models.mobilenet.mv2_all) + get_model_refs(models.mobilenet.mv3_all),
    models.resnet: get_model_refs(models.resnet.__all__),
    models.shufflenetv2: get_model_refs(models.shufflenetv2.__all__),
    models.squeezenet: get_model_refs(models.squeezenet.__all__),
    models.vgg: get_model_refs(models.vgg.__all__),
}
GokulNC commented 1 year ago

Also, removing the explicit pretrained=False helped.

Dipeshtamboli commented 1 year ago

Thanks for letting me know. You can raise a PR for the corrected version, and I'll merge it. Thank you.

GokulNC commented 1 year ago

Sure, please keep this issue open so that I can document other issues as well before raising a PR.

GokulNC commented 1 year ago

The data_csv_to_json.py file also seems to require the above MODEL_LIST change.

Also for some reason, the script skips ["mobilenet_v2", "MobileNetV3"]. That should be removed.