Closed Lucky-Stark-01 closed 2 months ago
Thank you for your interest on Actfound, Actfound is a meta learning model designed for few-shot bioactivity prediction, and we only finetune the last linear layer during finetuning process, which is inplemented in model.run_validation_iter(cur_data). The finetuning process is implemented in the inner_loop function of system_base:
def inner_loop(self, x_task, y_task, assay_idx, split, is_training_phase, epoch, num_steps):
task_losses = []
per_step_loss_importance_vectors = self.get_per_step_loss_importance_vector()
support_loss_each_step = []
sup_num = torch.sum(split)
names_weights_copy = self.get_inner_loop_parameter_dict(self.regressor.named_parameters())
use_second_order = self.args.second_order and epoch > self.args.first_order_to_second_order_epoch
for num_step in range(num_steps):
support_loss, support_preds = self.net_forward(x=x_task,
y=y_task,
assay_idx=assay_idx,
split=split,
is_support=True,
weights=names_weights_copy,
backup_running_statistics=
True if (num_step == 0) else False,
training=True, num_step=num_step)
support_loss_each_step.append(support_loss.detach().cpu().item())
if support_loss >= support_loss_each_step[0] * 10 or sup_num <= 5:
pass
else:
names_weights_copy = self.apply_inner_loop_update(loss=support_loss,
names_weights_copy=names_weights_copy,
use_second_order=use_second_order if is_training_phase else False,
current_step_idx=num_step,
sup_number=torch.sum(split))
if is_training_phase:
is_multi_step_optimize = self.args.use_multi_step_loss_optimization and epoch < self.args.multi_step_loss_num_epochs
is_last_step = num_step == (self.args.num_updates - 1)
if is_multi_step_optimize or is_last_step:
target_loss, target_preds = self.net_forward(x=x_task,
y=y_task,
assay_idx=assay_idx,
split=split,
weights=names_weights_copy,
backup_running_statistics=False,
training=True,
num_step=num_step)
if is_multi_step_optimize:
task_losses.append(per_step_loss_importance_vectors[num_step] * target_loss)
elif is_last_step:
task_losses.append(target_loss)
return names_weights_copy, support_loss_each_step, task_losses
The parameter of last linear layer after finetuning is given in "final_weights", and the evaluation result after finetuning is given in "per_task_target_preds". So you only need to run finetune and save the result you want, without need to run inference_main(). Please let me know if there are any other questions~~~
Here, In this code i'll explain briefly what I did for predicting the activity value of unknown compounds. I considered only two types of assay namely "test" and "valid". I've created a data-loader for that data with split_train_test_val from json file.
While looking into this problem simply, it's a task of fine-tuning and getting predictions for the unknown compounds.
Point-1: You're saying that to use model.run_validation_iter(cur_data)
to finetune and save predictions per_task_target_preds
.
But my intuition here is to finetune the base meta model with my assay type data and get activity predictions for similar unknown compounds (This is my core problem).
Point-2: In this code model.run_validation_iter(cur_data), you're using actual y lables and getting predictions relative to the actual known values.
Description:
After getting these values from `inner_loop` function, `names_weights_copy, support_loss_each_step, task_losses`, you're sending them to net_forward() function with updated weights. After that you're getting some out_value predictions from self.regressor.forward() method. But those values doesn't make any sense. You're then using those values and doing some calculations with actual y_task values and sending us the predictions.
So, why you're taking actual labels along with out_values? What is the exact meaning of those predictions from self.net_forwards function?
Here, I updated the base model linear layer weights with the final weights I got after the finetuning the model and doing model.regressor.forward() on unknown x_task data to get some values (y predictions ~ activity values)...But I'm not getting valid predictions. How to proceed further and what is the correct code to get valid activity predictions for unknown compounds after finetuning the base model with specific assay.
Hi, I think your question is to predict the activity of compounds with unknown activity from a specific assay(your test or valid assay), providing that there are some compounds with known activity in this assay (Please correct me if I was wrong). This code was implemented to do a large number of comparative experiments in several datasets, and what you want is not implemented here. I will implement it in the next few days to make Actfound more convenient to use.
As for the self.regressor.forward() method, x means input feature of all compounds(2048-dim), y means their activity, split means the split of compounds for this assay [1. means the compounds used for finetuning (which we usually call "support set" in meta-learning), and 0. means the compounds used for testing (which we usually call "target set" in meta-learning) ]. The Actfound base model is finetuned on the "support set", and give prediction on the "target set". Actfound is a pairwise learning model, which means the model predict the activity difference between compounds in an assay, and then reconstruct the final predicted activity of compounds in "target set", so it is ensential to input y for the compounds in "support set", and the y of the compounds in the "target set" is not used, please don't worry about data leakage problem here.
The data partitioning part of meta-learning is different from that of general machine learning tasks, which might be a little bit confusing. Please don't worry about that cause I will provide you a convident method very soon. Please let me know if there are any more questions.
Could you please check, why I'm getting not valid results (activity predictions) after fine-tuning the meta model:
actfound
. This is the code I've written separately for finetuning and inferencing with validation dataset:Output Predictions I got:
[-0.6336, -0.4449, -0.5665, -0.6252, -0.4101, -0.9025, -0.9602, -0.7823, -0.8399, -0.7683, -0.8970, -0.7017, -0.7930, -0.7291, -0.9378, -0.2267, -0.4170, 0.0460, -0.1171, -0.4783, -0.5798, -1.5286, -0.2624, -0.3334, -0.3932, -1.0339, -0.9790, -1.0076, -0.2937, -1.1409, -1.1365]