fundamentalvision / Deformable-DETR

Deformable DETR: Deformable Transformers for End-to-End Object Detection.
Apache License 2.0
3.15k stars 513 forks source link

Pure PyTorch version of Deformable Attention #131

Open vadimkantorov opened 2 years ago

vadimkantorov commented 2 years ago

As I found, there is a def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights), but its signature doesn't accept value_level_start_index.

How to correctly use this function? It can be useful for a CPU-only debugging.

Thank you!

shayanjoya commented 2 years ago

Hi @vadimkantorov were you able to use this as a replacement with the additions ?

ZhixiongSun commented 1 year ago

@shayanjoya use it like this in class MSDeformAttn:

output = MSDeformAttnFunction.apply(

    #     value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
    output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
    output = self.output_proj(output)
    return output