octavian-ganea / equidock_public

EquiDock: geometric deep learning for fast rigid 3D protein-protein docking
MIT License
231 stars 58 forks source link

Hyperparameters #12

Open PBordesInstadeep opened 2 years ago

PBordesInstadeep commented 2 years ago

Hello ! It is a bit unclear to me which hyper parameters you used to train your model, could you provide a complete list of your best models for DIPS and DB5? In particular, I am not sure whether node and edge features were used. Moreover, the hyperparameters you mention in the paper are not the same as your best model's checkpoints. Thanks :)

zhenpingli commented 1 year ago

Hello ! It is a bit unclear to me which hyper parameters you used to train your model, could you provide a complete list of your best models for DIPS and DB5? In particular, I am not sure whether node and edge features were used. Moreover, the hyperparameters you mention in the paper are not the same as your best model's checkpoints. Thanks :)

I find the same question

zhenpingli commented 1 year ago

Hello ! It is a bit unclear to me which hyper parameters you used to train your model, could you provide a complete list of your best models for DIPS and DB5? In particular, I am not sure whether node and edge features were used. Moreover, the hyperparameters you mention in the paper are not the same as your best model's checkpoints. Thanks :)

Hi: I find a to extract the hyper-parameters from checkpts. this is the main code .

` args = parser.parse_args().dict

args['input_edge_feats_dim'] = 64 args['device'] = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") print(f"Available GPUS:{torch.cuda.device_count()}")

stopper = EarlyStopping(mode='lower', patience=00, filename="./checkpts/db5/db5_model_best.pth", log=log)

model = create_model(args, log) optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['w_decay']) model, optimizer, args2, epoch = stopper.load_checkpoint(model, optimizer) for k in args2.keys(): if k not in ['device', 'debug', 'worker', 'n_jobs', 'toy']: print(args[k])

`