ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.21k stars 995 forks source link

[BUG] ValueError: [scatter] Cannot calculate VJP with respect to indices. #1439

Open sachinraja13 opened 1 month ago

sachinraja13 commented 1 month ago

Describe the bug I understand that HungarianMatching algorithm requires linear_sum_assignment from scipy, which needs cost matrix to be evaluated. Hence, I cannot compile my train step function. However, if I use SimpleMatching algorithm and then compile my train step, I get the following error:

    (loss_value, loss_dict), grads = train_step_fn(samples, targets, need_tgt_for_training, return_outputs=False)
  File "site-packages/mlx/nn/utils.py", line 35, in wrapped_value_grad_fn
    value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
ValueError: [scatter] Cannot calculate VJP with respect to indices.

Code for matching is as follows:

        C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
        C = C.reshape(bs, num_queries, -1)

        sizes = [v["num_objects"] for v in targets]
        indices = []

        for i, c in enumerate(C):
            # print(c.shape)
            if i == 0:
                start_index = 0
                end_index = start_index + sizes[i]
            else:
                start_index = sizes[i-1]
                end_index = start_index + sizes[i]
            cost_matrix = c[:, start_index:end_index]
            size_ = cost_matrix.shape[1]
            idx_i = cost_matrix.argmin(0)
            idx_j = mx.arange(size_)
            indices.append((idx_i, idx_j))

Also, is there a work around to get mx.compile working for Hungarian Matching algorithm?

Will greatly appreciate your help to solve this.

Additional context Using MLX 0.17.3

barronalex commented 1 month ago

It's probably a good idea to add an all zeros vjp w.r.t. indices for scatter like we did for gather for consistency.

Is this for your MLX implementation of DINO? Looking at the PyTorch implementation it seems like they get around this by zeroing out the gradients for the Hungarian matcher. Maybe that will work for you in the meantime?

sachinraja13 commented 1 month ago

It's probably a good idea to add an all zeros vjp w.r.t. indices for scatter like we did for gather for consistency.

Is this for your MLX implementation of DINO? Looking at the PyTorch implementation it seems like they get around this by zeroing out the gradients for the Hungarian matcher. Maybe that will work for you in the meantime?

Hi @barronalex : Yes, this is for the implementation of DINO. Seems like zeroing out the gradient should solve the problem for the simple greedy matcher. Not sure if it would solve the problem for the Hungarian Matcher though since it uses scipy's linear_sum_assignment. This is particularly a problem as it requires evaluating the cost matrix which in turn forbids me from calling mx.compile on the value_and_grad function for training. Please correct me if I'm mistaken here.

barronalex commented 1 month ago

Yes, you would need an MLX implementation of linear_sum_assignment to be able to compile the whole thing. That being said you should be able to compile the rest of model if you put @mx.compile on the model forward but not the loss.

sachinraja13 commented 1 month ago

Thanks @barronalex , understood.

Refering from this jax implementation, there is another challenge in writing an MLX implementation in python that can be compiled:

https://github.com/ml-explore/mlx/issues/1441

barronalex commented 1 month ago

Is HungarianMatcher a big performance bottleneck when you're training? If not it might be worth leaving it uncompiled and using the scipy implementation the way the original PyTorch implementation does it.

We could definitely implement something like scipy.optimize.linear_sum_assignment in MLX but my guess is it would require some custom C++/Metal to be competitive performance wise.

sachinraja13 commented 1 month ago

Thanks for your response @barronalex !

I realise that HungarianMatcher is not a big performance bottleneck. However, I'm facing memory inflation while in the MLX port of prepare_for_cdn function.

https://github.com/ml-explore/mlx/issues/1432

I thought that maybe if I compiled the entire computation graph, that would solve the problem.