IDEA-Research / detrex

detrex is a research platform for DETR-based object detection, segmentation, pose estimation and other visual recognition tasks.
https://detrex.readthedocs.io/en/latest/
Apache License 2.0
1.97k stars 206 forks source link

[Features] Support encoder-decoder checkpoint in DINO #200

Closed rentainhe closed 1 year ago

rentainhe commented 1 year ago

Support checkpoint in DINO

Environment: A100-40GB, batch-size=1, Swin-Large-384-4Scale-DINO

Usage

Simple tutorial on adding checkpoint

We use TransformerLayerSequence and BaseTransformerLayer as the basic block for DETR models in detrex, all you need to do is to wrap the layer in TransformerLayerSequence with checkpoint_wrapper implemented in fairscale:

from fairscale.nn.checkpoint import checkpoint_wrapper

class DINOTransformerEncoder(TransformerLayerSequence):
    def __init__(*args, **kwargs):
        ...
        if use_checkpoint:
            for layer in self.layers:
                layer = checkpoint_wrapper(layer)