heyoeyo / muggled_dpt

Muggled DPT: Depth estimation without the magic
Apache License 2.0
47 stars 4 forks source link

confusion about network architecture for depth anything v2? #4

Closed command-z-z closed 2 months ago

command-z-z commented 3 months ago

Thank you for your awesome project! But, I am confused about which you write comments to extract features from the last 4 blocks in depth_anything_v2 architecture. https://github.com/heyoeyo/muggled_dpt/blob/1961e575c2e714b7e8c4c5ab6caa969b9aa8ef3e/lib/v2_depthanything/image_encoder_model.py#L40-L41 However, in the code implementation, you average the features from the middle not last(depth anything v1 is like this). Is this an issue to fix the comment? https://github.com/heyoeyo/muggled_dpt/blob/1961e575c2e714b7e8c4c5ab6caa969b9aa8ef3e/lib/v2_depthanything/image_encoder_model.py#L60-L66 And, this feature extraction method is better or not than output intermediate tokens from the transformers like the original DPT model?

heyoeyo commented 3 months ago

Oops, sorry, ya that's confusing!

The comment is incorrect, I just copied the code from the V1 implementation which works that way, and forgot to update the comments. Like you said, Depth Anything V1 uses this last 4 tokens structure, while V2 uses 4 equally spaced output like the original DPT structure. As far as I'm aware, this is the only structural difference between V1 and V2. Though I'm planning to do a more thorough check of this along with creating the V2 documentation at some point. I'll be sure to fix that comment as well, so thanks for catching that!

As for which approach is better, I'd assume that the DPT structure would work better with the 4 intermediate tokens, just because there's more potential for a variety of information to be encoded, given the extra processing steps between each of the stages. This also seems likely based on the fusion scaling results, where the last 2 stages of V1 can be disabled without dramatically effecting the depth prediction, whereas there's a more noticeable impact on the results for V2.

That all being said, the 4-output structure may be over-complicated. I'm currently working on a similar project using the Segment-Anything model, and one of the papers associated with the SAM model (Exploring Plain Vision Transformer Backbones for Object Detection) argues that single-output models can work just fine. The performance of the SAM model (which only has 1 output) seems to confirm this... so the 4-output structure may not be doing much to help?

command-z-z commented 3 months ago

thank you for your reply! Sorry that my previous statement was a bit unclear. Actually, what I want to ask is the difference between averaging the middle features(you do) and not averaging the middle features(just like depth anything v2 or dpt does).Even though they are only slightly different, the differences from the original implementation still make me worried.

https://github.com/DepthAnything/Depth-Anything-V2/blob/31dc97708961675ce6b3a8d8ffa729170a4aa273/depth_anything_v2/dpt.py#L164-L169

# depth anything v2 extract features from this intermediate layer idx
self.intermediate_layer_idx = {
        'vits': [2, 5, 8, 11],
        'vitb': [2, 5, 8, 11], 
        'vitl': [4, 11, 17, 23], 
        'vitg': [9, 19, 29, 39]
}
heyoeyo commented 2 months ago

I have not (yet) confirmed that my implementation produces the same results as the original depth-anything V2, but it is intended it be the same.

There are some structural changes that make things a bit confusing looking, but there shouldn't be any averaging going on. Instead, the blocks of the transformer are just arranged into 4 'stages' internally. These stages are each run in sequence to produce the 4 internal outputs of the image encoder: https://github.com/heyoeyo/muggled_dpt/blob/1961e575c2e714b7e8c4c5ab6caa969b9aa8ef3e/lib/v2_depthanything/image_encoder_model.py#L76-L80

But these stages themselves are still just running the transformer blocks in sequence, which you can see from the forward function of the stages. So the end result is something like a double-for-loop:

stage_results = []
for stage_blocks in stages:
    for block in stage_blocks:
        tokens = block(tokens)
    stage_results.append(tokens)

As an example, vit-small has 12 blocks in total, so each of the 4 stages will have 3 blocks. With the way the stages are run, it's as if the tokens are processed by the first 3 blocks, then the results are stored for output, then the next 3 blocks are executed etc. If you write out the block indices, you'd get something like:

stage 1: [0,1,2]   -> store tokens
stage 2: [3,4,5]   -> store tokens
stage 3: [6,7,8]   -> store tokens
stage 4: [9,10,11] -> store tokens

And you can see that the last index in each of these lists are [2, 5, 8, 11], which correspond to the intermediate_layer_idx values from the original code. So it should be getting the same outputs as the original implementation, it's just written in a different way. Slightly simplifying the original code, (see: step 1, step 2, step 3) it's something like:

# For vit-small
for block_index, block in enumerate(all_12_blocks):
    tokens = block(tokens)
    if block_index in [2,5,8,11]:
        stage_results.append(tokens)

The original code is arguably more straightforward to understand, but I don't like that it needs the hard-coded indices (e.g [2,5,8,11]) for each model size, and that this needs to be passed through several function calls for the model to work properly. So I ended up re-working the structure into explicit stages to get rid of the extra config (and it makes it more consistent with the original DPT models, which have similar stages).

Sorry that's a bunch of low-level details all at once, but hopefully it makes sense!

command-z-z commented 2 months ago

you are right, I get it. I appreciate your patience.