MCG-NJU / DEQDet

[ICCV 2023] Deep Equilibrium Object Detection
https://arxiv.org/abs/2308.09564
23 stars 1 forks source link

Code related #1

Open jacksonsc007 opened 1 year ago

jacksonsc007 commented 1 year ago

Thanks for the excellent work. I wonder the release time of the source code. Looking forward to your reply. : )

WANGSSSSSSS commented 1 year ago

NOW!

jacksonsc007 commented 1 year ago

Terrific! Thanks a lot. Could you specify the version of mmdetection btw?

jacksonsc007 commented 1 year ago

And the pytorch version, it seems like you have pytorch>=2.0.0

WANGSSSSSSS commented 1 year ago

System environment: sys.platform: linux Python: 3.8.16 (default, Jun 12 2023, 18:09:05) [GCC 11.2.0] CUDA available: True numpy_random_seed: 1123624972 GPU 0: NVIDIA GeForce RTX 3090 CUDA_HOME: /usr/local/cuda NVCC: Cuda compilation tools, release 11.8, V11.8.89 GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0 PyTorch: 2.0.1+cu117 PyTorch compiling details: PyTorch built with:

WANGSSSSSSS commented 1 year ago

And the pytorch version, it seems like you have pytorch>=2.0.0

yes, I try to catch up with the rapid world.

jacksonsc007 commented 1 year ago

Hi, while I went through your code, I encountered some issues, hope you could help me out :)

  1. what does "self.grad_accumulation" mean here? And the meaning of "stash gradient"?
  2. for the refine-aware gradient formulation you proposed in equation (11), it seems that you didn't not use this technique in your code implementation to speed up training and save memory, but used naive iteration and autograd of pytorch to deal with backward grad propagation instead. Am I right?
WANGSSSSSSS commented 1 year ago

Hi, while I went through your code, I encountered some issues, hope you could help me out :)

1. what does "self.grad_accumulation" mean [here](https://github.com/MCG-NJU/DEQDet/blob/fa72a62b2340a04300424041e9ebd0087a700eba/projects/deqdet/deq_det_roi_head.py#L219C12-L219C35)? And the meaning of "stash gradient"?

2. for the refine-aware gradient formulation you proposed in equation (11), it seems that you didn't not use this technique in your code implementation to speed up training and save memory, but used naive iteration and **autograd of pytorch** to deal with backward grad propagation instead. Am I right?

The refinement aware gradient is equivalent to the truncated bptt to some extend, cutting off the higher order terms of the rnn iterations. Due to that then each supervision is independent, we can use gradient accumulation between each supervision to avoid the extra memory consumption, but the Autograd in pytorch will push the gradient calculated in single supervision to every parameters, resulting serval back pass to backbone though, so I use this hook to stash gradient to mlvl features, the last backward of the supervision will restore the stashed gradient, and bring stashed gradient to backbone weights

WANGSSSSSSS commented 1 year ago

Hi, while I went through your code, I encountered some issues, hope you could help me out :)

1. what does "self.grad_accumulation" mean [here](https://github.com/MCG-NJU/DEQDet/blob/fa72a62b2340a04300424041e9ebd0087a700eba/projects/deqdet/deq_det_roi_head.py#L219C12-L219C35)? And the meaning of "stash gradient"?

2. for the refine-aware gradient formulation you proposed in equation (11), it seems that you didn't not use this technique in your code implementation to speed up training and save memory, but used naive iteration and **autograd of pytorch** to deal with backward grad propagation instead. Am I right?

For the question 2, yes, the RAG formulation is derived from 2-step unrolled fix-point formulation in paper, the implementation in codebase is that 2-step unrolled fix-point. The equation mainly helps to analyze the reason why two-step better than simple estimation method used in deq-flow. You can find the pesudo code in appendix.