Open rose4labour opened 8 months ago
Yes, you are correct. Imbalance ratios are used only by the loss function during training of the model. It doesn't have any impact on inferences since its only a forward pass.
Thank you for reporting this bug!
For the work around during inference, you can use an arbitrary n_tasks size array. The value n_tasks should be the same as the value used during training of the model.
Yes, you are correct. Imbalance ratios are used only by the loss function during training of the model. It doesn't have any impact on inferences since its only a forward pass.
Thank you for reporting this bug!
For the work around during inference, you can use an arbitrary n_tasks size array. The value n_tasks should be the same as the value used during training of the model.
Thanks for your answer. What about the best training parameters? For example, num_step_message_passing = 5 what about num_step_message_passing =4 or 6?
Below is the model loading code, for the code “class_imbalance_ratio = train_ratios,“ My test indicates that loading the model requires specifying the train_ratios. I only need inference without a test csv file, so I specify 138 values fo train_ratios, like [0.1,0.2,...] I've tested it several times, and different 138 values of train_ratios seem to be independent of the inference results. Is this correct? if I only need to do inference, which part of parameters can I modify for a specific model weight file (such as https://github.com/ARY2260/openpom/blob/main/examples/example_model.pt)? THANK U!
model = MPNNPOMModel(n_tasks=n_tasks, batch_size=128, learning_rate=learning_rate, class_imbalance_ratio = train_ratios, loss_aggr_type = 'sum', node_out_feats = 100, edge_hidden_feats = 75, edge_out_feats = 100, num_step_message_passing = 5, mpnn_residual = True, message_aggregator_type = 'sum', mode = 'classification', number_atom_features = GraphConvConstants.ATOM_FDIM, number_bond_features = GraphConvConstants.BOND_FDIM, n_classes = 1, readout_type = 'set2set', num_step_set2set = 3, num_layer_set2set = 2, ffn_hidden_list= [392, 392], ffn_embeddings = 256, ffn_activation = 'relu', ffn_dropout_p = 0.12, ffn_dropout_at_input_no_act = False, weight_decay = 1e-5, self_loop = False, optimizer_name = 'adam', log_frequency = 32,
model_dir = f'./example_model.pt',