Open sachinraja13 opened 1 month ago
It's probably a good idea to add an all zeros vjp w.r.t. indices for scatter like we did for gather for consistency.
Is this for your MLX implementation of DINO? Looking at the PyTorch implementation it seems like they get around this by zeroing out the gradients for the Hungarian matcher. Maybe that will work for you in the meantime?
It's probably a good idea to add an all zeros vjp w.r.t. indices for scatter like we did for gather for consistency.
Is this for your MLX implementation of DINO? Looking at the PyTorch implementation it seems like they get around this by zeroing out the gradients for the Hungarian matcher. Maybe that will work for you in the meantime?
Hi @barronalex : Yes, this is for the implementation of DINO. Seems like zeroing out the gradient should solve the problem for the simple greedy matcher. Not sure if it would solve the problem for the Hungarian Matcher though since it uses scipy's linear_sum_assignment. This is particularly a problem as it requires evaluating the cost matrix which in turn forbids me from calling mx.compile on the value_and_grad function for training. Please correct me if I'm mistaken here.
Yes, you would need an MLX implementation of linear_sum_assignment
to be able to compile the whole thing. That being said you should be able to compile the rest of model if you put @mx.compile
on the model forward but not the loss.
Thanks @barronalex , understood.
Refering from this jax implementation, there is another challenge in writing an MLX implementation in python that can be compiled:
Is HungarianMatcher
a big performance bottleneck when you're training? If not it might be worth leaving it uncompiled and using the scipy implementation the way the original PyTorch implementation does it.
We could definitely implement something like scipy.optimize.linear_sum_assignment
in MLX but my guess is it would require some custom C++/Metal to be competitive performance wise.
Thanks for your response @barronalex !
I realise that HungarianMatcher is not a big performance bottleneck. However, I'm facing memory inflation while in the MLX port of prepare_for_cdn function.
https://github.com/ml-explore/mlx/issues/1432
I thought that maybe if I compiled the entire computation graph, that would solve the problem.
Describe the bug I understand that HungarianMatching algorithm requires linear_sum_assignment from scipy, which needs cost matrix to be evaluated. Hence, I cannot compile my train step function. However, if I use SimpleMatching algorithm and then compile my train step, I get the following error:
Code for matching is as follows:
Also, is there a work around to get mx.compile working for Hungarian Matching algorithm?
Will greatly appreciate your help to solve this.
Additional context Using MLX 0.17.3