jaekyeom / MABAS

MIT License
16 stars 1 forks source link

Details about implementation #1

Open akshayv1 opened 3 years ago

akshayv1 commented 3 years ago

Hi, Thanks for making the code publicly available. I had some questions about the implementation and it would be great if you could clarify them.

  1. I had some trouble understanding when the classifier is updated? Does the algorithm involve first tuning the embedding function (assuming random weights of classifier) and then updating the classifier? Or since the embedding function updates depend on the weights of the classifier(because of the boundary adversarial example), does this happen iteratively (update the embedding function, then classifier, again the embedding function and so on)?

  2. I was also facing issues running the code on Python3 with Pytorch 1.7.1. Training runs smoothly, but when running evaluation_finetune.py, I get "RuntimeError: derivative for floor_divide is not implemented" when backpropagating the loss. I had some trouble understanding which operation was causing this error.

Thanks.

jaekyeom commented 3 years ago

Hi, Thanks for checking out our code!

  1. If you are asking about the method, it performs multiple iterations of obtaining the class score function h and updating the embedding function with adversarial samples. Here, the class score function h() is a function that outputs a score for each class given a sample i.e. argmax of h() over classes returns the predicted class. Since we target meta-learning methods, the class score function h() is determined differently for different tasks (e.g. in each task, MetaOptNet solves the SVM problem in the embedding space given labeled samples to determine h()). And, we don't fine-tune parameters of h(), as they will have to be computed based on the updated embedding function anyway.

  2. I'm not completely sure given the information, but my guess is that maybe you added some code that uses torch.floor_divide (or //) to the fine-tuning module and the error occurs becuase that operation is not differentiable? If you didn't make such modifications, one option is to check if it runs without errors in the environment with which we tested (Python 2.7 and PyTorch 1.1.0), first.

I hope this answers your questions, but if it doesn't, please feel free to leave another comment.

Thanks.