Git Product home page Git Product logo

mrm-pytorch's Introduction

Advancing Radiograph Representation Learning with Masked Record Modeling (MRM)

This repository includes an official implementation of paper: Advancing Radiograph Representation Learning with Masked Record Modeling (ICLR'23).

Some code is borrowed from MAE, huggingface, and REFERS.

1 Environmental preparation and quick start

Environmental requirements

  • Ubuntu 18.04 LTS.

  • Python 3.8.11

If you are using anaconda/miniconda, we provide an easy way to prepare the environment for pre-training and finetuning of classification:

  conda env create -f environment.yaml
  pip install -r requirements.txt

2 How to load the pre-trained model

Download the pre-trained weight first!

import torch
import torch.nn as nn
from functools import partial
import timm
assert timm.__version__ == "0.6.12"  # version check
from timm.models.vision_transformer import VisionTransformer

def vit_base_patch16(**kwargs):
    model = VisionTransformer(norm_layer=partial(nn.LayerNorm, eps=1e-6),**kwargs)
    return model

# model definition
model = vit_base_patch16(num_classes=14,drop_path_rate=0.1,global_pool="avg")
checkpoint_model = torch.load("./MRM.pth", map_location="cpu")["model"]
# load the pre-trained model
model.load_state_dict(checkpoint_model, strict=False)

3 Pre-training

3.1 Data preparation for pre-training

  • We use MIMIC-CXR-JPG for pre-training. You can acquire more information about this dataset at Johnson et al. MIMIC-CXR-JPG.
  • The dataset directory specified in run.sh includes the MIMIC-CXR-JPG dataset and you need to prepare a file training.csv and put it into the dataset directory.
  • The file training.csv includes two columns image_path and report_content for each line, corresponding to (a) the path to an image and (b) the text of the corresponding report, respectively, which should be organized as follows:
      image_path, report_content
      /path/to/img1.jpg, FINAL REPORT  EXAMINATION: ...
      /path/to/img2.jpg, FINAL REPORT  CHEST: ...
      ...,...
  • take one line as an example: img

3.2 Start pre-training

  • Download the pre-trained weight of MAE and set resume to the path of the pre-trained weight in run.sh.

  • Set the data path, GPU IDs, batch size, output directory, and other parameters in run.sh.

  • Start training by running

      chmod a+x run.sh
      ./run.sh

4 Fine-tuning of classification (take NIH ChestX-ray 14 dataset as the example)

4.1 Data preparation

  • Download NIH ChestX-ray 14 dataset and split train/valid/test set. The directory should be organized as follows:
      NIH_ChestX-ray/
            all_classes/
                  xxxx1.png
                  xxxx2.png
                  ...
                  xxxxn.png
            train_1.txt
            trian_10.txt
            train_list.txt
            val_list.txt
            test_list.txt

4.2 Start fine-tuning (take 1 percent data as the example)

      chmod a+x finetuning_1percent.sh
      ./finetuning_1percent.sh

4.3 More fine-tuning hyperparameters

RSNA warm-up setps total steps learning rate
1% 50 2000 3e-3
10% 200 10000 5e-4
100% 2000 50000 5e-4
CheXpert warm-up setps total steps learning rate
1% 150 2000 3e-3
10% 1500 60000 5e-4
100% 15000 200000 5e-4
Covid warm-up setps total steps learning rate
100% 50 1000 3e-2

5 Fine-tuning of segmentation

5.1 Data preparation

  • Download SIIM-ACR Pneumothorax and preprocess the images and annotations. Then organize the directory as follows:
      siim/
            images/
                  training/
                        xxxx1.png
                        xxxx2.png
                        ...
                        xxxxn.png
                  validation/
                        ...
                  test/
                        ...

            annotations/
                  training/
                        xxxx1.png
                        xxxx2.png
                        ...
                        xxxxn.png
                  validation/
                        ...
                  test/
                        ...

5.2 Necessary files for segmentation

We conduct all experiments of segmentation by MMSegmentaiton (version 0.25.0) and it is necessary to set the environment and comprehend the code structures of MMSegmentaiton in advance.

Here we provide the necessary configuration files for reproducing the experiments in the directory Siim_Segmentation. After modifying MMSegmentaiton framework with provided files, start fine-tuning and evaluation with ft.sh and test.sh, respectively.

6 Links to download datasets

7 Datasets splits

In the directory DatasetsSplits, we provide dataset splits that may be helpful for organizing the datasets.

We give the train/valid/test splits of CheXpert, NIH ChestX-ray, and RSNA Pneumonia.

For COVID-19 Image Data Collection, we randomly split the train/valid/test set 5 times and we provide the images in the images directory.

For SIIM-ACR_Pneumothorax, please organize the directories of images and annotations as section 5.1 mentioned according to the given splits.

mrm-pytorch's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

mrm-pytorch's Issues

How to generate my own wordpiece.json

Hello, I have found that there are many semantic unrelated tokens in mimic_wordpiece.json. So I want to generate my own wordpiece file. How could I achieve it?

not found siim segmentation test.py

I would like to know that tools/test.py and tools/dist_train.sh in test.sh and ft.sh are not found in the SIIM partition task. Can I make these codes public

Pretraining takes long time

Hi, I tried to pretrain MRM model with the provided configuration. In the paper it says that pretraining takes about 2 days for training 200 epochs on 4 RTX 3080Ti GPUs. However, using the default setting, my pretraining already took more than 3 days for training 100 epochs on 4 RTX 3090 GPUs. It would take more than 6 days for training 200 epochs. Do you have any idea why the pretraining takes longer time than reported?

In addition, for the provided pretrained weights MRM.pth, is it directly taken from the saved model in the 200th epoch? If not, how to choose from all the saved pretrained models?

Thank you!

KeyError: 'Dice. front' when training the segementation task

Thank you for sharing the code!

There is an error occur when I trainning the segementation task:
File ".../site-packages/mmcv/runner/hooks/evaluation.py", line 389, in evaluate
return eval_res[self.key_indicator]
KeyError: 'Dice.front'

After inspection, I found that in /mmseg/core/evaluation/metrics. py, the function 'total_area_to_metrics' has the following definition:

Allowed_ Metrics=['mIoU ',' mDice ',' mFscore ',' medDice ']
But there is no 'if' item for 'medDice':
For metric in metrics:
If metric=='mIoU ':
...
If metric=='mDice':
...
I guess is any code missing here? or can I directly use [If metric=='mDice' or metric=='medDice ': ] ?

StopIteration Error

Thanks for your reply about the dataset in the former issue!
After I fine-tune the SIIM-ACR Pneumothorax segmentation task with pre-trained model weights and run ft.sh with mmsegmentation, the following error is reported:

Traceback (most recent call last):
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 34, in __next__
    data = next(self.iter_loader)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
    data = self._next_data()
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1172, in _next_data
    raise StopIteration
StopIteration

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "tools/train.py", line 242, in <module>
    main()
  File "tools/train.py", line 231, in main
    train_segmentor(
  File "/home/wentaochen/mmsegmentation/mmseg/apis/train.py", line 194, in train_segmentor
    runner.run(data_loaders, cfg.workflow)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 144, in run
    iter_runner(iter_loaders[i], **kwargs)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 61, in train
    data_batch = next(data_loader)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 41, in __next__
    data = next(self.iter_loader)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
    data = self._next_data()
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1172, in _next_data
    raise StopIteration
StopIteration
Traceback (most recent call last):
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 34, in __next__
    data = next(self.iter_loader)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
    data = self._next_data()
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1172, in _next_data
    raise StopIteration
StopIteration

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "tools/train.py", line 242, in <module>
    main()
  File "tools/train.py", line 231, in main
    train_segmentor(
  File "/home/wentaochen/mmsegmentation/mmseg/apis/train.py", line 194, in train_segmentor
    runner.run(data_loaders, cfg.workflow)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 144, in run
    iter_runner(iter_loaders[i], **kwargs)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 61, in train
    data_batch = next(data_loader)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 41, in __next__
    data = next(self.iter_loader)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
    data = self._next_data()
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1172, in _next_data
    raise StopIteration
StopIteration
Traceback (most recent call last):
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 34, in __next__
    data = next(self.iter_loader)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
    data = self._next_data()
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1172, in _next_data
    raise StopIteration
StopIteration

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "tools/train.py", line 242, in <module>
    main()
  File "tools/train.py", line 231, in main
    train_segmentor(
  File "/home/wentaochen/mmsegmentation/mmseg/apis/train.py", line 194, in train_segmentor
    runner.run(data_loaders, cfg.workflow)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 144, in run
    iter_runner(iter_loaders[i], **kwargs)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 61, in train
    data_batch = next(data_loader)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 41, in __next__
    data = next(self.iter_loader)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
    data = self._next_data()
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1172, in _next_data
    raise StopIteration
StopIteration
Traceback (most recent call last):
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 34, in __next__
    data = next(self.iter_loader)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
    data = self._next_data()
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1172, in _next_data
    raise StopIteration
StopIteration

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "tools/train.py", line 242, in <module>
    main()
  File "tools/train.py", line 231, in main
    train_segmentor(
  File "/home/wentaochen/mmsegmentation/mmseg/apis/train.py", line 194, in train_segmentor
    runner.run(data_loaders, cfg.workflow)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 144, in run
    iter_runner(iter_loaders[i], **kwargs)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 61, in train
    data_batch = next(data_loader)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 41, in __next__
    data = next(self.iter_loader)
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
    data = self._next_data()
  File "/home/wentaochen/anaconda3/envs/MRM/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1172, in _next_data
    raise StopIteration
StopIteration

Hope you can give me some advice! Thank you very much!

SIIM-ACR Pneumothorax dataset in fine-tuning of segmentation

Hi, thank you for your excellent work!
In Part 5 in README, how is the SIIM-ACR Pneumothorax data from Kaggle mapped to the code you used? The name (ID) of the downloaded data from Kaggle link are like ID_0a0adf93f, while in this repo the name (ID) of the downloaded data are like 1.2.276.0.7230010.3.1.4.8323329.1608.1517875168.577144. I was wondering if there is any other way to get this part of the dataset (maybe only the stage1 data?). Thanks.

Can't use distributed processing

Thank you for sharing the code!
This is my script run.sh:

CUDA_VISIBLE_DEVICES=0,1,2,3 OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=4 --master_port 12345 main_pretrain.py \
    --num_workers 10 \
    --accum_iter 2 \
    --batch_size 128 \
    --model mrm \
    --norm_pix_loss \
    --mask_ratio 0.75 \
    --epochs 200 \
    --warmup_epochs 40 \
    --blr 1.5e-4 --weight_decay 0.05 \
    --resume ./MRM/mae_pretrain_vit_base.pth \
    --data_path ./MRM \
    --output_dir ./MRM \

When I use distributed computing, the program always gets stuck in this position

image

and never continue. But if I set the graphics card to one, it can train at a very slow speed.

I'm wondering how to deal with it.

Lower results obtained

Hi, thank you for your efforts in organizing and releasing the code. I tried the code on the SIIM segmentation task by carefully following the instructions on organizing/splitting the dataset and setting up the open mmsegmentation framework. However, using the provided configuration files, the results cannot be reproduced.

The pre-trained weights MRM.pth are downloaded from the provided link. Finetuning on 100% SIIM data gave Dice 90.7% and finetuning on 10% data gave Dice 69.1%, which are lower than the values in the paper (91.4% and 73.2%) especially when finetuning on 10% data.

Could you clarify how to reproduce the results? Thanks.

SIIM-Stage1 dataset can't download

I have trouble downloading SIIM data set, I can't use google cloud to download siim-stage1 part of the data set, could you please provide a link to SIIM data set web disk? Thank you very much

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.