-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathREADME.md.bak
133 lines (100 loc) · 4.4 KB
/
README.md.bak
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# PCRLv2 (TPAMI'22)
This repository contains an official implementation of our TPAMI paper "A Unified Visual Information Preservation Framework for Self-supervised Pre-training in Medical Image Analysis".
## Dependencies
Please make sure your PyTorch version >=1.1 before you run the code. We strongly recommend you to install Anaconda3 where we use Python 3.6. In addition, we use [apex](https://github.com/NVIDIA/apex) for acceleration. We also use [pretrained-models.pytorch](https://github.com/Cadene/pretrained-models.pytorch) and [segmentation_models_pytorch](https://github.com/qubvel/segmentation_models.pytorch) to speed up the implementation.
## Load Pretrained Model
Download the pretrained weight from https://drive.google.com/drive/folders/1Rp7CJblA5HzX5xjc_hUEcD2kWFEiTboF?usp=sharing. And put them under the pretrained_weight folder:
```
pretrained_weight
simance_multi_crop_chexpter_pretask_1.0_240.pt
simance_multi_crop_luna_pretask_1.0_240.pt
simance_multi_crop_chest_pretask_1.0_240.pt
```
### 2D Model
```python
import segmentation_models_pytorch as smp
aux_params = dict(
pooling='avg', # one of 'avg', 'max'
dropout=0.2, # dropout ratio, default is None
activation='sigmoid', # activation function, default is None
classes=n_class, # define number of output labels
)
weight_path = './pretrained_weight/simance_multi_crop_chexpter_pretask_1.0_240.pt'
model = smp.Unet('resnet18', in_channels=3, aux_params=aux_params, classes=1, encoder_weights=None)
encoder_dict = torch.load(weight)['state_dict']
encoder_dict['fc.bias'] = 0
encoder_dict['fc.weight'] = 0
model.encoder.load_state_dict(encoder_dict)
```
### 3D Model
```python
from pcrlv2_model_3d import PCRLv23d
model = PCRLv23d()
weight_path = './pretrained_weight/simance_multi_crop_luna_pretask_1.0_240.pt'
model_dict = torch.load(weight)['state_dict']
model.load_state_dict(model_dict)
```
## Pre-training
### NIH ChestX-ray
#### **Step 0**
Download NIH ChestX-ray from [here](https://nihcc.app.box.com/v/ChestXray-NIHCC). The image folder of NIH ChestX-ray should be organized as follows:
``` python
Chest14/
images/
00002639_006.png
00010571_003.png
...
```
Besides, we also provide the list of images used for pre-training: ``pytorch/train_val_txt/chest_train.txt``. Specifically, for semi-supervised experiments, we use top K% images for pre-training and last (100-K)% for fine-tuning. For instance, given a ratio of 9.5:0.5, we use the first 95% images in ``chest_train.txt`` for pre-training, while the rest 5% images are used for fine-tuning.
#### **Step 1**
``` python
git clone https://github.com/Luchixiang/PCRLv2.git
cd PCRLv2
```
#### **Step 2**
For pre-training, you choose to run the following script
``` python
python main.py --data chest14_data_path --model pcrlv2 --b 64 --epochs 240 --lr 1e-2 --output pretrained_model_save_path --n chest --d 2 --gpus 0,1,2,3 --ratio 1.0 --amp
```
or
```
bash run2d.sh
```
``--data`` defines the path where you store NIH ChestX-ray.
``--d`` defines the type of dataset, ``2`` stands for 2D while ``3`` denotes 3D.
``--n`` gives the name of dataset.
``--ratio`` determines the percentages of images in the training set for pretraining. Here, ``1`` means using all training images in the training set to for pretraining.
`--amp` denotes whether you want to use apex to accelerate the training process.
### LUNA16
#### **Step 0**
Please download LUNA16 from [here](https://luna16.grand-challenge.org/Download/). The image folder of LUNA16 should be organized as follows:
```python
LUNA16/
subset0/
1.3.6.1.4.1.14519.5.2.1.6279.6001.979083010707182900091062408058.raw
1.3.6.1.4.1.14519.5.2.1.6279.6001.979083010707182900091062408058.mhd
...
subset1/
subset2/
...
subset9/
```
We also provide the list of training image in ``pytorch/train_val_txt/luna_train.txt``.
#### **Step1**
``` python
git clone https://github.com/Luchixiang/PCRLv2.git
cd PCLR/pytorch
```
#### **Step 2**
First, you should pre-process the LUNA dataset to get cropped pairs from 3D images.
``` python
python luna_preprocess.py --input_rows 64 --input_cols 64 --input_deps 32 --data LUNA_dataset_path --save processedLUNA_save_path
```
#### **Step 3**
``` python
python main.py --data processed LUNA_save_path --model pcrlv2 --b 32 --epochs 240 --lr 1e-3 --output pretrained_model_save_path --n luna --d 3 --gpus 0,1,2,3 --ratio 1.0 --amp
```
or
```
bash run3d.sh
```