RL4M / MRM-pytorch

An official implementation of Advancing Radiograph Representation Learning with Masked Record Modeling (ICLR'23)
MIT License
72 stars 4 forks source link
chest-xray-images multi-modal-learning pre-trained-model representation-learning self-supervised-learning

Advancing Radiograph Representation Learning with Masked Record Modeling (MRM)

This repository includes an official implementation of paper: Advancing Radiograph Representation Learning with Masked Record Modeling (ICLR'23).

Some code is borrowed from MAE, huggingface, and REFERS.

1 Environmental preparation and quick start

Environmental requirements

If you are using anaconda/miniconda, we provide an easy way to prepare the environment for pre-training and finetuning of classification:

  conda env create -f environment.yaml
  pip install -r requirements.txt

2 How to load the pre-trained model

Download the pre-trained weight first!

import torch
import torch.nn as nn
from functools import partial
import timm
assert timm.__version__ == "0.6.12"  # version check
from timm.models.vision_transformer import VisionTransformer

def vit_base_patch16(**kwargs):
    model = VisionTransformer(norm_layer=partial(nn.LayerNorm, eps=1e-6),**kwargs)
    return model

# model definition
model = vit_base_patch16(num_classes=14,drop_path_rate=0.1,global_pool="avg")
checkpoint_model = torch.load("./MRM.pth", map_location="cpu")["model"]
# load the pre-trained model
model.load_state_dict(checkpoint_model, strict=False)

3 Pre-training

3.1 Data preparation for pre-training

3.2 Start pre-training

4.2 Start fine-tuning (take 1 percent data as the example)

CheXpert warm-up setps total steps learning rate
1% 150 2000 3e-3
10% 1500 60000 5e-4
100% 15000 200000 5e-4
Covid warm-up setps total steps learning rate
100% 50 1000 3e-2

5 Fine-tuning of segmentation

5.1 Data preparation

Here we provide the necessary configuration files for reproducing the experiments in the directory Siim_Segmentation. After modifying MMSegmentaiton framework with provided files, start fine-tuning and evaluation with ft.sh and test.sh, respectively.

6 Links to download datasets

7 Datasets splits

In the directory DatasetsSplits, we provide dataset splits that may be helpful for organizing the datasets.

We give the train/valid/test splits of CheXpert, NIH ChestX-ray, and RSNA Pneumonia.

For COVID-19 Image Data Collection, we randomly split the train/valid/test set 5 times and we provide the images in the images directory.

For SIIM-ACR_Pneumothorax, please organize the directories of images and annotations as section 5.1 mentioned according to the given splits.