ft42 commited on
Commit
599a397
·
verified ·
1 Parent(s): c196078

Upload 63 files

Browse files

Added Inference code, demo data and config and slum script

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +13 -0
  2. NoMAISI_logo.png +3 -0
  3. configs/config_maisi3d-rflow.json +150 -0
  4. configs/infr_config_NoMAISI_controlnet.json +17 -0
  5. configs/infr_env_NoMAISI_DLCSD24_demo.json +11 -0
  6. data/DLCS_1419_seg_sh.nii.gz +3 -0
  7. data/infr_NoMAISI_DLCSD24_demo_512xy_256z_771p25m_dataset.json +32 -0
  8. doc/images/DLCS_1419_ann0_slice134_triple.png +3 -0
  9. doc/images/DLCS_1419_ann1_slice204_triple.png +3 -0
  10. doc/images/DLCS_1443_ann1_slice125_triple.png +3 -0
  11. doc/images/DLCS_1446_ann0_slice122_triple.png +3 -0
  12. doc/images/DLCS_1447_ann0_slice206_triple.png +3 -0
  13. doc/images/DLCS_1453_ann0_slice204_triple.png +3 -0
  14. doc/images/DLCS_1508_ann0_slice46_triple.png +3 -0
  15. doc/images/DLCS_1519_ann3_slice155_triple.png +3 -0
  16. doc/images/GanAI_fid_scatter_marker_legend.png +3 -0
  17. doc/images/NoMAISI_train_and_infer.png +3 -0
  18. doc/images/TaskCls.png +3 -0
  19. doc/images/workflow.png +3 -0
  20. inference.sub +26 -0
  21. logs/NoMAISI-infr-log-38612.out +18 -0
  22. scripts/__init__.py +10 -0
  23. scripts/__pycache__/__init__.cpython-310.pyc +0 -0
  24. scripts/__pycache__/augmentation.cpython-310.pyc +0 -0
  25. scripts/__pycache__/diff_model_create_training_data.cpython-310.pyc +0 -0
  26. scripts/__pycache__/diff_model_setting.cpython-310.pyc +0 -0
  27. scripts/__pycache__/find_masks.cpython-310.pyc +0 -0
  28. scripts/__pycache__/infer_controlnet.cpython-310.pyc +0 -0
  29. scripts/__pycache__/infer_testV2_controlnet.cpython-310.pyc +0 -0
  30. scripts/__pycache__/infer_test_controlnet.cpython-310.pyc +0 -0
  31. scripts/__pycache__/inference.cpython-310.pyc +0 -0
  32. scripts/__pycache__/quality_check.cpython-310.pyc +0 -0
  33. scripts/__pycache__/rectified_flow.cpython-310.pyc +0 -0
  34. scripts/__pycache__/sample.cpython-310.pyc +0 -0
  35. scripts/__pycache__/train_controlnet.cpython-310.pyc +0 -0
  36. scripts/__pycache__/utils.cpython-310.pyc +0 -0
  37. scripts/__pycache__/utils_plot.cpython-310.pyc +0 -0
  38. scripts/augmentation.py +373 -0
  39. scripts/compute_fid_2-5d_ct.py +747 -0
  40. scripts/diff_model_create_training_data.py +231 -0
  41. scripts/diff_model_infer.py +358 -0
  42. scripts/diff_model_setting.py +92 -0
  43. scripts/diff_model_train.py +499 -0
  44. scripts/find_masks.py +157 -0
  45. scripts/infer_controlnet.py +222 -0
  46. scripts/infer_testV2_controlnet.py +220 -0
  47. scripts/infer_test_controlnet.py +220 -0
  48. scripts/inference.py +299 -0
  49. scripts/quality_check.py +149 -0
  50. scripts/rectified_flow.py +322 -0
.gitattributes CHANGED
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ doc/images/DLCS_1419_ann0_slice134_triple.png filter=lfs diff=lfs merge=lfs -text
37
+ doc/images/DLCS_1419_ann1_slice204_triple.png filter=lfs diff=lfs merge=lfs -text
38
+ doc/images/DLCS_1443_ann1_slice125_triple.png filter=lfs diff=lfs merge=lfs -text
39
+ doc/images/DLCS_1446_ann0_slice122_triple.png filter=lfs diff=lfs merge=lfs -text
40
+ doc/images/DLCS_1447_ann0_slice206_triple.png filter=lfs diff=lfs merge=lfs -text
41
+ doc/images/DLCS_1453_ann0_slice204_triple.png filter=lfs diff=lfs merge=lfs -text
42
+ doc/images/DLCS_1508_ann0_slice46_triple.png filter=lfs diff=lfs merge=lfs -text
43
+ doc/images/DLCS_1519_ann3_slice155_triple.png filter=lfs diff=lfs merge=lfs -text
44
+ doc/images/GanAI_fid_scatter_marker_legend.png filter=lfs diff=lfs merge=lfs -text
45
+ doc/images/NoMAISI_train_and_infer.png filter=lfs diff=lfs merge=lfs -text
46
+ doc/images/TaskCls.png filter=lfs diff=lfs merge=lfs -text
47
+ doc/images/workflow.png filter=lfs diff=lfs merge=lfs -text
48
+ NoMAISI_logo.png filter=lfs diff=lfs merge=lfs -text
NoMAISI_logo.png ADDED

Git LFS Details

  • SHA256: 59e28b561fa2a934150fa912146fc81f75aa8b526defd5c698c46cac09995c94
  • Pointer size: 131 Bytes
  • Size of remote file: 186 kB
configs/config_maisi3d-rflow.json ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "spatial_dims": 3,
3
+ "image_channels": 1,
4
+ "latent_channels": 4,
5
+ "include_body_region": false,
6
+ "mask_generation_latent_shape": [
7
+ 4,
8
+ 64,
9
+ 64,
10
+ 64
11
+ ],
12
+ "autoencoder_def": {
13
+ "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
14
+ "spatial_dims": "@spatial_dims",
15
+ "in_channels": "@image_channels",
16
+ "out_channels": "@image_channels",
17
+ "latent_channels": "@latent_channels",
18
+ "num_channels": [
19
+ 64,
20
+ 128,
21
+ 256
22
+ ],
23
+ "num_res_blocks": [2,2,2],
24
+ "norm_num_groups": 32,
25
+ "norm_eps": 1e-06,
26
+ "attention_levels": [
27
+ false,
28
+ false,
29
+ false
30
+ ],
31
+ "with_encoder_nonlocal_attn": false,
32
+ "with_decoder_nonlocal_attn": false,
33
+ "use_checkpointing": false,
34
+ "use_convtranspose": false,
35
+ "norm_float16": true,
36
+ "num_splits": 4,
37
+ "dim_split": 1
38
+ },
39
+ "diffusion_unet_def": {
40
+ "_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi",
41
+ "spatial_dims": "@spatial_dims",
42
+ "in_channels": "@latent_channels",
43
+ "out_channels": "@latent_channels",
44
+ "num_channels": [64, 128, 256, 512],
45
+ "attention_levels": [
46
+ false,
47
+ false,
48
+ true,
49
+ true
50
+ ],
51
+ "num_head_channels": [
52
+ 0,
53
+ 0,
54
+ 32,
55
+ 32
56
+ ],
57
+ "num_res_blocks": 2,
58
+ "use_flash_attention": true,
59
+ "include_top_region_index_input": "@include_body_region",
60
+ "include_bottom_region_index_input": "@include_body_region",
61
+ "include_spacing_input": true,
62
+ "num_class_embeds": 128,
63
+ "resblock_updown": true,
64
+ "include_fc": true
65
+ },
66
+ "controlnet_def": {
67
+ "_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi",
68
+ "spatial_dims": "@spatial_dims",
69
+ "in_channels": "@latent_channels",
70
+ "num_channels": [64, 128, 256, 512],
71
+ "attention_levels": [
72
+ false,
73
+ false,
74
+ true,
75
+ true
76
+ ],
77
+ "num_head_channels": [
78
+ 0,
79
+ 0,
80
+ 32,
81
+ 32
82
+ ],
83
+ "num_res_blocks": 2,
84
+ "use_flash_attention": true,
85
+ "conditioning_embedding_in_channels": 8,
86
+ "conditioning_embedding_num_channels": [8, 32, 64],
87
+ "num_class_embeds": 128,
88
+ "resblock_updown": true,
89
+ "include_fc": true
90
+ },
91
+ "mask_generation_autoencoder_def": {
92
+ "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
93
+ "spatial_dims": "@spatial_dims",
94
+ "in_channels": 8,
95
+ "out_channels": 125,
96
+ "latent_channels": "@latent_channels",
97
+ "num_channels": [
98
+ 32,
99
+ 64,
100
+ 128
101
+ ],
102
+ "num_res_blocks": [1, 2, 2],
103
+ "norm_num_groups": 32,
104
+ "norm_eps": 1e-06,
105
+ "attention_levels": [
106
+ false,
107
+ false,
108
+ false
109
+ ],
110
+ "with_encoder_nonlocal_attn": false,
111
+ "with_decoder_nonlocal_attn": false,
112
+ "use_flash_attention": false,
113
+ "use_checkpointing": true,
114
+ "use_convtranspose": true,
115
+ "norm_float16": true,
116
+ "num_splits": 8,
117
+ "dim_split": 1
118
+ },
119
+ "mask_generation_diffusion_def": {
120
+ "_target_": "monai.networks.nets.diffusion_model_unet.DiffusionModelUNet",
121
+ "spatial_dims": "@spatial_dims",
122
+ "in_channels": "@latent_channels",
123
+ "out_channels": "@latent_channels",
124
+ "channels":[64, 128, 256, 512],
125
+ "attention_levels":[false, false, true, true],
126
+ "num_head_channels":[0, 0, 32, 32],
127
+ "num_res_blocks": 2,
128
+ "use_flash_attention": true,
129
+ "with_conditioning": true,
130
+ "upcast_attention": true,
131
+ "cross_attention_dim": 10
132
+ },
133
+ "mask_generation_scale_factor": 1.0055984258651733,
134
+ "noise_scheduler": {
135
+ "_target_": "monai.networks.schedulers.rectified_flow.RFlowScheduler",
136
+ "num_train_timesteps": 1000,
137
+ "use_discrete_timesteps": false,
138
+ "use_timestep_transform": true,
139
+ "sample_method": "uniform",
140
+ "scale":1.4
141
+ },
142
+ "mask_generation_noise_scheduler": {
143
+ "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler",
144
+ "num_train_timesteps": 1000,
145
+ "beta_start": 0.0015,
146
+ "beta_end": 0.0195,
147
+ "schedule": "scaled_linear_beta",
148
+ "clip_sample": false
149
+ }
150
+ }
configs/infr_config_NoMAISI_controlnet.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "controlnet_train": {
3
+ "batch_size": 2,
4
+ "cache_rate": 0.0,
5
+ "fold": 1,
6
+ "lr": 1e-5,
7
+ "n_epochs": 500,
8
+ "weighted_loss_label": [23],
9
+ "weighted_loss": 100
10
+ },
11
+ "controlnet_infer": {
12
+ "num_inference_steps": 30,
13
+ "autoencoder_sliding_window_infer_size": [80, 80, 64],
14
+ "autoencoder_sliding_window_infer_overlap": 0.25,
15
+ "modality": 1
16
+ }
17
+ }
configs/infr_env_NoMAISI_DLCSD24_demo.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_dir": "./models/",
3
+ "output_dir": "./outputs/NoMAISI_DLCSD24_demo_512xy_256z_771p25m",
4
+ "tfevent_path": "./outputs/tfevent",
5
+ "trained_autoencoder_path": "./models/autoencoder.pt",
6
+ "trained_diffusion_path": "./models/diffusion_unet.pt",
7
+ "trained_controlnet_path": "./models/Experiments_NoMAISI_512xy_256z_771p25m_finetune_500epoch_best.pt",
8
+ "exp_name": "NoMAISI_DLCSD24_demo_512xy_256z_771p25m",
9
+ "data_base_dir": ["/home/ft42/NoMAISI/data"],
10
+ "json_data_list": ["/home/ft42/NoMAISI/data/infr_NoMAISI_DLCSD24_demo_512xy_256z_771p25m_dataset.json"]
11
+ }
data/DLCS_1419_seg_sh.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83da8dbf3b165023f3ffcec571fe5766177b65aabfa143f3a0bef5be41af757b
3
+ size 2265286
data/infr_NoMAISI_DLCSD24_demo_512xy_256z_771p25m_dataset.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "NoMAISI_DLCSD24_demo_512xy_256z_771p25m",
3
+ "numTest": 1,
4
+ "testing": [
5
+ {
6
+
7
+ "label": "DLCS_1419_seg_sh.nii.gz",
8
+ "fold": 0,
9
+ "dim": [
10
+ 512,
11
+ 512,
12
+ 256
13
+ ],
14
+ "spacing": [
15
+ 0.703125,
16
+ 0.703125,
17
+ 1.25
18
+ ],
19
+ "top_region_index": [
20
+ 0,
21
+ 1,
22
+ 0,
23
+ 0
24
+ ],
25
+ "bottom_region_index": [
26
+ 0,
27
+ 0,
28
+ 1,
29
+ 0
30
+ ]
31
+ }]
32
+ }
doc/images/DLCS_1419_ann0_slice134_triple.png ADDED

Git LFS Details

  • SHA256: 9729e15104e9f3b6ae675f57bf7d5f9f1aec3e191a4d7a68209bde4a3d148363
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
doc/images/DLCS_1419_ann1_slice204_triple.png ADDED

Git LFS Details

  • SHA256: 5bbcd3ddca8a3623f38764984fed7f9a36c92d8e2c98336c4b3e5e0aadb29e0a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
doc/images/DLCS_1443_ann1_slice125_triple.png ADDED

Git LFS Details

  • SHA256: e6336851a8174aeedd990f169f4dfa1ec8f2524adbbfd048f1d491ba0973ae72
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
doc/images/DLCS_1446_ann0_slice122_triple.png ADDED

Git LFS Details

  • SHA256: 29706ad025325e95dd9ad6cc56e52ea9481866a23d97fb3033198f76a5b65a13
  • Pointer size: 131 Bytes
  • Size of remote file: 955 kB
doc/images/DLCS_1447_ann0_slice206_triple.png ADDED

Git LFS Details

  • SHA256: ad5d313eee8c53edb67c8240963c29c340f2e4456db8cdfc538f7c10fcbf7f2f
  • Pointer size: 131 Bytes
  • Size of remote file: 893 kB
doc/images/DLCS_1453_ann0_slice204_triple.png ADDED

Git LFS Details

  • SHA256: c1cb674c92523eab8a008367b658561cfb94ebc3dfbc84b6f666d609097f2863
  • Pointer size: 132 Bytes
  • Size of remote file: 1.2 MB
doc/images/DLCS_1508_ann0_slice46_triple.png ADDED

Git LFS Details

  • SHA256: 5d3f245b13e4495d01e8585058239c02f2cbc17b72557d8306b58bce23747334
  • Pointer size: 132 Bytes
  • Size of remote file: 1.64 MB
doc/images/DLCS_1519_ann3_slice155_triple.png ADDED

Git LFS Details

  • SHA256: a0a7db06ba28e1412d546d2ef917c50f04dd2cff9f06c5d83d611c406185fd13
  • Pointer size: 132 Bytes
  • Size of remote file: 1.36 MB
doc/images/GanAI_fid_scatter_marker_legend.png ADDED

Git LFS Details

  • SHA256: 60c1e2e2be297fd13de2600aa2559c853db277d9ef3238da7a166c1e3472a237
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB
doc/images/NoMAISI_train_and_infer.png ADDED

Git LFS Details

  • SHA256: ffc762231f799865c8a36898ae6e23434f0f188edd45fec1be88bbd9f582a3f4
  • Pointer size: 131 Bytes
  • Size of remote file: 457 kB
doc/images/TaskCls.png ADDED

Git LFS Details

  • SHA256: 8d23c4d5110aab51b39e9772122eb98edaa5d260e1fcc3de24ff486fb5feaa06
  • Pointer size: 131 Bytes
  • Size of remote file: 280 kB
doc/images/workflow.png ADDED

Git LFS Details

  • SHA256: 3bfeafa6ca6729ce6808c39e13afaa222d7ce102277b0f5fcb7d3eb29148ef93
  • Pointer size: 131 Bytes
  • Size of remote file: 610 kB
inference.sub ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #SBATCH --job-name=nomaisi
4
+ #SBATCH --mail-type=END,FAIL
5
+ #SBATCH --mail-user=ft42@duke.edu
6
+ #SBATCH -p vram48
7
+ #SBATCH --ntasks=1 #
8
+ #SBATCH --gpus=1 # 2 GPU per task, chose more if model is capable of multi gpu training
9
+ #SBATCH --cpus-per-task=16 # More if it is CPU intensive job too NNUNET demands lot of CPU
10
+
11
+ ## Make sure logs directory is present on current directory (same as this script)
12
+ #SBATCH --output=logs/NoMAISI-infr-log-%j.out
13
+ #SBATCH --error=logs/NoMAISI-infr-log-%j.out
14
+
15
+
16
+
17
+ echo "Job starting"
18
+ echo "GPUs Given: $CUDA_VISIBLE_DEVICES"
19
+ module load miniconda/py39_4.12.0
20
+ source activate monai-auto3dseg
21
+
22
+
23
+ # Add the correct path to PYTHONPATH
24
+ export MONAI_DATA_DIRECTORY=/home/ft42/NoMAISI/
25
+
26
+ python -m scripts.infer_testV2_controlnet -c ./configs/config_maisi3d-rflow.json -e ./configs/infr_env_NoMAISI_DLCSD24_demo.json -t ./configs/infr_config_NoMAISI_controlnet.json
logs/NoMAISI-infr-log-38612.out ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
0
  0%| | 0/30 [00:00<?, ?it/s]
1
  3%|▎ | 1/30 [00:00<00:23, 1.22it/s]
2
  7%|▋ | 2/30 [00:01<00:14, 1.93it/s]
3
  10%|█ | 3/30 [00:01<00:12, 2.17it/s]
4
  13%|█▎ | 4/30 [00:01<00:11, 2.30it/s]
5
  17%|█▋ | 5/30 [00:02<00:10, 2.39it/s]
6
  20%|██ | 6/30 [00:02<00:09, 2.44it/s]
7
  23%|██▎ | 7/30 [00:03<00:09, 2.47it/s]
8
  27%|██▋ | 8/30 [00:03<00:08, 2.49it/s]
9
  30%|███ | 9/30 [00:03<00:08, 2.51it/s]
10
  33%|███▎ | 10/30 [00:04<00:07, 2.52it/s]
11
  37%|███▋ | 11/30 [00:04<00:07, 2.53it/s]
12
  40%|████ | 12/30 [00:05<00:07, 2.53it/s]
13
  43%|████▎ | 13/30 [00:05<00:06, 2.53it/s]
14
  47%|████▋ | 14/30 [00:05<00:06, 2.54it/s]
15
  50%|█████ | 15/30 [00:06<00:05, 2.54it/s]
16
  53%|█████▎ | 16/30 [00:06<00:05, 2.54it/s]
17
  57%|█████▋ | 17/30 [00:07<00:05, 2.54it/s]
18
  60%|██████ | 18/30 [00:07<00:04, 2.54it/s]
19
  63%|██████▎ | 19/30 [00:07<00:04, 2.54it/s]
20
  67%|██████▋ | 20/30 [00:08<00:03, 2.54it/s]
21
  70%|███████ | 21/30 [00:08<00:03, 2.54it/s]
22
  73%|███████▎ | 22/30 [00:08<00:03, 2.54it/s]
23
  77%|███████▋ | 23/30 [00:09<00:02, 2.53it/s]
24
  80%|████████ | 24/30 [00:09<00:02, 2.54it/s]
25
  83%|████████▎ | 25/30 [00:10<00:01, 2.53it/s]
26
  87%|████████▋ | 26/30 [00:10<00:01, 2.53it/s]
27
  90%|█████████ | 27/30 [00:10<00:01, 2.53it/s]
28
  93%|█████████▎| 28/30 [00:11<00:00, 2.53it/s]
29
  97%|█████████▋| 29/30 [00:11<00:00, 2.53it/s]
 
 
 
30
  0%| | 0/4 [00:00<?, ?it/s]
31
  25%|██▌ | 1/4 [00:04<00:13, 4.36s/it]
32
  50%|█████ | 2/4 [00:08<00:07, 4.00s/it]
33
  75%|███████▌ | 3/4 [00:11<00:03, 3.79s/it]
 
 
 
 
1
+ Job starting
2
+ GPUs Given: 0
3
+ [2025-09-24 13:42:58.511][ INFO](maisi.controlnet.infer) - Number of GPUs: 1
4
+ [2025-09-24 13:42:58.512][ INFO](maisi.controlnet.infer) - World_size: 1
5
+ [2025-09-24 13:42:59.541][ INFO](maisi.controlnet.infer) - Load trained diffusion model from ./models/autoencoder.pt.
6
+ [2025-09-24 13:43:03.285][ INFO](maisi.controlnet.infer) - Load trained diffusion model from ./models/diffusion_unet.pt.
7
+ [2025-09-24 13:43:03.287][ INFO](maisi.controlnet.infer) - loaded scale_factor from diffusion model ckpt -> 1.0311251878738403.
8
+ 2025-09-24 13:43:03,824 - INFO - 'dst' model updated: 180 of 231 variables.
9
+ [2025-09-24 13:43:04.077][ INFO](maisi.controlnet.infer) - load trained controlnet model from ./models/Experiments_NoMAISI_512xy_256z_771p25m_finetune_500epoch_best.pt
10
+ [2025-09-24 13:43:07.130][ INFO](root) - `controllable_anatomy_size` is not provided.
11
+ [2025-09-24 13:43:07.133][ INFO](root) - ---- Start generating latent features... ----
12
+
13
  0%| | 0/30 [00:00<?, ?it/s]
14
  3%|▎ | 1/30 [00:00<00:23, 1.22it/s]
15
  7%|▋ | 2/30 [00:01<00:14, 1.93it/s]
16
  10%|█ | 3/30 [00:01<00:12, 2.17it/s]
17
  13%|█▎ | 4/30 [00:01<00:11, 2.30it/s]
18
  17%|█▋ | 5/30 [00:02<00:10, 2.39it/s]
19
  20%|██ | 6/30 [00:02<00:09, 2.44it/s]
20
  23%|██▎ | 7/30 [00:03<00:09, 2.47it/s]
21
  27%|██▋ | 8/30 [00:03<00:08, 2.49it/s]
22
  30%|███ | 9/30 [00:03<00:08, 2.51it/s]
23
  33%|███▎ | 10/30 [00:04<00:07, 2.52it/s]
24
  37%|███▋ | 11/30 [00:04<00:07, 2.53it/s]
25
  40%|████ | 12/30 [00:05<00:07, 2.53it/s]
26
  43%|████▎ | 13/30 [00:05<00:06, 2.53it/s]
27
  47%|████▋ | 14/30 [00:05<00:06, 2.54it/s]
28
  50%|█████ | 15/30 [00:06<00:05, 2.54it/s]
29
  53%|█████▎ | 16/30 [00:06<00:05, 2.54it/s]
30
  57%|█████▋ | 17/30 [00:07<00:05, 2.54it/s]
31
  60%|██████ | 18/30 [00:07<00:04, 2.54it/s]
32
  63%|██████▎ | 19/30 [00:07<00:04, 2.54it/s]
33
  67%|██████▋ | 20/30 [00:08<00:03, 2.54it/s]
34
  70%|███████ | 21/30 [00:08<00:03, 2.54it/s]
35
  73%|███████▎ | 22/30 [00:08<00:03, 2.54it/s]
36
  77%|███████▋ | 23/30 [00:09<00:02, 2.53it/s]
37
  80%|████████ | 24/30 [00:09<00:02, 2.54it/s]
38
  83%|████████▎ | 25/30 [00:10<00:01, 2.53it/s]
39
  87%|████████▋ | 26/30 [00:10<00:01, 2.53it/s]
40
  90%|█████████ | 27/30 [00:10<00:01, 2.53it/s]
41
  93%|█████████▎| 28/30 [00:11<00:00, 2.53it/s]
42
  97%|█████████▋| 29/30 [00:11<00:00, 2.53it/s]
43
+ [2025-09-24 13:43:19.446][ INFO](root) - ---- DM/ControlNet Latent features generation time: 12.313125371932983 seconds ----
44
+ [2025-09-24 13:43:20.016][ INFO](root) - ---- Start decoding latent features into images... ----
45
+
46
  0%| | 0/4 [00:00<?, ?it/s]
47
  25%|██▌ | 1/4 [00:04<00:13, 4.36s/it]
48
  50%|█████ | 2/4 [00:08<00:07, 4.00s/it]
49
  75%|███████▌ | 3/4 [00:11<00:03, 3.79s/it]
50
+ [2025-09-24 13:43:35.252][ INFO](root) - ---- Image VAE decoding time: 15.23531699180603 seconds ----
51
+ 2025-09-24 13:43:37,053 INFO image_writer.py:197 - writing: outputs/NoMAISI_DLCSD24_demo_512xy_256z_771p25m/DLCS_1419_seg_sh_image.nii.gz
52
+ 2025-09-24 13:43:41,437 INFO image_writer.py:197 - writing: outputs/NoMAISI_DLCSD24_demo_512xy_256z_771p25m/DLCS_1419_seg_sh_label.nii.gz
scripts/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # You may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
scripts/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (140 Bytes). View file
 
scripts/__pycache__/augmentation.cpython-310.pyc ADDED
Binary file (6.81 kB). View file
 
scripts/__pycache__/diff_model_create_training_data.cpython-310.pyc ADDED
Binary file (7.38 kB). View file
 
scripts/__pycache__/diff_model_setting.cpython-310.pyc ADDED
Binary file (2.47 kB). View file
 
scripts/__pycache__/find_masks.cpython-310.pyc ADDED
Binary file (4.48 kB). View file
 
scripts/__pycache__/infer_controlnet.cpython-310.pyc ADDED
Binary file (5.72 kB). View file
 
scripts/__pycache__/infer_testV2_controlnet.cpython-310.pyc ADDED
Binary file (5.76 kB). View file
 
scripts/__pycache__/infer_test_controlnet.cpython-310.pyc ADDED
Binary file (5.75 kB). View file
 
scripts/__pycache__/inference.cpython-310.pyc ADDED
Binary file (7.62 kB). View file
 
scripts/__pycache__/quality_check.cpython-310.pyc ADDED
Binary file (4.39 kB). View file
 
scripts/__pycache__/rectified_flow.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
scripts/__pycache__/sample.cpython-310.pyc ADDED
Binary file (31.4 kB). View file
 
scripts/__pycache__/train_controlnet.cpython-310.pyc ADDED
Binary file (8.01 kB). View file
 
scripts/__pycache__/utils.cpython-310.pyc ADDED
Binary file (26.5 kB). View file
 
scripts/__pycache__/utils_plot.cpython-310.pyc ADDED
Binary file (6.66 kB). View file
 
scripts/augmentation.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from monai.transforms import Rand3DElastic, RandAffine, RandZoom
16
+ from monai.utils import ensure_tuple_rep
17
+
18
+
19
+ def erode3d(input_tensor, erosion=3):
20
+ # Define the structuring element
21
+ erosion = ensure_tuple_rep(erosion, 3)
22
+ structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to(input_tensor.device)
23
+
24
+ # Pad the input tensor to handle border pixels
25
+ input_padded = F.pad(
26
+ input_tensor.float().unsqueeze(0).unsqueeze(0),
27
+ (erosion[0] // 2, erosion[0] // 2, erosion[1] // 2, erosion[1] // 2, erosion[2] // 2, erosion[2] // 2),
28
+ mode="constant",
29
+ value=1.0,
30
+ )
31
+
32
+ # Apply erosion operation
33
+ output = F.conv3d(input_padded, structuring_element, padding=0)
34
+
35
+ # Set output values based on the minimum value within the structuring element
36
+ output = torch.where(output == torch.sum(structuring_element), 1.0, 0.0)
37
+
38
+ return output.squeeze(0).squeeze(0)
39
+
40
+
41
+ def dilate3d(input_tensor, erosion=3):
42
+ # Define the structuring element
43
+ erosion = ensure_tuple_rep(erosion, 3)
44
+ structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to(input_tensor.device)
45
+
46
+ # Pad the input tensor to handle border pixels
47
+ input_padded = F.pad(
48
+ input_tensor.float().unsqueeze(0).unsqueeze(0),
49
+ (erosion[0] // 2, erosion[0] // 2, erosion[1] // 2, erosion[1] // 2, erosion[2] // 2, erosion[2] // 2),
50
+ mode="constant",
51
+ value=1.0,
52
+ )
53
+
54
+ # Apply erosion operation
55
+ output = F.conv3d(input_padded, structuring_element, padding=0)
56
+
57
+ # Set output values based on the minimum value within the structuring element
58
+ output = torch.where(output > 0, 1.0, 0.0)
59
+
60
+ return output.squeeze(0).squeeze(0)
61
+
62
+
63
+ def augmentation_tumor_bone(pt_nda, output_size, random_seed=None):
64
+ volume = pt_nda.squeeze(0)
65
+ real_l_volume_ = torch.zeros_like(volume)
66
+ real_l_volume_[volume == 128] = 1
67
+ real_l_volume_ = real_l_volume_.to(torch.uint8)
68
+
69
+ elastic = RandAffine(
70
+ mode="nearest",
71
+ prob=1.0,
72
+ translate_range=(5, 5, 0),
73
+ rotate_range=(0, 0, 0.1),
74
+ scale_range=(0.15, 0.15, 0),
75
+ padding_mode="zeros",
76
+ )
77
+ elastic.set_random_state(seed=random_seed)
78
+
79
+ tumor_szie = torch.sum((real_l_volume_ > 0).float())
80
+ ###########################
81
+ # remove pred in pseudo_label in real lesion region
82
+ volume[real_l_volume_ > 0] = 200
83
+ ###########################
84
+ if tumor_szie > 0:
85
+ # get organ mask
86
+ organ_mask = (
87
+ torch.logical_and(33 <= volume, volume <= 56).float()
88
+ + torch.logical_and(63 <= volume, volume <= 97).float()
89
+ + (volume == 127).float()
90
+ + (volume == 114).float()
91
+ + real_l_volume_
92
+ )
93
+ organ_mask = (organ_mask > 0).float()
94
+ cnt = 0
95
+ while True:
96
+ threshold = 0.8 if cnt < 40 else 0.75
97
+ real_l_volume = real_l_volume_
98
+ # random distor mask
99
+ distored_mask = elastic((real_l_volume > 0).cuda(), spatial_size=tuple(output_size)).as_tensor()
100
+ real_l_volume = distored_mask * organ_mask
101
+ cnt += 1
102
+ print(torch.sum(real_l_volume), "|", tumor_szie * threshold)
103
+ if torch.sum(real_l_volume) >= tumor_szie * threshold:
104
+ real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5)
105
+ real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0).to(torch.uint8)
106
+ break
107
+ else:
108
+ real_l_volume = real_l_volume_
109
+
110
+ volume[real_l_volume == 1] = 128
111
+
112
+ pt_nda = volume.unsqueeze(0)
113
+ return pt_nda
114
+
115
+
116
+ def augmentation_tumor_liver(pt_nda, output_size, random_seed=None):
117
+ volume = pt_nda.squeeze(0)
118
+ real_l_volume_ = torch.zeros_like(volume)
119
+ real_l_volume_[volume == 1] = 1
120
+ real_l_volume_[volume == 26] = 2
121
+ real_l_volume_ = real_l_volume_.to(torch.uint8)
122
+
123
+ elastic = Rand3DElastic(
124
+ mode="nearest",
125
+ prob=1.0,
126
+ sigma_range=(5, 8),
127
+ magnitude_range=(100, 200),
128
+ translate_range=(10, 10, 10),
129
+ rotate_range=(np.pi / 36, np.pi / 36, np.pi / 36),
130
+ scale_range=(0.2, 0.2, 0.2),
131
+ padding_mode="zeros",
132
+ )
133
+ elastic.set_random_state(seed=random_seed)
134
+
135
+ tumor_szie = torch.sum(real_l_volume_ == 2)
136
+ ###########################
137
+ # remove pred organ labels
138
+ volume[volume == 1] = 0
139
+ volume[volume == 26] = 0
140
+ # before move tumor maks, full the original location by organ labels
141
+ volume[real_l_volume_ == 1] = 1
142
+ volume[real_l_volume_ == 2] = 1
143
+ ###########################
144
+ while True:
145
+ real_l_volume = real_l_volume_
146
+ # random distor mask
147
+ real_l_volume = elastic((real_l_volume == 2).cuda(), spatial_size=tuple(output_size)).as_tensor()
148
+ # get organ mask
149
+ organ_mask = (real_l_volume_ == 1).float() + (real_l_volume_ == 2).float()
150
+
151
+ organ_mask = dilate3d(organ_mask.squeeze(0), erosion=5)
152
+ organ_mask = erode3d(organ_mask, erosion=5).unsqueeze(0)
153
+ real_l_volume = real_l_volume * organ_mask
154
+ print(torch.sum(real_l_volume), "|", tumor_szie * 0.80)
155
+ if torch.sum(real_l_volume) >= tumor_szie * 0.80:
156
+ real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5)
157
+ real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0)
158
+ break
159
+
160
+ volume[real_l_volume == 1] = 26
161
+
162
+ pt_nda = volume.unsqueeze(0)
163
+ return pt_nda
164
+
165
+
166
+ def augmentation_tumor_lung(pt_nda, output_size, random_seed=None):
167
+ volume = pt_nda.squeeze(0)
168
+ real_l_volume_ = torch.zeros_like(volume)
169
+ real_l_volume_[volume == 23] = 1
170
+ real_l_volume_ = real_l_volume_.to(torch.uint8)
171
+
172
+ elastic = Rand3DElastic(
173
+ mode="nearest",
174
+ prob=1.0,
175
+ sigma_range=(5, 8),
176
+ magnitude_range=(100, 200),
177
+ translate_range=(20, 20, 20),
178
+ rotate_range=(np.pi / 36, np.pi / 36, np.pi),
179
+ scale_range=(0.15, 0.15, 0.15),
180
+ padding_mode="zeros",
181
+ )
182
+ elastic.set_random_state(seed=random_seed)
183
+
184
+ tumor_szie = torch.sum(real_l_volume_)
185
+ # before move lung tumor maks, full the original location by lung labels
186
+ new_real_l_volume_ = dilate3d(real_l_volume_.squeeze(0), erosion=3)
187
+ new_real_l_volume_ = new_real_l_volume_.unsqueeze(0)
188
+ new_real_l_volume_[real_l_volume_ > 0] = 0
189
+ new_real_l_volume_[volume < 28] = 0
190
+ new_real_l_volume_[volume > 32] = 0
191
+ tmp = volume[(volume * new_real_l_volume_).nonzero(as_tuple=True)].view(-1)
192
+
193
+ mode = torch.mode(tmp, 0)[0].item()
194
+ print(mode)
195
+ assert 28 <= mode <= 32
196
+ volume[real_l_volume_.bool()] = mode
197
+ ###########################
198
+ if tumor_szie > 0:
199
+ # aug
200
+ while True:
201
+ real_l_volume = real_l_volume_
202
+ # random distor mask
203
+ real_l_volume = elastic(real_l_volume, spatial_size=tuple(output_size)).as_tensor()
204
+ # get lung mask v2 (133 order)
205
+ lung_mask = (
206
+ (volume == 28).float()
207
+ + (volume == 29).float()
208
+ + (volume == 30).float()
209
+ + (volume == 31).float()
210
+ + (volume == 32).float()
211
+ )
212
+
213
+ lung_mask = dilate3d(lung_mask.squeeze(0), erosion=5)
214
+ lung_mask = erode3d(lung_mask, erosion=5).unsqueeze(0)
215
+ real_l_volume = real_l_volume * lung_mask
216
+ print(torch.sum(real_l_volume), "|", tumor_szie * 0.85)
217
+ if torch.sum(real_l_volume) >= tumor_szie * 0.85:
218
+ real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5)
219
+ real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0).to(torch.uint8)
220
+ break
221
+ else:
222
+ real_l_volume = real_l_volume_
223
+
224
+ volume[real_l_volume == 1] = 23
225
+
226
+ pt_nda = volume.unsqueeze(0)
227
+ return pt_nda
228
+
229
+
230
+ def augmentation_tumor_pancreas(pt_nda, output_size, random_seed=None):
231
+ volume = pt_nda.squeeze(0)
232
+ real_l_volume_ = torch.zeros_like(volume)
233
+ real_l_volume_[volume == 4] = 1
234
+ real_l_volume_[volume == 24] = 2
235
+ real_l_volume_ = real_l_volume_.to(torch.uint8)
236
+
237
+ elastic = Rand3DElastic(
238
+ mode="nearest",
239
+ prob=1.0,
240
+ sigma_range=(5, 8),
241
+ magnitude_range=(100, 200),
242
+ translate_range=(15, 15, 15),
243
+ rotate_range=(np.pi / 36, np.pi / 36, np.pi / 36),
244
+ scale_range=(0.1, 0.1, 0.1),
245
+ padding_mode="zeros",
246
+ )
247
+ elastic.set_random_state(seed=random_seed)
248
+
249
+ tumor_szie = torch.sum(real_l_volume_ == 2)
250
+ ###########################
251
+ # remove pred organ labels
252
+ volume[volume == 24] = 0
253
+ volume[volume == 4] = 0
254
+ # before move tumor maks, full the original location by organ labels
255
+ volume[real_l_volume_ == 1] = 4
256
+ volume[real_l_volume_ == 2] = 4
257
+ ###########################
258
+ while True:
259
+ real_l_volume = real_l_volume_
260
+ # random distor mask
261
+ real_l_volume = elastic((real_l_volume == 2).cuda(), spatial_size=tuple(output_size)).as_tensor()
262
+ # get organ mask
263
+ organ_mask = (real_l_volume_ == 1).float() + (real_l_volume_ == 2).float()
264
+
265
+ organ_mask = dilate3d(organ_mask.squeeze(0), erosion=5)
266
+ organ_mask = erode3d(organ_mask, erosion=5).unsqueeze(0)
267
+ real_l_volume = real_l_volume * organ_mask
268
+ print(torch.sum(real_l_volume), "|", tumor_szie * 0.80)
269
+ if torch.sum(real_l_volume) >= tumor_szie * 0.80:
270
+ real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5)
271
+ real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0)
272
+ break
273
+
274
+ volume[real_l_volume == 1] = 24
275
+
276
+ pt_nda = volume.unsqueeze(0)
277
+ return pt_nda
278
+
279
+
280
+ def augmentation_tumor_colon(pt_nda, output_size, random_seed=None):
281
+ volume = pt_nda.squeeze(0)
282
+ real_l_volume_ = torch.zeros_like(volume)
283
+ real_l_volume_[volume == 27] = 1
284
+ real_l_volume_ = real_l_volume_.to(torch.uint8)
285
+
286
+ elastic = Rand3DElastic(
287
+ mode="nearest",
288
+ prob=1.0,
289
+ sigma_range=(5, 8),
290
+ magnitude_range=(100, 200),
291
+ translate_range=(5, 5, 5),
292
+ rotate_range=(np.pi / 36, np.pi / 36, np.pi / 36),
293
+ scale_range=(0.1, 0.1, 0.1),
294
+ padding_mode="zeros",
295
+ )
296
+ elastic.set_random_state(seed=random_seed)
297
+
298
+ tumor_szie = torch.sum(real_l_volume_)
299
+ ###########################
300
+ # before move tumor maks, full the original location by organ labels
301
+ volume[real_l_volume_.bool()] = 62
302
+ ###########################
303
+ if tumor_szie > 0:
304
+ # get organ mask
305
+ organ_mask = (volume == 62).float()
306
+ organ_mask = dilate3d(organ_mask.squeeze(0), erosion=5)
307
+ organ_mask = erode3d(organ_mask, erosion=5).unsqueeze(0)
308
+ # cnt = 0
309
+ cnt = 0
310
+ while True:
311
+ threshold = 0.8
312
+ real_l_volume = real_l_volume_
313
+ if cnt < 20:
314
+ # random distor mask
315
+ distored_mask = elastic((real_l_volume == 1).cuda(), spatial_size=tuple(output_size)).as_tensor()
316
+ real_l_volume = distored_mask * organ_mask
317
+ elif 20 <= cnt < 40:
318
+ threshold = 0.75
319
+ else:
320
+ break
321
+
322
+ real_l_volume = real_l_volume * organ_mask
323
+ print(torch.sum(real_l_volume), "|", tumor_szie * threshold)
324
+ cnt += 1
325
+ if torch.sum(real_l_volume) >= tumor_szie * threshold:
326
+ real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5)
327
+ real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0).to(torch.uint8)
328
+ break
329
+ else:
330
+ real_l_volume = real_l_volume_
331
+ # break
332
+ volume[real_l_volume == 1] = 27
333
+
334
+ pt_nda = volume.unsqueeze(0)
335
+ return pt_nda
336
+
337
+
338
+ def augmentation_body(pt_nda, random_seed=None):
339
+ volume = pt_nda.squeeze(0)
340
+
341
+ zoom = RandZoom(min_zoom=0.99, max_zoom=1.01, mode="nearest", align_corners=None, prob=1.0)
342
+ zoom.set_random_state(seed=random_seed)
343
+
344
+ volume = zoom(volume)
345
+
346
+ pt_nda = volume.unsqueeze(0)
347
+ return pt_nda
348
+
349
+
350
+ def augmentation(pt_nda, output_size, random_seed=None):
351
+ label_list = torch.unique(pt_nda)
352
+ label_list = list(label_list.cpu().numpy())
353
+
354
+ if 128 in label_list:
355
+ print("augmenting bone lesion/tumor")
356
+ pt_nda = augmentation_tumor_bone(pt_nda, output_size, random_seed)
357
+ elif 26 in label_list:
358
+ print("augmenting liver tumor")
359
+ pt_nda = augmentation_tumor_liver(pt_nda, output_size, random_seed)
360
+ elif 23 in label_list:
361
+ print("augmenting lung tumor")
362
+ pt_nda = augmentation_tumor_lung(pt_nda, output_size, random_seed)
363
+ elif 24 in label_list:
364
+ print("augmenting pancreas tumor")
365
+ pt_nda = augmentation_tumor_pancreas(pt_nda, output_size, random_seed)
366
+ elif 27 in label_list:
367
+ print("augmenting colon tumor")
368
+ pt_nda = augmentation_tumor_colon(pt_nda, output_size, random_seed)
369
+ else:
370
+ print("augmenting body")
371
+ pt_nda = augmentation_body(pt_nda, random_seed)
372
+
373
+ return pt_nda
scripts/compute_fid_2-5d_ct.py ADDED
@@ -0,0 +1,747 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at:
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an
9
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
10
+ # either express or implied.
11
+ # See the License for the specific language governing permissions
12
+ # and limitations under the License.
13
+
14
+ """
15
+ Compute 2.5D FID using distributed GPU processing.
16
+
17
+ SHELL Usage Example:
18
+ -------------------
19
+ #!/bin/bash
20
+
21
+ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6
22
+ NUM_GPUS=7
23
+
24
+ torchrun --nproc_per_node=${NUM_GPUS} compute_fid_2-5d_ct.py \
25
+ --model_name "radimagenet_resnet50" \
26
+ --real_dataset_root "path/to/datasetA" \
27
+ --real_filelist "path/to/filelistA.txt" \
28
+ --real_features_dir "datasetA" \
29
+ --synth_dataset_root "path/to/datasetB" \
30
+ --synth_filelist "path/to/filelistB.txt" \
31
+ --synth_features_dir "datasetB" \
32
+ --enable_center_slices_ratio 0.4 \
33
+ --enable_padding True \
34
+ --enable_center_cropping True \
35
+ --enable_resampling_spacing "1.0x1.0x1.0" \
36
+ --ignore_existing True \
37
+ --num_images 100 \
38
+ --output_root "./features/features-512x512x512" \
39
+ --target_shape "512x512x512"
40
+
41
+ This script loads two datasets (real vs. synthetic) in 3D medical format (NIfTI)
42
+ and extracts feature maps via a 2.5D approach. It then computes the Frechet
43
+ Inception Distance (FID) across three orthogonal planes. Data parallelism
44
+ is implemented using torch.distributed with an NCCL backend.
45
+
46
+ Function Arguments (main):
47
+ --------------------------
48
+ real_dataset_root (str):
49
+ Root folder for the real dataset.
50
+
51
+ real_filelist (str):
52
+ Text file listing 3D images for the real dataset.
53
+
54
+ real_features_dir (str):
55
+ Subdirectory (under `output_root`) in which to store feature files
56
+ extracted from the real dataset.
57
+
58
+ synth_dataset_root (str):
59
+ Root folder for the synthetic dataset.
60
+
61
+ synth_filelist (str):
62
+ Text file listing 3D images for the synthetic dataset.
63
+
64
+ synth_features_dir (str):
65
+ Subdirectory (under `output_root`) in which to store feature files
66
+ extracted from the synthetic dataset.
67
+
68
+ enable_center_slices_ratio (float or None):
69
+ - If not None, only slices around the specified center ratio will be used
70
+ (analogous to "enable_center_slices=True" with that ratio).
71
+ - If None, no center-slice selection is performed
72
+ (analogous to "enable_center_slices=False").
73
+
74
+ enable_padding (bool):
75
+ Whether to pad images to `target_shape`.
76
+
77
+ enable_center_cropping (bool):
78
+ Whether to center-crop images to `target_shape`.
79
+
80
+ enable_resampling_spacing (str or None):
81
+ - If not None, resample images to the specified voxel spacing (e.g. "1.0x1.0x1.0")
82
+ (analogous to "enable_resampling=True" with that spacing).
83
+ - If None, resampling is skipped
84
+ (analogous to "enable_resampling=False").
85
+
86
+ ignore_existing (bool):
87
+ If True, ignore any existing .pt feature files and force re-extraction.
88
+
89
+ model_name (str):
90
+ Model identifier. Typically "radimagenet_resnet50" or "squeezenet1_1".
91
+
92
+ num_images (int):
93
+ Max number of images to process from each dataset (truncate if more are present).
94
+
95
+ output_root (str):
96
+ Folder where extracted .pt feature files, logs, and results are saved.
97
+
98
+ target_shape (str):
99
+ Target shape as "XxYxZ" for padding, cropping, or resampling operations.
100
+ """
101
+
102
+
103
+ from __future__ import annotations
104
+
105
+ import os
106
+ import sys
107
+ import torch
108
+ import fire
109
+ import monai
110
+ import re
111
+ import torch.distributed as dist
112
+ import torch.nn.functional as F
113
+
114
+ from datetime import timedelta
115
+ from pathlib import Path
116
+ from monai.metrics.fid import FIDMetric
117
+ from monai.transforms import Compose
118
+
119
+ import logging
120
+
121
+ # ------------------------------------------------------------------------------
122
+ # Create logger
123
+ # ------------------------------------------------------------------------------
124
+ logger = logging.getLogger("fid_2-5d_ct")
125
+ if not logger.handlers:
126
+ # Configure logger only if it has no handlers (avoid reconfiguring in multi-rank scenarios)
127
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
128
+ logger.setLevel(logging.INFO)
129
+
130
+
131
+ def drop_empty_slice(slices, empty_threshold: float):
132
+ """
133
+ Decide which 2D slices to keep by checking if their maximum intensity
134
+ is below a certain threshold.
135
+
136
+ Args:
137
+ slices (tuple or list of Tensors): Each element is (B, C, H, W).
138
+ empty_threshold (float): If the slice's maximum value is below this threshold,
139
+ it is considered "empty".
140
+
141
+ Returns:
142
+ list[bool]: A list of booleans indicating for each slice whether to keep it.
143
+ """
144
+ outputs = []
145
+ n_drop = 0
146
+ for s in slices:
147
+ largest_unique = torch.max(torch.unique(s))
148
+ if largest_unique < empty_threshold:
149
+ outputs.append(False)
150
+ n_drop += 1
151
+ else:
152
+ outputs.append(True)
153
+
154
+ logger.info(f"Empty slice drop rate {round((n_drop/len(slices))*100,1)}%")
155
+ return outputs
156
+
157
+
158
+ def subtract_mean(x: torch.Tensor) -> torch.Tensor:
159
+ """
160
+ Subtract per-channel means (ImageNet-like: [0.406, 0.456, 0.485])
161
+ from the input 4D or 5D tensor. Expects channels in the first dimension
162
+ after the batch dimension: (B, C, H, W) or (B, C, H, W, D).
163
+ """
164
+ mean = [0.406, 0.456, 0.485]
165
+ x[:, 0, ...] -= mean[0]
166
+ x[:, 1, ...] -= mean[1]
167
+ x[:, 2, ...] -= mean[2]
168
+ return x
169
+
170
+
171
+ def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:
172
+ """
173
+ Average out the spatial dimensions of a tensor, preserving or removing them
174
+ according to `keepdim`. This is used to produce a 1D feature vector
175
+ out of a feature map.
176
+
177
+ Args:
178
+ x (torch.Tensor): Input tensor (B, C, H, W, ...) or (B, C, H, W).
179
+ keepdim (bool): Whether to keep dimension or not after averaging.
180
+
181
+ Returns:
182
+ torch.Tensor: Tensor with reduced spatial dimensions.
183
+ """
184
+ dim = len(x.shape)
185
+ # 2D -> no average
186
+ if dim == 2:
187
+ return x
188
+ # 3D -> average over last dim
189
+ if dim == 3:
190
+ return x.mean([2], keepdim=keepdim)
191
+ # 4D -> average over H,W
192
+ if dim == 4:
193
+ return x.mean([2, 3], keepdim=keepdim)
194
+ # 5D -> average over H,W,D
195
+ if dim == 5:
196
+ return x.mean([2, 3, 4], keepdim=keepdim)
197
+ return x
198
+
199
+
200
+ def medicalnet_intensity_normalisation(volume: torch.Tensor) -> torch.Tensor:
201
+ """
202
+ Intensity normalization approach from MedicalNet:
203
+ (volume - mean) / (std + 1e-5) across spatial dims.
204
+ Expects (B, C, H, W) or (B, C, H, W, D).
205
+ """
206
+ dim = len(volume.shape)
207
+ if dim == 4:
208
+ mean = volume.mean([2, 3], keepdim=True)
209
+ std = volume.std([2, 3], keepdim=True)
210
+ elif dim == 5:
211
+ mean = volume.mean([2, 3, 4], keepdim=True)
212
+ std = volume.std([2, 3, 4], keepdim=True)
213
+ else:
214
+ return volume
215
+ return (volume - mean) / (std + 1e-5)
216
+
217
+
218
+ def radimagenet_intensity_normalisation(volume: torch.Tensor, norm2d: bool = False) -> torch.Tensor:
219
+ """
220
+ Intensity normalization for radimagenet_resnet. Optionally normalizes each 2D slice individually.
221
+
222
+ Args:
223
+ volume (torch.Tensor): Input (B, C, H, W) or (B, C, H, W, D).
224
+ norm2d (bool): If True, normalizes each (H,W) slice to [0,1], then subtracts the ImageNet mean.
225
+ """
226
+ logger.info(f"norm2d: {norm2d}")
227
+ dim = len(volume.shape)
228
+ # If norm2d is True, only meaningful for 4D data (B, C, H, W):
229
+ if dim == 4 and norm2d:
230
+ max2d, _ = torch.max(volume, dim=2, keepdim=True)
231
+ max2d, _ = torch.max(max2d, dim=3, keepdim=True)
232
+ min2d, _ = torch.min(volume, dim=2, keepdim=True)
233
+ min2d, _ = torch.min(min2d, dim=3, keepdim=True)
234
+ # Scale each slice to 0..1
235
+ volume = (volume - min2d) / (max2d - min2d + 1e-10)
236
+ # Subtract channel mean
237
+ return subtract_mean(volume)
238
+ elif dim == 4:
239
+ # 4D but no per-slice normalization
240
+ max3d = torch.max(volume)
241
+ min3d = torch.min(volume)
242
+ volume = (volume - min3d) / (max3d - min3d + 1e-10)
243
+ return subtract_mean(volume)
244
+ # Fallback for e.g. 5D data is simply a min-max over entire volume
245
+ if dim == 5:
246
+ maxval = torch.max(volume)
247
+ minval = torch.min(volume)
248
+ volume = (volume - minval) / (maxval - minval + 1e-10)
249
+ return subtract_mean(volume)
250
+ return volume
251
+
252
+
253
+ def get_features_2p5d(
254
+ image: torch.Tensor,
255
+ feature_network: torch.nn.Module,
256
+ center_slices: bool = False,
257
+ center_slices_ratio: float = 1.0,
258
+ sample_every_k: int = 1,
259
+ xy_only: bool = True,
260
+ drop_empty: bool = False,
261
+ empty_threshold: float = -700,
262
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
263
+ """
264
+ Extract 2.5D features from a 3D image by slicing it along XY, YZ, ZX planes.
265
+
266
+ Args:
267
+ image (torch.Tensor): Input 5D tensor in shape (B, C, H, W, D).
268
+ feature_network (torch.nn.Module): Model that processes 2D slices (C,H,W).
269
+ center_slices (bool): Whether to slice only the center portion of each axis.
270
+ center_slices_ratio (float): Ratio of slices to keep in the center if `center_slices` is True.
271
+ sample_every_k (int): Downsampling factor along each axis when slicing.
272
+ xy_only (bool): If True, return only the XY-plane features.
273
+ drop_empty (bool): Drop slices that are deemed "empty" below `empty_threshold`.
274
+ empty_threshold (float): Threshold to decide emptiness of slices.
275
+
276
+ Returns:
277
+ tuple of torch.Tensor or None: (XY_features, YZ_features, ZX_features).
278
+ """
279
+ logger.info(f"center_slices: {center_slices}, ratio: {center_slices_ratio}")
280
+
281
+ # If there's only 1 channel, replicate to 3 channels
282
+ if image.shape[1] == 1:
283
+ image = image.repeat(1, 3, 1, 1, 1)
284
+
285
+ # Convert from 'RGB'→(R,G,B) to (B,G,R)
286
+ image = image[:, [2, 1, 0], ...]
287
+
288
+ B, C, H, W, D = image.size()
289
+ with torch.no_grad():
290
+ # ---------------------- XY-plane slicing along D ----------------------
291
+ if center_slices:
292
+ start_d = int((1.0 - center_slices_ratio) / 2.0 * D)
293
+ end_d = int((1.0 + center_slices_ratio) / 2.0 * D)
294
+ slices = torch.unbind(image[:, :, :, :, start_d:end_d:sample_every_k], dim=-1)
295
+ else:
296
+ slices = torch.unbind(image, dim=-1)
297
+
298
+ if drop_empty:
299
+ mapping_index = drop_empty_slice(slices, empty_threshold)
300
+ else:
301
+ mapping_index = [True for _ in range(len(slices))]
302
+
303
+ images_2d = torch.cat(slices, dim=0)
304
+ images_2d = radimagenet_intensity_normalisation(images_2d)
305
+ images_2d = images_2d[mapping_index]
306
+
307
+ feature_image_xy = feature_network.forward(images_2d)
308
+ feature_image_xy = spatial_average(feature_image_xy, keepdim=False)
309
+ if xy_only:
310
+ return feature_image_xy, None, None
311
+
312
+ # ---------------------- YZ-plane slicing along H ----------------------
313
+ if center_slices:
314
+ start_h = int((1.0 - center_slices_ratio) / 2.0 * H)
315
+ end_h = int((1.0 + center_slices_ratio) / 2.0 * H)
316
+ slices = torch.unbind(image[:, :, start_h:end_h:sample_every_k, :, :], dim=2)
317
+ else:
318
+ slices = torch.unbind(image, dim=2)
319
+
320
+ if drop_empty:
321
+ mapping_index = drop_empty_slice(slices, empty_threshold)
322
+ else:
323
+ mapping_index = [True for _ in range(len(slices))]
324
+
325
+ images_2d = torch.cat(slices, dim=0)
326
+ images_2d = radimagenet_intensity_normalisation(images_2d)
327
+ images_2d = images_2d[mapping_index]
328
+
329
+ feature_image_yz = feature_network.forward(images_2d)
330
+ feature_image_yz = spatial_average(feature_image_yz, keepdim=False)
331
+
332
+ # ---------------------- ZX-plane slicing along W ----------------------
333
+ if center_slices:
334
+ start_w = int((1.0 - center_slices_ratio) / 2.0 * W)
335
+ end_w = int((1.0 + center_slices_ratio) / 2.0 * W)
336
+ slices = torch.unbind(image[:, :, :, start_w:end_w:sample_every_k, :], dim=3)
337
+ else:
338
+ slices = torch.unbind(image, dim=3)
339
+
340
+ if drop_empty:
341
+ mapping_index = drop_empty_slice(slices, empty_threshold)
342
+ else:
343
+ mapping_index = [True for _ in range(len(slices))]
344
+
345
+ images_2d = torch.cat(slices, dim=0)
346
+ images_2d = radimagenet_intensity_normalisation(images_2d)
347
+ images_2d = images_2d[mapping_index]
348
+
349
+ feature_image_zx = feature_network.forward(images_2d)
350
+ feature_image_zx = spatial_average(feature_image_zx, keepdim=False)
351
+
352
+ return feature_image_xy, feature_image_yz, feature_image_zx
353
+
354
+
355
+ def pad_to_max_size(tensor: torch.Tensor, max_size: int, padding_value: float = 0.0) -> torch.Tensor:
356
+ """
357
+ Zero-pad a 2D feature map or other tensor along the first dimension to match a specified size.
358
+
359
+ Args:
360
+ tensor (torch.Tensor): The feature tensor to pad.
361
+ max_size (int): Desired size along the first dimension.
362
+ padding_value (float): Value to fill during padding.
363
+
364
+ Returns:
365
+ torch.Tensor: Padded tensor matching `max_size` along dim=0.
366
+ """
367
+ pad_size = [0, 0] * (len(tensor.shape) - 1) + [0, max_size - tensor.shape[0]]
368
+ return F.pad(tensor, pad_size, "constant", padding_value)
369
+
370
+
371
+ def main(
372
+ real_dataset_root: str = "path/to/datasetA",
373
+ real_filelist: str = "path/to/filelistA.txt",
374
+ real_features_dir: str = "datasetA",
375
+ synth_dataset_root: str = "path/to/datasetB",
376
+ synth_filelist: str = "path/to/filelistB.txt",
377
+ synth_features_dir: str = "datasetB",
378
+ enable_center_slices_ratio: float = None,
379
+ enable_padding: bool = True,
380
+ enable_center_cropping: bool = True,
381
+ enable_resampling_spacing: str = None,
382
+ ignore_existing: bool = False,
383
+ model_name: str = "radimagenet_resnet50",
384
+ num_images: int = 100,
385
+ output_root: str = "./features/features-512x512x512",
386
+ target_shape: str = "512x512x512",
387
+ ):
388
+ """
389
+ Compute 2.5D FID using distributed GPU processing.
390
+
391
+ This function loads two datasets (real vs. synthetic) in 3D medical format (NIfTI)
392
+ and extracts feature maps via a 2.5D approach, then computes the Frechet Inception
393
+ Distance (FID) across three orthogonal planes. Data parallelism is implemented
394
+ using torch.distributed with an NCCL backend.
395
+
396
+ Args:
397
+ real_dataset_root (str):
398
+ Root folder for the real dataset.
399
+ real_filelist (str):
400
+ Path to a text file listing 3D images (e.g., NIfTI files) for the real dataset.
401
+ Each line in this file should contain a relative path (or filename) to a NIfTI file.
402
+ For example, your "real_filelist.txt" could look like:
403
+ case001.nii.gz
404
+ case002.nii.gz
405
+ case003.nii.gz
406
+ ...
407
+ These entries will be appended to `real_dataset_root`.
408
+ real_features_dir (str):
409
+ Name of the directory under `output_root` in which to store
410
+ extracted features for the real dataset.
411
+
412
+ synth_dataset_root (str):
413
+ Root folder for the synthetic dataset.
414
+ synth_filelist (str):
415
+ Path to a text file listing 3D images (e.g., NIfTI files) for the synthetic dataset.
416
+ The format is the same as the real dataset file list, for example:
417
+ synth_case001.nii.gz
418
+ synth_case002.nii.gz
419
+ synth_case003.nii.gz
420
+ ...
421
+ These entries will be appended to `synth_dataset_root`.
422
+ synth_features_dir (str):
423
+ Name of the directory under `output_root` in which to store
424
+ extracted features for the synthetic dataset.
425
+
426
+ enable_center_slices_ratio (float or None):
427
+ - If not None, only slices around the specified center ratio are used.
428
+ (similar to "enable_center_slices=True" with that ratio in an earlier script).
429
+ - If None, no center-slice selection is performed
430
+ (similar to "enable_center_slices=False").
431
+
432
+ enable_padding (bool):
433
+ Whether to pad images to `target_shape`.
434
+
435
+ enable_center_cropping (bool):
436
+ Whether to center-crop images to `target_shape`.
437
+
438
+ enable_resampling_spacing (str or None):
439
+ - If not None, resample images to this voxel spacing (e.g. "1.0x1.0x1.0")
440
+ (similar to "enable_resampling=True" with that spacing).
441
+ - If None, skip resampling (similar to "enable_resampling=False").
442
+
443
+ ignore_existing (bool):
444
+ If True, ignore any existing .pt feature files and force re-computation.
445
+
446
+ model_name (str):
447
+ Model identifier. Typically "radimagenet_resnet50" or "squeezenet1_1".
448
+
449
+ num_images (int):
450
+ Maximum number of images to load from each dataset (truncate if more are present).
451
+
452
+ output_root (str):
453
+ Parent folder where extracted .pt files and logs will be saved.
454
+
455
+ target_shape (str):
456
+ Target shape, e.g. "512x512x512", for padding, cropping, or resampling operations.
457
+
458
+ Returns:
459
+ None
460
+ """
461
+ # -------------------------------------------------------------------------
462
+ # Initialize Process Group (Distributed)
463
+ # -------------------------------------------------------------------------
464
+ dist.init_process_group(backend="nccl", init_method="env://", timeout=timedelta(seconds=7200))
465
+
466
+ local_rank = int(os.environ["LOCAL_RANK"])
467
+ world_size = int(dist.get_world_size())
468
+ device = torch.device("cuda", local_rank)
469
+ torch.cuda.set_device(device)
470
+ logger.info(f"[INFO] Running process on {device} of total {world_size} ranks.")
471
+
472
+ # Convert potential string bools to actual bools (if using Fire or similar)
473
+ if not isinstance(enable_padding, bool):
474
+ enable_padding = enable_padding.lower() == "true"
475
+ if not isinstance(enable_center_cropping, bool):
476
+ enable_center_cropping = enable_center_cropping.lower() == "true"
477
+ if not isinstance(ignore_existing, bool):
478
+ ignore_existing = ignore_existing.lower() == "true"
479
+
480
+ # Merge logic for center slices
481
+ enable_center_slices = enable_center_slices_ratio is not None
482
+
483
+ # Merge logic for resampling
484
+ enable_resampling = enable_resampling_spacing is not None
485
+
486
+ # Print out some flags on rank 0
487
+ if local_rank == 0:
488
+ logger.info(f"Real dataset root: {real_dataset_root}")
489
+ logger.info(f"Synth dataset root: {synth_dataset_root}")
490
+ logger.info(f"enable_center_slices_ratio: {enable_center_slices_ratio}")
491
+ logger.info(f"enable_center_slices: {enable_center_slices}")
492
+ logger.info(f"enable_padding: {enable_padding}")
493
+ logger.info(f"enable_center_cropping: {enable_center_cropping}")
494
+ logger.info(f"enable_resampling_spacing: {enable_resampling_spacing}")
495
+ logger.info(f"enable_resampling: {enable_resampling}")
496
+ logger.info(f"ignore_existing: {ignore_existing}")
497
+
498
+ # -------------------------------------------------------------------------
499
+ # Load feature extraction model
500
+ # -------------------------------------------------------------------------
501
+ if model_name == "radimagenet_resnet50":
502
+ feature_network = torch.hub.load(
503
+ "Warvito/radimagenet-models", model="radimagenet_resnet50", verbose=True, trust_repo=True
504
+ )
505
+ suffix = "radimagenet_resnet50"
506
+ else:
507
+ import torchvision
508
+
509
+ feature_network = torchvision.models.squeezenet1_1(pretrained=True)
510
+ suffix = "squeezenet1_1"
511
+
512
+ feature_network.to(device)
513
+ feature_network.eval()
514
+
515
+ # -------------------------------------------------------------------------
516
+ # Parse shape/spacings
517
+ # -------------------------------------------------------------------------
518
+ t_shape = [int(x) for x in target_shape.split("x")]
519
+ target_shape_tuple = tuple(t_shape)
520
+
521
+ # If not None, parse the resampling spacing
522
+ if enable_resampling:
523
+ rs_spacing = [float(x) for x in enable_resampling_spacing.split("x")]
524
+ rs_spacing_tuple = tuple(rs_spacing)
525
+ if local_rank == 0:
526
+ logger.info(f"Resampling spacing: {rs_spacing_tuple}")
527
+ else:
528
+ rs_spacing_tuple = (1.0, 1.0, 1.0)
529
+
530
+ # Use the ratio if provided, otherwise 1.0
531
+ center_slices_ratio_final = enable_center_slices_ratio if enable_center_slices else 1.0
532
+ if local_rank == 0:
533
+ logger.info(f"center_slices_ratio: {center_slices_ratio_final}")
534
+
535
+ # -------------------------------------------------------------------------
536
+ # Prepare Real Dataset
537
+ # -------------------------------------------------------------------------
538
+ output_root_real = os.path.join(output_root, real_features_dir)
539
+ with open(real_filelist, "r") as rf:
540
+ real_lines = [l.strip() for l in rf.readlines()]
541
+ real_lines.sort()
542
+ real_lines = real_lines[:num_images]
543
+
544
+ real_filenames = [{"image": os.path.join(real_dataset_root, f)} for f in real_lines]
545
+ real_filenames = monai.data.partition_dataset(
546
+ data=real_filenames, shuffle=False, num_partitions=world_size, even_divisible=False
547
+ )[local_rank]
548
+
549
+ # -------------------------------------------------------------------------
550
+ # Prepare Synthetic Dataset
551
+ # -------------------------------------------------------------------------
552
+ output_root_synth = os.path.join(output_root, synth_features_dir)
553
+ with open(synth_filelist, "r") as sf:
554
+ synth_lines = [l.strip() for l in sf.readlines()]
555
+ synth_lines.sort()
556
+ synth_lines = synth_lines[:num_images]
557
+
558
+ synth_filenames = [{"image": os.path.join(synth_dataset_root, f)} for f in synth_lines]
559
+ synth_filenames = monai.data.partition_dataset(
560
+ data=synth_filenames, shuffle=False, num_partitions=world_size, even_divisible=False
561
+ )[local_rank]
562
+
563
+ # -------------------------------------------------------------------------
564
+ # Build MONAI transforms
565
+ # -------------------------------------------------------------------------
566
+ transform_list = [
567
+ monai.transforms.LoadImaged(keys=["image"]),
568
+ monai.transforms.EnsureChannelFirstd(keys=["image"]),
569
+ monai.transforms.Orientationd(keys=["image"], axcodes="RAS"),
570
+ ]
571
+
572
+ if enable_resampling:
573
+ transform_list.append(monai.transforms.Spacingd(keys=["image"], pixdim=rs_spacing_tuple, mode=["bilinear"]))
574
+
575
+ if enable_padding:
576
+ transform_list.append(
577
+ monai.transforms.SpatialPadd(keys=["image"], spatial_size=target_shape_tuple, mode="constant", value=-1000)
578
+ )
579
+
580
+ if enable_center_cropping:
581
+ transform_list.append(monai.transforms.CenterSpatialCropd(keys=["image"], roi_size=target_shape_tuple))
582
+
583
+ transform_list.append(
584
+ monai.transforms.ScaleIntensityRanged(
585
+ keys=["image"], a_min=-1000, a_max=1000, b_min=-1000, b_max=1000, clip=True
586
+ )
587
+ )
588
+ transforms = Compose(transform_list)
589
+
590
+ # -------------------------------------------------------------------------
591
+ # Create DataLoaders
592
+ # -------------------------------------------------------------------------
593
+ real_ds = monai.data.Dataset(data=real_filenames, transform=transforms)
594
+ real_loader = monai.data.DataLoader(real_ds, num_workers=6, batch_size=1, shuffle=False)
595
+
596
+ synth_ds = monai.data.Dataset(data=synth_filenames, transform=transforms)
597
+ synth_loader = monai.data.DataLoader(synth_ds, num_workers=6, batch_size=1, shuffle=False)
598
+
599
+ # -------------------------------------------------------------------------
600
+ # Extract features for Real Dataset
601
+ # -------------------------------------------------------------------------
602
+ real_features_xy, real_features_yz, real_features_zx = [], [], []
603
+ for idx, batch_data in enumerate(real_loader, start=1):
604
+ img = batch_data["image"].to(device)
605
+ fn = img.meta["filename_or_obj"][0]
606
+ logger.info(f"[Rank {local_rank}] Real data {idx}/{len(real_filenames)}: {fn}")
607
+
608
+ out_fp = fn.replace(real_dataset_root, output_root_real).replace(".nii.gz", ".pt")
609
+ out_fp = Path(out_fp)
610
+ out_fp.parent.mkdir(parents=True, exist_ok=True)
611
+
612
+ if (not ignore_existing) and os.path.isfile(out_fp):
613
+ feats = torch.load(out_fp, weights_only=True)
614
+ else:
615
+ img_t = img.as_tensor()
616
+ logger.info(f"image shape: {tuple(img_t.shape)}")
617
+
618
+ feats = get_features_2p5d(
619
+ img_t,
620
+ feature_network,
621
+ center_slices=enable_center_slices,
622
+ center_slices_ratio=center_slices_ratio_final,
623
+ xy_only=False,
624
+ )
625
+ logger.info(f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}")
626
+ torch.save(feats, out_fp)
627
+
628
+ real_features_xy.append(feats[0])
629
+ real_features_yz.append(feats[1])
630
+ real_features_zx.append(feats[2])
631
+
632
+ real_features_xy = torch.vstack(real_features_xy)
633
+ real_features_yz = torch.vstack(real_features_yz)
634
+ real_features_zx = torch.vstack(real_features_zx)
635
+ logger.info(
636
+ f"Real feature shapes: {real_features_xy.shape}, " f"{real_features_yz.shape}, {real_features_zx.shape}"
637
+ )
638
+
639
+ # -------------------------------------------------------------------------
640
+ # Extract features for Synthetic Dataset
641
+ # -------------------------------------------------------------------------
642
+ synth_features_xy, synth_features_yz, synth_features_zx = [], [], []
643
+ for idx, batch_data in enumerate(synth_loader, start=1):
644
+ img = batch_data["image"].to(device)
645
+ fn = img.meta["filename_or_obj"][0]
646
+ logger.info(f"[Rank {local_rank}] Synth data {idx}/{len(synth_filenames)}: {fn}")
647
+
648
+ out_fp = fn.replace(synth_dataset_root, output_root_synth).replace(".nii.gz", ".pt")
649
+ out_fp = Path(out_fp)
650
+ out_fp.parent.mkdir(parents=True, exist_ok=True)
651
+
652
+ if (not ignore_existing) and os.path.isfile(out_fp):
653
+ feats = torch.load(out_fp, weights_only=True)
654
+ else:
655
+ img_t = img.as_tensor()
656
+ logger.info(f"image shape: {tuple(img_t.shape)}")
657
+
658
+ feats = get_features_2p5d(
659
+ img_t,
660
+ feature_network,
661
+ center_slices=enable_center_slices,
662
+ center_slices_ratio=center_slices_ratio_final,
663
+ xy_only=False,
664
+ )
665
+ logger.info(f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}")
666
+ torch.save(feats, out_fp)
667
+
668
+ synth_features_xy.append(feats[0])
669
+ synth_features_yz.append(feats[1])
670
+ synth_features_zx.append(feats[2])
671
+
672
+ synth_features_xy = torch.vstack(synth_features_xy)
673
+ synth_features_yz = torch.vstack(synth_features_yz)
674
+ synth_features_zx = torch.vstack(synth_features_zx)
675
+ logger.info(
676
+ f"Synth feature shapes: {synth_features_xy.shape}, " f"{synth_features_yz.shape}, {synth_features_zx.shape}"
677
+ )
678
+
679
+ # -------------------------------------------------------------------------
680
+ # All-reduce / gather features across ranks
681
+ # -------------------------------------------------------------------------
682
+ features = [
683
+ real_features_xy,
684
+ real_features_yz,
685
+ real_features_zx,
686
+ synth_features_xy,
687
+ synth_features_yz,
688
+ synth_features_zx,
689
+ ]
690
+
691
+ # 1) Gather local feature sizes across ranks
692
+ local_sizes = []
693
+ for ft_idx in range(len(features)):
694
+ local_size = torch.tensor([features[ft_idx].shape[0]], dtype=torch.int64, device=device)
695
+ local_sizes.append(local_size)
696
+
697
+ all_sizes = []
698
+ for ft_idx in range(len(features)):
699
+ rank_sizes = [torch.tensor([0], dtype=torch.int64, device=device) for _ in range(world_size)]
700
+ dist.all_gather(rank_sizes, local_sizes[ft_idx])
701
+ all_sizes.append(rank_sizes)
702
+
703
+ # 2) Pad and gather all features
704
+ all_tensors_list = []
705
+ for ft_idx, ft in enumerate(features):
706
+ max_size = max(all_sizes[ft_idx]).item()
707
+ ft_padded = pad_to_max_size(ft, max_size)
708
+
709
+ gather_list = [torch.empty_like(ft_padded) for _ in range(world_size)]
710
+ dist.all_gather(gather_list, ft_padded)
711
+
712
+ # Trim each gather back to the real size
713
+ for rk in range(world_size):
714
+ gather_list[rk] = gather_list[rk][: all_sizes[ft_idx][rk], :]
715
+
716
+ all_tensors_list.append(gather_list)
717
+
718
+ # On rank 0, compute FID
719
+ if local_rank == 0:
720
+ real_xy = torch.vstack(all_tensors_list[0])
721
+ real_yz = torch.vstack(all_tensors_list[1])
722
+ real_zx = torch.vstack(all_tensors_list[2])
723
+
724
+ synth_xy = torch.vstack(all_tensors_list[3])
725
+ synth_yz = torch.vstack(all_tensors_list[4])
726
+ synth_zx = torch.vstack(all_tensors_list[5])
727
+
728
+ logger.info(f"Final Real shapes: {real_xy.shape}, {real_yz.shape}, {real_zx.shape}")
729
+ logger.info(f"Final Synth shapes: {synth_xy.shape}, {synth_yz.shape}, {synth_zx.shape}")
730
+
731
+ fid = FIDMetric()
732
+ logger.info(f"Computing FID for: {output_root_real} | {output_root_synth}")
733
+ fid_res_xy = fid(synth_xy, real_xy)
734
+ fid_res_yz = fid(synth_yz, real_yz)
735
+ fid_res_zx = fid(synth_zx, real_zx)
736
+
737
+ logger.info(f"FID XY: {fid_res_xy}")
738
+ logger.info(f"FID YZ: {fid_res_yz}")
739
+ logger.info(f"FID ZX: {fid_res_zx}")
740
+ fid_avg = (fid_res_xy + fid_res_yz + fid_res_zx) / 3.0
741
+ logger.info(f"FID Avg: {fid_avg}")
742
+
743
+ dist.destroy_process_group()
744
+
745
+
746
+ if __name__ == "__main__":
747
+ fire.Fire(main)
scripts/diff_model_create_training_data.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import json
16
+ import logging
17
+ import os
18
+ from pathlib import Path
19
+
20
+ import monai
21
+ import nibabel as nib
22
+ import numpy as np
23
+ import torch
24
+ import torch.distributed as dist
25
+ from monai.transforms import Compose
26
+ from monai.utils import set_determinism
27
+
28
+ from .diff_model_setting import initialize_distributed, load_config, setup_logging
29
+ from .utils import define_instance
30
+
31
+ # Set the random seed for reproducibility
32
+ set_determinism(seed=0)
33
+
34
+
35
+ def create_transforms(dim: tuple = None) -> Compose:
36
+ """
37
+ Create a set of MONAI transforms for preprocessing.
38
+
39
+ Args:
40
+ dim (tuple, optional): New dimensions for resizing. Defaults to None.
41
+
42
+ Returns:
43
+ Compose: Composed MONAI transforms.
44
+ """
45
+ if dim:
46
+ return Compose(
47
+ [
48
+ monai.transforms.LoadImaged(keys="image"),
49
+ monai.transforms.EnsureChannelFirstd(keys="image"),
50
+ monai.transforms.Orientationd(keys="image", axcodes="RAS"),
51
+ monai.transforms.EnsureTyped(keys="image", dtype=torch.float32),
52
+ monai.transforms.ScaleIntensityRanged(
53
+ keys="image", a_min=-1000, a_max=1000, b_min=0, b_max=1, clip=True
54
+ ),
55
+ monai.transforms.Resized(keys="image", spatial_size=dim, mode="trilinear"),
56
+ ]
57
+ )
58
+ else:
59
+ return Compose(
60
+ [
61
+ monai.transforms.LoadImaged(keys="image"),
62
+ monai.transforms.EnsureChannelFirstd(keys="image"),
63
+ monai.transforms.Orientationd(keys="image", axcodes="RAS"),
64
+ ]
65
+ )
66
+
67
+
68
+ def round_number(number: int, base_number: int = 128) -> int:
69
+ """
70
+ Round the number to the nearest multiple of the base number, with a minimum value of the base number.
71
+
72
+ Args:
73
+ number (int): Number to be rounded.
74
+ base_number (int): Number to be common divisor.
75
+
76
+ Returns:
77
+ int: Rounded number.
78
+ """
79
+ new_number = max(round(float(number) / float(base_number)), 1.0) * float(base_number)
80
+ return int(new_number)
81
+
82
+
83
+ def load_filenames(data_list_path: str) -> list:
84
+ """
85
+ Load filenames from the JSON data list.
86
+
87
+ Args:
88
+ data_list_path (str): Path to the JSON data list file.
89
+
90
+ Returns:
91
+ list: List of filenames.
92
+ """
93
+ with open(data_list_path, "r") as file:
94
+ json_data = json.load(file)
95
+ filenames_raw = json_data["training"]
96
+ return [_item["image"] for _item in filenames_raw]
97
+
98
+
99
+ def process_file(
100
+ filepath: str,
101
+ args: argparse.Namespace,
102
+ autoencoder: torch.nn.Module,
103
+ device: torch.device,
104
+ plain_transforms: Compose,
105
+ new_transforms: Compose,
106
+ logger: logging.Logger,
107
+ ) -> None:
108
+ """
109
+ Process a single file to create training data.
110
+
111
+ Args:
112
+ filepath (str): Path to the file to be processed.
113
+ args (argparse.Namespace): Configuration arguments.
114
+ autoencoder (torch.nn.Module): Autoencoder model.
115
+ device (torch.device): Device to process the file on.
116
+ plain_transforms (Compose): Plain transforms.
117
+ new_transforms (Compose): New transforms.
118
+ logger (logging.Logger): Logger for logging information.
119
+ """
120
+ out_filename_base = filepath.replace(".gz", "").replace(".nii", "")
121
+ out_filename_base = os.path.join(args.embedding_base_dir, out_filename_base)
122
+ out_filename = out_filename_base + "_emb.nii.gz"
123
+
124
+ if os.path.isfile(out_filename):
125
+ return
126
+
127
+ test_data = {"image": os.path.join(args.data_base_dir, filepath)}
128
+ transformed_data = plain_transforms(test_data)
129
+ nda = transformed_data["image"]
130
+
131
+ dim = [int(nda.meta["dim"][_i]) for _i in range(1, 4)]
132
+ spacing = [float(nda.meta["pixdim"][_i]) for _i in range(1, 4)]
133
+
134
+ logger.info(f"old dim: {dim}, old spacing: {spacing}")
135
+
136
+ new_data = new_transforms(test_data)
137
+ nda_image = new_data["image"]
138
+
139
+ new_affine = nda_image.meta["affine"].numpy()
140
+ nda_image = nda_image.numpy().squeeze()
141
+
142
+ logger.info(f"new dim: {nda_image.shape}, new affine: {new_affine}")
143
+
144
+ try:
145
+ out_path = Path(out_filename)
146
+ out_path.parent.mkdir(parents=True, exist_ok=True)
147
+ logger.info(f"out_filename: {out_filename}")
148
+
149
+ with torch.amp.autocast("cuda"):
150
+ pt_nda = torch.from_numpy(nda_image).float().to(device).unsqueeze(0).unsqueeze(0)
151
+ z = autoencoder.encode_stage_2_inputs(pt_nda)
152
+ logger.info(f"z: {z.size()}, {z.dtype}")
153
+
154
+ out_nda = z.squeeze().cpu().detach().numpy().transpose(1, 2, 3, 0)
155
+ out_img = nib.Nifti1Image(np.float32(out_nda), affine=new_affine)
156
+ nib.save(out_img, out_filename)
157
+ except Exception as e:
158
+ logger.error(f"Error processing {filepath}: {e}")
159
+
160
+
161
+ @torch.inference_mode()
162
+ def diff_model_create_training_data(
163
+ env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int
164
+ ) -> None:
165
+ """
166
+ Create training data for the diffusion model.
167
+
168
+ Args:
169
+ env_config_path (str): Path to the environment configuration file.
170
+ model_config_path (str): Path to the model configuration file.
171
+ model_def_path (str): Path to the model definition file.
172
+ """
173
+ args = load_config(env_config_path, model_config_path, model_def_path)
174
+ local_rank, world_size, device = initialize_distributed(num_gpus=num_gpus)
175
+ logger = setup_logging("creating training data")
176
+ logger.info(f"Using device {device}")
177
+
178
+ autoencoder = define_instance(args, "autoencoder_def").to(device)
179
+ try:
180
+ checkpoint_autoencoder = torch.load(args.trained_autoencoder_path, weights_only=True)
181
+ autoencoder.load_state_dict(checkpoint_autoencoder)
182
+ except Exception:
183
+ logger.error("The trained_autoencoder_path does not exist!")
184
+
185
+ Path(args.embedding_base_dir).mkdir(parents=True, exist_ok=True)
186
+
187
+ filenames_raw = load_filenames(args.json_data_list)
188
+ logger.info(f"filenames_raw: {filenames_raw}")
189
+
190
+ plain_transforms = create_transforms(dim=None)
191
+
192
+ for _iter in range(len(filenames_raw)):
193
+ if _iter % world_size != local_rank:
194
+ continue
195
+
196
+ filepath = filenames_raw[_iter]
197
+ new_dim = tuple(
198
+ round_number(
199
+ int(plain_transforms({"image": os.path.join(args.data_base_dir, filepath)})["image"].meta["dim"][_i])
200
+ )
201
+ for _i in range(1, 4)
202
+ )
203
+ new_transforms = create_transforms(new_dim)
204
+
205
+ process_file(filepath, args, autoencoder, device, plain_transforms, new_transforms, logger)
206
+
207
+ if dist.is_initialized():
208
+ dist.destroy_process_group()
209
+
210
+
211
+ if __name__ == "__main__":
212
+ parser = argparse.ArgumentParser(description="Diffusion Model Training Data Creation")
213
+ parser.add_argument(
214
+ "--env_config",
215
+ type=str,
216
+ default="./configs/environment_maisi_diff_model_train.json",
217
+ help="Path to environment configuration file",
218
+ )
219
+ parser.add_argument(
220
+ "--model_config",
221
+ type=str,
222
+ default="./configs/config_maisi_diff_model_train.json",
223
+ help="Path to model training/inference configuration",
224
+ )
225
+ parser.add_argument(
226
+ "--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
227
+ )
228
+ parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for distributed training")
229
+
230
+ args = parser.parse_args()
231
+ diff_model_create_training_data(args.env_config, args.model_config, args.model_def, args.num_gpus)
scripts/diff_model_infer.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import logging
16
+ import os
17
+ import random
18
+ from datetime import datetime
19
+
20
+ import nibabel as nib
21
+ import numpy as np
22
+ import torch
23
+ import torch.distributed as dist
24
+ from monai.inferers import sliding_window_inference
25
+ from monai.inferers.inferer import SlidingWindowInferer
26
+ from monai.networks.schedulers import RFlowScheduler
27
+ from monai.utils import set_determinism
28
+ from tqdm import tqdm
29
+
30
+ from .diff_model_setting import initialize_distributed, load_config, setup_logging
31
+ from .sample import ReconModel, check_input
32
+ from .utils import define_instance, dynamic_infer
33
+
34
+
35
+ def set_random_seed(seed: int) -> int:
36
+ """
37
+ Set random seed for reproducibility.
38
+
39
+ Args:
40
+ seed (int): Random seed.
41
+
42
+ Returns:
43
+ int: Set random seed.
44
+ """
45
+ random_seed = random.randint(0, 99999) if seed is None else seed
46
+ set_determinism(random_seed)
47
+ return random_seed
48
+
49
+
50
+ def load_models(args: argparse.Namespace, device: torch.device, logger: logging.Logger) -> tuple:
51
+ """
52
+ Load the autoencoder and UNet models.
53
+
54
+ Args:
55
+ args (argparse.Namespace): Configuration arguments.
56
+ device (torch.device): Device to load models on.
57
+ logger (logging.Logger): Logger for logging information.
58
+
59
+ Returns:
60
+ tuple: Loaded autoencoder, UNet model, and scale factor.
61
+ """
62
+ autoencoder = define_instance(args, "autoencoder_def").to(device)
63
+ try:
64
+ checkpoint_autoencoder = torch.load(args.trained_autoencoder_path, weights_only=True)
65
+ autoencoder.load_state_dict(checkpoint_autoencoder)
66
+ except Exception:
67
+ logger.error("The trained_autoencoder_path does not exist!")
68
+
69
+ unet = define_instance(args, "diffusion_unet_def").to(device)
70
+ checkpoint = torch.load(f"{args.model_dir}/{args.model_filename}", map_location=device, weights_only=False)
71
+ unet.load_state_dict(checkpoint["unet_state_dict"], strict=True)
72
+ logger.info(f"checkpoints {args.model_dir}/{args.model_filename} loaded.")
73
+
74
+ scale_factor = checkpoint["scale_factor"]
75
+ logger.info(f"scale_factor -> {scale_factor}.")
76
+
77
+ return autoencoder, unet, scale_factor
78
+
79
+
80
+ def prepare_tensors(args: argparse.Namespace, device: torch.device) -> tuple:
81
+ """
82
+ Prepare necessary tensors for inference.
83
+
84
+ Args:
85
+ args (argparse.Namespace): Configuration arguments.
86
+ device (torch.device): Device to load tensors on.
87
+
88
+ Returns:
89
+ tuple: Prepared top_region_index_tensor, bottom_region_index_tensor, and spacing_tensor.
90
+ """
91
+ top_region_index_tensor = np.array(args.diffusion_unet_inference["top_region_index"]).astype(float) * 1e2
92
+ bottom_region_index_tensor = np.array(args.diffusion_unet_inference["bottom_region_index"]).astype(float) * 1e2
93
+ spacing_tensor = np.array(args.diffusion_unet_inference["spacing"]).astype(float) * 1e2
94
+
95
+ top_region_index_tensor = torch.from_numpy(top_region_index_tensor[np.newaxis, :]).half().to(device)
96
+ bottom_region_index_tensor = torch.from_numpy(bottom_region_index_tensor[np.newaxis, :]).half().to(device)
97
+ spacing_tensor = torch.from_numpy(spacing_tensor[np.newaxis, :]).half().to(device)
98
+ modality_tensor = args.diffusion_unet_inference["modality"] * torch.ones(
99
+ (len(spacing_tensor)), dtype=torch.long
100
+ ).to(device)
101
+
102
+ return top_region_index_tensor, bottom_region_index_tensor, spacing_tensor, modality_tensor
103
+
104
+
105
+ def run_inference(
106
+ args: argparse.Namespace,
107
+ device: torch.device,
108
+ autoencoder: torch.nn.Module,
109
+ unet: torch.nn.Module,
110
+ scale_factor: float,
111
+ top_region_index_tensor: torch.Tensor,
112
+ bottom_region_index_tensor: torch.Tensor,
113
+ spacing_tensor: torch.Tensor,
114
+ modality_tensor: torch.Tensor,
115
+ output_size: tuple,
116
+ divisor: int,
117
+ logger: logging.Logger,
118
+ ) -> np.ndarray:
119
+ """
120
+ Run the inference to generate synthetic images.
121
+
122
+ Args:
123
+ args (argparse.Namespace): Configuration arguments.
124
+ device (torch.device): Device to run inference on.
125
+ autoencoder (torch.nn.Module): Autoencoder model.
126
+ unet (torch.nn.Module): UNet model.
127
+ scale_factor (float): Scale factor for the model.
128
+ top_region_index_tensor (torch.Tensor): Top region index tensor.
129
+ bottom_region_index_tensor (torch.Tensor): Bottom region index tensor.
130
+ spacing_tensor (torch.Tensor): Spacing tensor.
131
+ modality_tensor (torch.Tensor): Modality tensor.
132
+ output_size (tuple): Output size of the synthetic image.
133
+ divisor (int): Divisor for downsample level.
134
+ logger (logging.Logger): Logger for logging information.
135
+
136
+ Returns:
137
+ np.ndarray: Generated synthetic image data.
138
+ """
139
+ include_body_region = unet.include_top_region_index_input
140
+ include_modality = unet.num_class_embeds is not None
141
+
142
+ noise = torch.randn(
143
+ (
144
+ 1,
145
+ args.latent_channels,
146
+ output_size[0] // divisor,
147
+ output_size[1] // divisor,
148
+ output_size[2] // divisor,
149
+ ),
150
+ device=device,
151
+ )
152
+ logger.info(f"noise: {noise.device}, {noise.dtype}, {type(noise)}")
153
+
154
+ image = noise
155
+ noise_scheduler = define_instance(args, "noise_scheduler")
156
+ if isinstance(noise_scheduler, RFlowScheduler):
157
+ noise_scheduler.set_timesteps(
158
+ num_inference_steps=args.diffusion_unet_inference["num_inference_steps"],
159
+ input_img_size_numel=torch.prod(torch.tensor(noise.shape[2:])),
160
+ )
161
+ else:
162
+ noise_scheduler.set_timesteps(num_inference_steps=args.diffusion_unet_inference["num_inference_steps"])
163
+
164
+ recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device)
165
+ autoencoder.eval()
166
+ unet.eval()
167
+
168
+ all_timesteps = noise_scheduler.timesteps
169
+ all_next_timesteps = torch.cat((all_timesteps[1:], torch.tensor([0], dtype=all_timesteps.dtype)))
170
+ progress_bar = tqdm(
171
+ zip(all_timesteps, all_next_timesteps),
172
+ total=min(len(all_timesteps), len(all_next_timesteps)),
173
+ )
174
+ with torch.amp.autocast("cuda", enabled=True):
175
+ for t, next_t in progress_bar:
176
+ # Create a dictionary to store the inputs
177
+ unet_inputs = {
178
+ "x": image,
179
+ "timesteps": torch.Tensor((t,)).to(device),
180
+ "spacing_tensor": spacing_tensor,
181
+ }
182
+
183
+ # Add extra arguments if include_body_region is True
184
+ if include_body_region:
185
+ unet_inputs.update(
186
+ {
187
+ "top_region_index_tensor": top_region_index_tensor,
188
+ "bottom_region_index_tensor": bottom_region_index_tensor,
189
+ }
190
+ )
191
+
192
+ if include_modality:
193
+ unet_inputs.update(
194
+ {
195
+ "class_labels": modality_tensor,
196
+ }
197
+ )
198
+ model_output = unet(**unet_inputs)
199
+ if not isinstance(noise_scheduler, RFlowScheduler):
200
+ image, _ = noise_scheduler.step(model_output, t, image) # type: ignore
201
+ else:
202
+ image, _ = noise_scheduler.step(model_output, t, image, next_t) # type: ignore
203
+
204
+ inferer = SlidingWindowInferer(
205
+ roi_size=[80, 80, 80],
206
+ sw_batch_size=1,
207
+ progress=True,
208
+ mode="gaussian",
209
+ overlap=0.4,
210
+ sw_device=device,
211
+ device=device,
212
+ )
213
+ synthetic_images = dynamic_infer(inferer, recon_model, image)
214
+ data = synthetic_images.squeeze().cpu().detach().numpy()
215
+ a_min, a_max, b_min, b_max = -1000, 1000, 0, 1
216
+ data = (data - b_min) / (b_max - b_min) * (a_max - a_min) + a_min
217
+ data = np.clip(data, a_min, a_max)
218
+ return np.int16(data)
219
+
220
+
221
+ def save_image(
222
+ data: np.ndarray,
223
+ output_size: tuple,
224
+ out_spacing: tuple,
225
+ output_path: str,
226
+ logger: logging.Logger,
227
+ ) -> None:
228
+ """
229
+ Save the generated synthetic image to a file.
230
+
231
+ Args:
232
+ data (np.ndarray): Synthetic image data.
233
+ output_size (tuple): Output size of the image.
234
+ out_spacing (tuple): Spacing of the output image.
235
+ output_path (str): Path to save the output image.
236
+ logger (logging.Logger): Logger for logging information.
237
+ """
238
+ out_affine = np.eye(4)
239
+ for i in range(3):
240
+ out_affine[i, i] = out_spacing[i]
241
+
242
+ new_image = nib.Nifti1Image(data, affine=out_affine)
243
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
244
+ nib.save(new_image, output_path)
245
+ logger.info(f"Saved {output_path}.")
246
+
247
+
248
+ @torch.inference_mode()
249
+ def diff_model_infer(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
250
+ """
251
+ Main function to run the diffusion model inference.
252
+
253
+ Args:
254
+ env_config_path (str): Path to the environment configuration file.
255
+ model_config_path (str): Path to the model configuration file.
256
+ model_def_path (str): Path to the model definition file.
257
+ """
258
+ args = load_config(env_config_path, model_config_path, model_def_path)
259
+ local_rank, world_size, device = initialize_distributed(num_gpus)
260
+ logger = setup_logging("inference")
261
+ random_seed = set_random_seed(
262
+ args.diffusion_unet_inference["random_seed"] + local_rank
263
+ if args.diffusion_unet_inference["random_seed"]
264
+ else None
265
+ )
266
+ logger.info(f"Using {device} of {world_size} with random seed: {random_seed}")
267
+
268
+ output_size = tuple(args.diffusion_unet_inference["dim"])
269
+ out_spacing = tuple(args.diffusion_unet_inference["spacing"])
270
+ output_prefix = args.output_prefix
271
+ ckpt_filepath = f"{args.model_dir}/{args.model_filename}"
272
+
273
+ if local_rank == 0:
274
+ logger.info(f"[config] ckpt_filepath -> {ckpt_filepath}.")
275
+ logger.info(f"[config] random_seed -> {random_seed}.")
276
+ logger.info(f"[config] output_prefix -> {output_prefix}.")
277
+ logger.info(f"[config] output_size -> {output_size}.")
278
+ logger.info(f"[config] out_spacing -> {out_spacing}.")
279
+
280
+ check_input(None, None, None, output_size, out_spacing, None)
281
+
282
+ autoencoder, unet, scale_factor = load_models(args, device, logger)
283
+ num_downsample_level = max(
284
+ 1,
285
+ (
286
+ len(args.diffusion_unet_def["num_channels"])
287
+ if isinstance(args.diffusion_unet_def["num_channels"], list)
288
+ else len(args.diffusion_unet_def["attention_levels"])
289
+ ),
290
+ )
291
+ divisor = 2 ** (num_downsample_level - 2)
292
+ logger.info(f"num_downsample_level -> {num_downsample_level}, divisor -> {divisor}.")
293
+
294
+ top_region_index_tensor, bottom_region_index_tensor, spacing_tensor, modality_tensor = prepare_tensors(args, device)
295
+ data = run_inference(
296
+ args,
297
+ device,
298
+ autoencoder,
299
+ unet,
300
+ scale_factor,
301
+ top_region_index_tensor,
302
+ bottom_region_index_tensor,
303
+ spacing_tensor,
304
+ modality_tensor,
305
+ output_size,
306
+ divisor,
307
+ logger,
308
+ )
309
+
310
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
311
+ output_path = "{0}/{1}_seed{2}_size{3:d}x{4:d}x{5:d}_spacing{6:.2f}x{7:.2f}x{8:.2f}_{9}_rank{10}.nii.gz".format(
312
+ args.output_dir,
313
+ output_prefix,
314
+ random_seed,
315
+ output_size[0],
316
+ output_size[1],
317
+ output_size[2],
318
+ out_spacing[0],
319
+ out_spacing[1],
320
+ out_spacing[2],
321
+ timestamp,
322
+ local_rank,
323
+ )
324
+ save_image(data, output_size, out_spacing, output_path, logger)
325
+
326
+ if dist.is_initialized():
327
+ dist.destroy_process_group()
328
+
329
+
330
+ if __name__ == "__main__":
331
+ parser = argparse.ArgumentParser(description="Diffusion Model Inference")
332
+ parser.add_argument(
333
+ "--env_config",
334
+ type=str,
335
+ default="./configs/environment_maisi_diff_model_train.json",
336
+ help="Path to environment configuration file",
337
+ )
338
+ parser.add_argument(
339
+ "--model_config",
340
+ type=str,
341
+ default="./configs/config_maisi_diff_model_train.json",
342
+ help="Path to model training/inference configuration",
343
+ )
344
+ parser.add_argument(
345
+ "--model_def",
346
+ type=str,
347
+ default="./configs/config_maisi.json",
348
+ help="Path to model definition file",
349
+ )
350
+ parser.add_argument(
351
+ "--num_gpus",
352
+ type=int,
353
+ default=1,
354
+ help="Number of GPUs to use for distributed inference",
355
+ )
356
+
357
+ args = parser.parse_args()
358
+ diff_model_infer(args.env_config, args.model_config, args.model_def, args.num_gpus)
scripts/diff_model_setting.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import json
16
+ import logging
17
+
18
+ import torch
19
+ import torch.distributed as dist
20
+ from monai.utils import RankFilter
21
+
22
+
23
+ def setup_logging(logger_name: str = "") -> logging.Logger:
24
+ """
25
+ Setup the logging configuration.
26
+
27
+ Args:
28
+ logger_name (str): logger name.
29
+
30
+ Returns:
31
+ logging.Logger: Configured logger.
32
+ """
33
+ logger = logging.getLogger(logger_name)
34
+ if dist.is_initialized():
35
+ logger.addFilter(RankFilter())
36
+ logging.basicConfig(
37
+ level=logging.INFO,
38
+ format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
39
+ datefmt="%Y-%m-%d %H:%M:%S",
40
+ )
41
+ return logger
42
+
43
+
44
+ def load_config(env_config_path: str, model_config_path: str, model_def_path: str) -> argparse.Namespace:
45
+ """
46
+ Load configuration from JSON files.
47
+
48
+ Args:
49
+ env_config_path (str): Path to the environment configuration file.
50
+ model_config_path (str): Path to the model configuration file.
51
+ model_def_path (str): Path to the model definition file.
52
+
53
+ Returns:
54
+ argparse.Namespace: Loaded configuration.
55
+ """
56
+ args = argparse.Namespace()
57
+
58
+ with open(env_config_path, "r") as f:
59
+ env_config = json.load(f)
60
+ for k, v in env_config.items():
61
+ setattr(args, k, v)
62
+
63
+ with open(model_config_path, "r") as f:
64
+ model_config = json.load(f)
65
+ for k, v in model_config.items():
66
+ setattr(args, k, v)
67
+
68
+ with open(model_def_path, "r") as f:
69
+ model_def = json.load(f)
70
+ for k, v in model_def.items():
71
+ setattr(args, k, v)
72
+
73
+ return args
74
+
75
+
76
+ def initialize_distributed(num_gpus: int) -> tuple:
77
+ """
78
+ Initialize distributed training.
79
+
80
+ Returns:
81
+ tuple: local_rank, world_size, and device.
82
+ """
83
+ if torch.cuda.is_available() and num_gpus > 1:
84
+ dist.init_process_group(backend="nccl", init_method="env://")
85
+ local_rank = dist.get_rank()
86
+ world_size = dist.get_world_size()
87
+ else:
88
+ local_rank = 0
89
+ world_size = 1
90
+ device = torch.device("cuda", local_rank)
91
+ torch.cuda.set_device(device)
92
+ return local_rank, world_size, device
scripts/diff_model_train.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import json
16
+ import logging
17
+ import os
18
+ from datetime import datetime
19
+ from pathlib import Path
20
+
21
+ import monai
22
+ import torch
23
+ import torch.distributed as dist
24
+ from monai.data import DataLoader, partition_dataset
25
+ from monai.networks.schedulers import RFlowScheduler
26
+ from monai.networks.schedulers.ddpm import DDPMPredictionType
27
+ from monai.transforms import Compose
28
+ from monai.utils import first
29
+ from torch.amp import GradScaler, autocast
30
+ from torch.nn.parallel import DistributedDataParallel
31
+
32
+ from .diff_model_setting import initialize_distributed, load_config, setup_logging
33
+ from .utils import define_instance
34
+
35
+
36
+ def load_filenames(data_list_path: str) -> list:
37
+ """
38
+ Load filenames from the JSON data list.
39
+
40
+ Args:
41
+ data_list_path (str): Path to the JSON data list file.
42
+
43
+ Returns:
44
+ list: List of filenames.
45
+ """
46
+ with open(data_list_path, "r") as file:
47
+ json_data = json.load(file)
48
+ filenames_train = json_data["training"]
49
+ return [_item["image"].replace(".nii.gz", "_emb.nii.gz") for _item in filenames_train]
50
+
51
+
52
+ def prepare_data(
53
+ train_files: list,
54
+ device: torch.device,
55
+ cache_rate: float,
56
+ num_workers: int = 2,
57
+ batch_size: int = 1,
58
+ include_body_region: bool = False,
59
+ ) -> DataLoader:
60
+ """
61
+ Prepare training data.
62
+
63
+ Args:
64
+ train_files (list): List of training files.
65
+ device (torch.device): Device to use for training.
66
+ cache_rate (float): Cache rate for dataset.
67
+ num_workers (int): Number of workers for data loading.
68
+ batch_size (int): Mini-batch size.
69
+ include_body_region (bool): Whether to include body region in data
70
+
71
+ Returns:
72
+ DataLoader: Data loader for training.
73
+ """
74
+
75
+ def _load_data_from_file(file_path, key):
76
+ with open(file_path) as f:
77
+ return torch.FloatTensor(json.load(f)[key])
78
+
79
+ train_transforms_list = [
80
+ monai.transforms.LoadImaged(keys=["image"]),
81
+ monai.transforms.EnsureChannelFirstd(keys=["image"]),
82
+ monai.transforms.Lambdad(keys="spacing", func=lambda x: _load_data_from_file(x, "spacing")),
83
+ monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2),
84
+ ]
85
+ if include_body_region:
86
+ train_transforms_list += [
87
+ monai.transforms.Lambdad(
88
+ keys="top_region_index", func=lambda x: _load_data_from_file(x, "top_region_index")
89
+ ),
90
+ monai.transforms.Lambdad(
91
+ keys="bottom_region_index", func=lambda x: _load_data_from_file(x, "bottom_region_index")
92
+ ),
93
+ monai.transforms.Lambdad(keys="top_region_index", func=lambda x: x * 1e2),
94
+ monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: x * 1e2),
95
+ ]
96
+ train_transforms = Compose(train_transforms_list)
97
+
98
+ train_ds = monai.data.CacheDataset(
99
+ data=train_files, transform=train_transforms, cache_rate=cache_rate, num_workers=num_workers
100
+ )
101
+
102
+ return DataLoader(train_ds, num_workers=6, batch_size=batch_size, shuffle=True)
103
+
104
+
105
+ def load_unet(args: argparse.Namespace, device: torch.device, logger: logging.Logger) -> torch.nn.Module:
106
+ """
107
+ Load the UNet model.
108
+
109
+ Args:
110
+ args (argparse.Namespace): Configuration arguments.
111
+ device (torch.device): Device to load the model on.
112
+ logger (logging.Logger): Logger for logging information.
113
+
114
+ Returns:
115
+ torch.nn.Module: Loaded UNet model.
116
+ """
117
+ unet = define_instance(args, "diffusion_unet_def").to(device)
118
+ unet = torch.nn.SyncBatchNorm.convert_sync_batchnorm(unet)
119
+
120
+ if dist.is_initialized():
121
+ unet = DistributedDataParallel(unet, device_ids=[device], find_unused_parameters=True)
122
+
123
+ if args.existing_ckpt_filepath is None:
124
+ logger.info("Training from scratch.")
125
+ else:
126
+ checkpoint_unet = torch.load(f"{args.existing_ckpt_filepath}", map_location=device, weights_only=False)
127
+ if dist.is_initialized():
128
+ unet.module.load_state_dict(checkpoint_unet["unet_state_dict"], strict=True)
129
+ else:
130
+ unet.load_state_dict(checkpoint_unet["unet_state_dict"], strict=True)
131
+ logger.info(f"Pretrained checkpoint {args.existing_ckpt_filepath} loaded.")
132
+
133
+ return unet
134
+
135
+
136
+ def calculate_scale_factor(train_loader: DataLoader, device: torch.device, logger: logging.Logger) -> torch.Tensor:
137
+ """
138
+ Calculate the scaling factor for the dataset.
139
+
140
+ Args:
141
+ train_loader (DataLoader): Data loader for training.
142
+ device (torch.device): Device to use for calculation.
143
+ logger (logging.Logger): Logger for logging information.
144
+
145
+ Returns:
146
+ torch.Tensor: Calculated scaling factor.
147
+ """
148
+ check_data = first(train_loader)
149
+ z = check_data["image"].to(device)
150
+ scale_factor = 1 / torch.std(z)
151
+ logger.info(f"Scaling factor set to {scale_factor}.")
152
+
153
+ if dist.is_initialized():
154
+ dist.barrier()
155
+ dist.all_reduce(scale_factor, op=torch.distributed.ReduceOp.AVG)
156
+ logger.info(f"scale_factor -> {scale_factor}.")
157
+ return scale_factor
158
+
159
+
160
+ def create_optimizer(model: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
161
+ """
162
+ Create optimizer for training.
163
+
164
+ Args:
165
+ model (torch.nn.Module): Model to optimize.
166
+ lr (float): Learning rate.
167
+
168
+ Returns:
169
+ torch.optim.Optimizer: Created optimizer.
170
+ """
171
+ return torch.optim.Adam(params=model.parameters(), lr=lr)
172
+
173
+
174
+ def create_lr_scheduler(optimizer: torch.optim.Optimizer, total_steps: int) -> torch.optim.lr_scheduler.PolynomialLR:
175
+ """
176
+ Create learning rate scheduler.
177
+
178
+ Args:
179
+ optimizer (torch.optim.Optimizer): Optimizer to schedule.
180
+ total_steps (int): Total number of training steps.
181
+
182
+ Returns:
183
+ torch.optim.lr_scheduler.PolynomialLR: Created learning rate scheduler.
184
+ """
185
+ return torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_steps, power=2.0)
186
+
187
+
188
+ def train_one_epoch(
189
+ epoch: int,
190
+ unet: torch.nn.Module,
191
+ train_loader: DataLoader,
192
+ optimizer: torch.optim.Optimizer,
193
+ lr_scheduler: torch.optim.lr_scheduler.PolynomialLR,
194
+ loss_pt: torch.nn.L1Loss,
195
+ scaler: GradScaler,
196
+ scale_factor: torch.Tensor,
197
+ noise_scheduler: torch.nn.Module,
198
+ num_images_per_batch: int,
199
+ num_train_timesteps: int,
200
+ device: torch.device,
201
+ logger: logging.Logger,
202
+ local_rank: int,
203
+ amp: bool = True,
204
+ ) -> torch.Tensor:
205
+ """
206
+ Train the model for one epoch.
207
+
208
+ Args:
209
+ epoch (int): Current epoch number.
210
+ unet (torch.nn.Module): UNet model.
211
+ train_loader (DataLoader): Data loader for training.
212
+ optimizer (torch.optim.Optimizer): Optimizer.
213
+ lr_scheduler (torch.optim.lr_scheduler.PolynomialLR): Learning rate scheduler.
214
+ loss_pt (torch.nn.L1Loss): Loss function.
215
+ scaler (GradScaler): Gradient scaler for mixed precision training.
216
+ scale_factor (torch.Tensor): Scaling factor.
217
+ noise_scheduler (torch.nn.Module): Noise scheduler.
218
+ num_images_per_batch (int): Number of images per batch.
219
+ num_train_timesteps (int): Number of training timesteps.
220
+ device (torch.device): Device to use for training.
221
+ logger (logging.Logger): Logger for logging information.
222
+ local_rank (int): Local rank for distributed training.
223
+ amp (bool): Use automatic mixed precision training.
224
+
225
+ Returns:
226
+ torch.Tensor: Training loss for the epoch.
227
+ """
228
+ include_body_region = unet.include_top_region_index_input
229
+ include_modality = unet.num_class_embeds is not None
230
+
231
+ if local_rank == 0:
232
+ current_lr = optimizer.param_groups[0]["lr"]
233
+ logger.info(f"Epoch {epoch + 1}, lr {current_lr}.")
234
+
235
+ _iter = 0
236
+ loss_torch = torch.zeros(2, dtype=torch.float, device=device)
237
+
238
+ unet.train()
239
+ for train_data in train_loader:
240
+ current_lr = optimizer.param_groups[0]["lr"]
241
+
242
+ _iter += 1
243
+ images = train_data["image"].to(device)
244
+ images = images * scale_factor
245
+
246
+ if include_body_region:
247
+ top_region_index_tensor = train_data["top_region_index"].to(device)
248
+ bottom_region_index_tensor = train_data["bottom_region_index"].to(device)
249
+ # We trained with only CT in this version
250
+ if include_modality:
251
+ modality_tensor = torch.ones((len(images),), dtype=torch.long).to(device)
252
+ spacing_tensor = train_data["spacing"].to(device)
253
+
254
+ optimizer.zero_grad(set_to_none=True)
255
+
256
+ with autocast("cuda", enabled=amp):
257
+ noise = torch.randn_like(images)
258
+
259
+ if isinstance(noise_scheduler, RFlowScheduler):
260
+ timesteps = noise_scheduler.sample_timesteps(images)
261
+ else:
262
+ timesteps = torch.randint(0, num_train_timesteps, (images.shape[0],), device=images.device).long()
263
+
264
+ noisy_latent = noise_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps)
265
+
266
+ # Create a dictionary to store the inputs
267
+ unet_inputs = {
268
+ "x": noisy_latent,
269
+ "timesteps": timesteps,
270
+ "spacing_tensor": spacing_tensor,
271
+ }
272
+ # Add extra arguments if include_body_region is True
273
+ if include_body_region:
274
+ unet_inputs.update(
275
+ {
276
+ "top_region_index_tensor": top_region_index_tensor,
277
+ "bottom_region_index_tensor": bottom_region_index_tensor,
278
+ }
279
+ )
280
+ if include_modality:
281
+ unet_inputs.update(
282
+ {
283
+ "class_labels": modality_tensor,
284
+ }
285
+ )
286
+ model_output = unet(**unet_inputs)
287
+
288
+ if noise_scheduler.prediction_type == DDPMPredictionType.EPSILON:
289
+ # predict noise
290
+ model_gt = noise
291
+ elif noise_scheduler.prediction_type == DDPMPredictionType.SAMPLE:
292
+ # predict sample
293
+ model_gt = images
294
+ elif noise_scheduler.prediction_type == DDPMPredictionType.V_PREDICTION:
295
+ # predict velocity
296
+ model_gt = images - noise
297
+ else:
298
+ raise ValueError(
299
+ "noise scheduler prediction type has to be chosen from ",
300
+ f"[{DDPMPredictionType.EPSILON},{DDPMPredictionType.SAMPLE},{DDPMPredictionType.V_PREDICTION}]",
301
+ )
302
+
303
+ loss = loss_pt(model_output.float(), model_gt.float())
304
+
305
+ if amp:
306
+ scaler.scale(loss).backward()
307
+ scaler.step(optimizer)
308
+ scaler.update()
309
+ else:
310
+ loss.backward()
311
+ optimizer.step()
312
+
313
+ lr_scheduler.step()
314
+
315
+ loss_torch[0] += loss.item()
316
+ loss_torch[1] += 1.0
317
+
318
+ if local_rank == 0:
319
+ logger.info(
320
+ "[{0}] epoch {1}, iter {2}/{3}, loss: {4:.4f}, lr: {5:.12f}.".format(
321
+ str(datetime.now())[:19], epoch + 1, _iter, len(train_loader), loss.item(), current_lr
322
+ )
323
+ )
324
+
325
+ if dist.is_initialized():
326
+ dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM)
327
+
328
+ return loss_torch
329
+
330
+
331
+ def save_checkpoint(
332
+ epoch: int,
333
+ unet: torch.nn.Module,
334
+ loss_torch_epoch: float,
335
+ num_train_timesteps: int,
336
+ scale_factor: torch.Tensor,
337
+ ckpt_folder: str,
338
+ args: argparse.Namespace,
339
+ ) -> None:
340
+ """
341
+ Save checkpoint.
342
+
343
+ Args:
344
+ epoch (int): Current epoch number.
345
+ unet (torch.nn.Module): UNet model.
346
+ loss_torch_epoch (float): Training loss for the epoch.
347
+ num_train_timesteps (int): Number of training timesteps.
348
+ scale_factor (torch.Tensor): Scaling factor.
349
+ ckpt_folder (str): Checkpoint folder path.
350
+ args (argparse.Namespace): Configuration arguments.
351
+ """
352
+ unet_state_dict = unet.module.state_dict() if dist.is_initialized() else unet.state_dict()
353
+ torch.save(
354
+ {
355
+ "epoch": epoch + 1,
356
+ "loss": loss_torch_epoch,
357
+ "num_train_timesteps": num_train_timesteps,
358
+ "scale_factor": scale_factor,
359
+ "unet_state_dict": unet_state_dict,
360
+ },
361
+ f"{ckpt_folder}/{args.model_filename}",
362
+ )
363
+
364
+
365
+ def diff_model_train(
366
+ env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int, amp: bool = True
367
+ ) -> None:
368
+ """
369
+ Main function to train a diffusion model.
370
+
371
+ Args:
372
+ env_config_path (str): Path to the environment configuration file.
373
+ model_config_path (str): Path to the model configuration file.
374
+ model_def_path (str): Path to the model definition file.
375
+ num_gpus (int): Number of GPUs to use for training.
376
+ amp (bool): Use automatic mixed precision training.
377
+ """
378
+ args = load_config(env_config_path, model_config_path, model_def_path)
379
+ local_rank, world_size, device = initialize_distributed(num_gpus)
380
+ logger = setup_logging("training")
381
+
382
+ logger.info(f"Using {device} of {world_size}")
383
+
384
+ if local_rank == 0:
385
+ logger.info(f"[config] ckpt_folder -> {args.model_dir}.")
386
+ logger.info(f"[config] data_root -> {args.embedding_base_dir}.")
387
+ logger.info(f"[config] data_list -> {args.json_data_list}.")
388
+ logger.info(f"[config] lr -> {args.diffusion_unet_train['lr']}.")
389
+ logger.info(f"[config] num_epochs -> {args.diffusion_unet_train['n_epochs']}.")
390
+ logger.info(f"[config] num_train_timesteps -> {args.noise_scheduler['num_train_timesteps']}.")
391
+
392
+ Path(args.model_dir).mkdir(parents=True, exist_ok=True)
393
+
394
+ unet = load_unet(args, device, logger)
395
+ noise_scheduler = define_instance(args, "noise_scheduler")
396
+ include_body_region = unet.include_top_region_index_input
397
+
398
+ filenames_train = load_filenames(args.json_data_list)
399
+ if local_rank == 0:
400
+ logger.info(f"num_files_train: {len(filenames_train)}")
401
+
402
+ train_files = []
403
+ for _i in range(len(filenames_train)):
404
+ str_img = os.path.join(args.embedding_base_dir, filenames_train[_i])
405
+ if not os.path.exists(str_img):
406
+ continue
407
+
408
+ str_info = os.path.join(args.embedding_base_dir, filenames_train[_i]) + ".json"
409
+ train_files_i = {"image": str_img, "spacing": str_info}
410
+ if include_body_region:
411
+ train_files_i["top_region_index"] = str_info
412
+ train_files_i["bottom_region_index"] = str_info
413
+ train_files.append(train_files_i)
414
+ if dist.is_initialized():
415
+ train_files = partition_dataset(
416
+ data=train_files, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True
417
+ )[local_rank]
418
+
419
+ train_loader = prepare_data(
420
+ train_files,
421
+ device,
422
+ args.diffusion_unet_train["cache_rate"],
423
+ batch_size=args.diffusion_unet_train["batch_size"],
424
+ include_body_region=include_body_region,
425
+ )
426
+
427
+ scale_factor = calculate_scale_factor(train_loader, device, logger)
428
+ optimizer = create_optimizer(unet, args.diffusion_unet_train["lr"])
429
+
430
+ total_steps = (args.diffusion_unet_train["n_epochs"] * len(train_loader.dataset)) / args.diffusion_unet_train[
431
+ "batch_size"
432
+ ]
433
+ lr_scheduler = create_lr_scheduler(optimizer, total_steps)
434
+ loss_pt = torch.nn.L1Loss()
435
+ scaler = GradScaler("cuda")
436
+
437
+ torch.set_float32_matmul_precision("highest")
438
+ logger.info("torch.set_float32_matmul_precision -> highest.")
439
+
440
+ for epoch in range(args.diffusion_unet_train["n_epochs"]):
441
+ loss_torch = train_one_epoch(
442
+ epoch,
443
+ unet,
444
+ train_loader,
445
+ optimizer,
446
+ lr_scheduler,
447
+ loss_pt,
448
+ scaler,
449
+ scale_factor,
450
+ noise_scheduler,
451
+ args.diffusion_unet_train["batch_size"],
452
+ args.noise_scheduler["num_train_timesteps"],
453
+ device,
454
+ logger,
455
+ local_rank,
456
+ amp=amp,
457
+ )
458
+
459
+ loss_torch = loss_torch.tolist()
460
+ if torch.cuda.device_count() == 1 or local_rank == 0:
461
+ loss_torch_epoch = loss_torch[0] / loss_torch[1]
462
+ logger.info(f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}.")
463
+
464
+ save_checkpoint(
465
+ epoch,
466
+ unet,
467
+ loss_torch_epoch,
468
+ args.noise_scheduler["num_train_timesteps"],
469
+ scale_factor,
470
+ args.model_dir,
471
+ args,
472
+ )
473
+
474
+ if dist.is_initialized():
475
+ dist.destroy_process_group()
476
+
477
+
478
+ if __name__ == "__main__":
479
+ parser = argparse.ArgumentParser(description="Diffusion Model Training")
480
+ parser.add_argument(
481
+ "--env_config",
482
+ type=str,
483
+ default="./configs/environment_maisi_diff_model.json",
484
+ help="Path to environment configuration file",
485
+ )
486
+ parser.add_argument(
487
+ "--model_config",
488
+ type=str,
489
+ default="./configs/config_maisi_diff_model.json",
490
+ help="Path to model training/inference configuration",
491
+ )
492
+ parser.add_argument(
493
+ "--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
494
+ )
495
+ parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for training")
496
+ parser.add_argument("--no_amp", dest="amp", action="store_false", help="Disable automatic mixed precision training")
497
+
498
+ args = parser.parse_args()
499
+ diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus, args.amp)
scripts/find_masks.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+
13
+ import json
14
+ import os
15
+ from typing import Sequence
16
+
17
+ from monai.apps.utils import extractall
18
+ from monai.utils import ensure_tuple_rep
19
+
20
+
21
+ def convert_body_region(body_region: str | Sequence[str]) -> Sequence[int]:
22
+ """
23
+ Convert body region string to body region index.
24
+ Args:
25
+ body_region: list of input body region string. If single str, will be converted to list of str.
26
+ Return:
27
+ body_region_indices, list of input body region index.
28
+ """
29
+ if type(body_region) is str:
30
+ body_region = [body_region]
31
+
32
+ # body region mapping for maisi
33
+ region_mapping_maisi = {
34
+ "head": 0,
35
+ "chest": 1,
36
+ "thorax": 1,
37
+ "chest/thorax": 1,
38
+ "abdomen": 2,
39
+ "pelvis": 3,
40
+ "lower": 3,
41
+ "pelvis/lower": 3,
42
+ }
43
+
44
+ # perform mapping
45
+ body_region_indices = []
46
+ for region in body_region:
47
+ normalized_region = region.lower() # norm str to lower case
48
+ if normalized_region not in region_mapping_maisi:
49
+ raise ValueError(f"Invalid region: {normalized_region}")
50
+ body_region_indices.append(region_mapping_maisi[normalized_region])
51
+
52
+ return body_region_indices
53
+
54
+
55
+ def find_masks(
56
+ body_region: str | Sequence[str],
57
+ anatomy_list: int | Sequence[int],
58
+ spacing: Sequence[float] | float = 1.0,
59
+ output_size: Sequence[int] = [512, 512, 512],
60
+ check_spacing_and_output_size: bool = False,
61
+ database_filepath: str = "./configs/database.json",
62
+ mask_foldername: str = "./datasets/masks/",
63
+ ):
64
+ """
65
+ Find candidate masks that fullfills all the requirements.
66
+ They shoud contain all the body region in `body_region`, all the anatomies in `anatomy_list`.
67
+ If there is no tumor specified in `anatomy_list`, we also expect the candidate masks to be tumor free.
68
+ If check_spacing_and_output_size is True, the candidate masks need to have the expected `spacing` and `output_size`.
69
+ Args:
70
+ body_region: list of input body region string. If single str, will be converted to list of str.
71
+ The found candidate mask will include these body regions.
72
+ anatomy_list: list of input anatomy. The found candidate mask will include these anatomies.
73
+ spacing: list of three floats, voxel spacing. If providing a single number, will use it for all the three dimensions.
74
+ output_size: list of three int, expected candidate mask spatial size.
75
+ check_spacing_and_output_size: whether we expect candidate mask to have spatial size of `output_size` and voxel size of `spacing`.
76
+ database_filepath: path for the json file that stores the information of all the candidate masks.
77
+ mask_foldername: directory that saves all the candidate masks.
78
+ Return:
79
+ candidate_masks, list of dict, each dict contains information of one candidate mask that fullfills all the requirements.
80
+ """
81
+ # check and preprocess input
82
+ body_region = convert_body_region(body_region)
83
+
84
+ if isinstance(anatomy_list, int):
85
+ anatomy_list = [anatomy_list]
86
+
87
+ spacing = ensure_tuple_rep(spacing, 3)
88
+
89
+ if not os.path.exists(mask_foldername):
90
+ zip_file_path = mask_foldername + ".zip"
91
+
92
+ if not os.path.isfile(zip_file_path):
93
+ raise ValueError(f"Please download {zip_file_path} following the instruction in ./datasets/README.md.")
94
+
95
+ print(f"Extracting {zip_file_path} to {os.path.dirname(zip_file_path)}")
96
+ extractall(filepath=zip_file_path, output_dir=os.path.dirname(zip_file_path), file_type="zip")
97
+ print(f"Unzipped {zip_file_path} to {mask_foldername}.")
98
+
99
+ if not os.path.isfile(database_filepath):
100
+ raise ValueError(f"Please download {database_filepath} following the instruction in ./datasets/README.md.")
101
+ with open(database_filepath, "r") as f:
102
+ db = json.load(f)
103
+
104
+ # select candidate_masks
105
+ candidate_masks = []
106
+ for _item in db:
107
+ if not set(anatomy_list).issubset(_item["label_list"]):
108
+ continue
109
+
110
+ # whether to keep this mask, default to be True.
111
+ keep_mask = True
112
+
113
+ # extract region indice (top_index and bottom_index) for candidate mask
114
+ include_body_region = "top_region_index" in _item.keys()
115
+ if include_body_region:
116
+ top_index = [index for index, element in enumerate(_item["top_region_index"]) if element != 0]
117
+ top_index = top_index[0]
118
+ bottom_index = [index for index, element in enumerate(_item["bottom_region_index"]) if element != 0]
119
+ bottom_index = bottom_index[0]
120
+
121
+ # if candiate mask does not contain all the body_region, skip it
122
+ for _idx in body_region:
123
+ if _idx > bottom_index or _idx < top_index:
124
+ keep_mask = False
125
+
126
+ for tumor_label in [23, 24, 26, 27, 128]:
127
+ # we skip those mask with tumors if users do not provide tumor label in anatomy_list
128
+ if tumor_label not in anatomy_list and tumor_label in _item["label_list"]:
129
+ keep_mask = False
130
+
131
+ if check_spacing_and_output_size:
132
+ # if the output_size and spacing are different with user's input, skip it
133
+ for axis in range(3):
134
+ if _item["dim"][axis] != output_size[axis] or _item["spacing"][axis] != spacing[axis]:
135
+ keep_mask = False
136
+
137
+ if keep_mask:
138
+ # if decide to keep this mask, we pack the information of this mask and add to final output.
139
+ candidate = {
140
+ "pseudo_label": os.path.join(mask_foldername, _item["pseudo_label_filename"]),
141
+ "spacing": _item["spacing"],
142
+ "dim": _item["dim"],
143
+ }
144
+ if include_body_region:
145
+ candidate["top_region_index"] = _item["top_region_index"]
146
+ candidate["bottom_region_index"] = _item["bottom_region_index"]
147
+
148
+ # Conditionally add the label to the candidate dictionary
149
+ if "label_filename" in _item:
150
+ candidate["label"] = os.path.join(mask_foldername, _item["label_filename"])
151
+
152
+ candidate_masks.append(candidate)
153
+
154
+ if len(candidate_masks) == 0 and not check_spacing_and_output_size:
155
+ raise ValueError("Cannot find body region with given anatomy list.")
156
+
157
+ return candidate_masks
scripts/infer_controlnet.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import argparse
13
+ import json
14
+ import logging
15
+ import os
16
+ import sys
17
+ from datetime import datetime
18
+
19
+ import torch
20
+ import torch.distributed as dist
21
+ from monai.data import MetaTensor, decollate_batch
22
+ from monai.networks.utils import copy_model_state
23
+ from monai.transforms import SaveImage
24
+ from monai.utils import RankFilter
25
+
26
+ from .sample import check_input, ldm_conditional_sample_one_image
27
+ from .utils import define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp
28
+
29
+
30
+ @torch.inference_mode()
31
+ def main():
32
+ parser = argparse.ArgumentParser(description="maisi.controlnet.infer")
33
+ parser.add_argument(
34
+ "-e",
35
+ "--environment-file",
36
+ default="./configs/environment_maisi_controlnet_train.json",
37
+ help="environment json file that stores environment path",
38
+ )
39
+ parser.add_argument(
40
+ "-c",
41
+ "--config-file",
42
+ default="./configs/config_maisi.json",
43
+ help="config json file that stores network hyper-parameters",
44
+ )
45
+ parser.add_argument(
46
+ "-t",
47
+ "--training-config",
48
+ default="./configs/config_maisi_controlnet_train.json",
49
+ help="config json file that stores training hyper-parameters",
50
+ )
51
+ parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node")
52
+
53
+ args = parser.parse_args()
54
+
55
+ # Step 0: configuration
56
+ logger = logging.getLogger("maisi.controlnet.infer")
57
+ # whether to use distributed data parallel
58
+ use_ddp = args.gpus > 1
59
+ if use_ddp:
60
+ rank = int(os.environ["LOCAL_RANK"])
61
+ world_size = int(os.environ["WORLD_SIZE"])
62
+ device = setup_ddp(rank, world_size)
63
+ logger.addFilter(RankFilter())
64
+ else:
65
+ rank = 0
66
+ world_size = 1
67
+ device = torch.device(f"cuda:{rank}")
68
+
69
+ torch.cuda.set_device(device)
70
+ logger.info(f"Number of GPUs: {torch.cuda.device_count()}")
71
+ logger.info(f"World_size: {world_size}")
72
+
73
+ with open(args.environment_file, "r") as env_file:
74
+ env_dict = json.load(env_file)
75
+ with open(args.config_file, "r") as config_file:
76
+ config_dict = json.load(config_file)
77
+ with open(args.training_config, "r") as training_config_file:
78
+ training_config_dict = json.load(training_config_file)
79
+
80
+ for k, v in env_dict.items():
81
+ setattr(args, k, v)
82
+ for k, v in config_dict.items():
83
+ setattr(args, k, v)
84
+ for k, v in training_config_dict.items():
85
+ setattr(args, k, v)
86
+
87
+ # Step 1: set data loader
88
+ _, val_loader = prepare_maisi_controlnet_json_dataloader(
89
+ json_data_list=args.json_data_list,
90
+ data_base_dir=args.data_base_dir,
91
+ rank=rank,
92
+ world_size=world_size,
93
+ batch_size=args.controlnet_train["batch_size"],
94
+ cache_rate=args.controlnet_train["cache_rate"],
95
+ fold=args.controlnet_train["fold"],
96
+ )
97
+
98
+ # Step 2: define AE, diffusion model and controlnet
99
+ # define AE
100
+ autoencoder = define_instance(args, "autoencoder_def").to(device)
101
+ # load trained autoencoder model
102
+ if args.trained_autoencoder_path is not None:
103
+ if not os.path.exists(args.trained_autoencoder_path):
104
+ raise ValueError("Please download the autoencoder checkpoint.")
105
+ autoencoder_ckpt = torch.load(args.trained_autoencoder_path, weights_only=True)
106
+ autoencoder.load_state_dict(autoencoder_ckpt)
107
+ logger.info(f"Load trained diffusion model from {args.trained_autoencoder_path}.")
108
+ else:
109
+ logger.info("trained autoencoder model is not loaded.")
110
+
111
+ # define diffusion Model
112
+ unet = define_instance(args, "diffusion_unet_def").to(device)
113
+ include_body_region = unet.include_top_region_index_input
114
+ include_modality = unet.num_class_embeds is not None
115
+
116
+ # load trained diffusion model
117
+ if args.trained_diffusion_path is not None:
118
+ if not os.path.exists(args.trained_diffusion_path):
119
+ raise ValueError("Please download the trained diffusion unet checkpoint.")
120
+ diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device, weights_only=False)
121
+ unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"])
122
+ # load scale factor from diffusion model checkpoint
123
+ scale_factor = diffusion_model_ckpt["scale_factor"]
124
+ logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.")
125
+ logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.")
126
+ else:
127
+ logger.info("trained diffusion model is not loaded.")
128
+ scale_factor = 1.0
129
+ logger.info(f"set scale_factor -> {scale_factor}.")
130
+
131
+ # define ControlNet
132
+ controlnet = define_instance(args, "controlnet_def").to(device)
133
+ # copy weights from the DM to the controlnet
134
+ copy_model_state(controlnet, unet.state_dict())
135
+ # load trained controlnet model if it is provided
136
+ if args.trained_controlnet_path is not None:
137
+ if not os.path.exists(args.trained_controlnet_path):
138
+ raise ValueError("Please download the trained ControlNet checkpoint.")
139
+ controlnet.load_state_dict(
140
+ torch.load(args.trained_controlnet_path, map_location=device, weights_only=False)["controlnet_state_dict"]
141
+ )
142
+ logger.info(f"load trained controlnet model from {args.trained_controlnet_path}")
143
+ else:
144
+ logger.info("trained controlnet is not loaded.")
145
+
146
+ noise_scheduler = define_instance(args, "noise_scheduler")
147
+
148
+ # Step 3: inference
149
+ autoencoder.eval()
150
+ controlnet.eval()
151
+ unet.eval()
152
+
153
+ for batch in val_loader:
154
+
155
+ # get label mask
156
+ labels = batch["label"].to(device)
157
+ # get corresponding conditions
158
+ if include_body_region:
159
+ top_region_index_tensor = batch["top_region_index"].to(device)
160
+ bottom_region_index_tensor = batch["bottom_region_index"].to(device)
161
+ else:
162
+ top_region_index_tensor = None
163
+ bottom_region_index_tensor = None
164
+ spacing_tensor = batch["spacing"].to(device)
165
+ modality_tensor = args.controlnet_infer["modality"] * torch.ones((len(labels),), dtype=torch.long).to(device)
166
+ out_spacing = tuple((batch["spacing"].squeeze().numpy() / 100).tolist())
167
+ # get target dimension
168
+ dim = batch["dim"]
169
+ output_size = (dim[0].item(), dim[1].item(), dim[2].item())
170
+ latent_shape = (args.latent_channels, output_size[0] // 4, output_size[1] // 4, output_size[2] // 4)
171
+ # check if output_size and out_spacing are valid.
172
+ check_input(None, None, None, output_size, out_spacing, None)
173
+ # generate a single synthetic image using a latent diffusion model with controlnet.
174
+ synthetic_images, _ = ldm_conditional_sample_one_image(
175
+ autoencoder=autoencoder,
176
+ diffusion_unet=unet,
177
+ controlnet=controlnet,
178
+ noise_scheduler=noise_scheduler,
179
+ scale_factor=scale_factor,
180
+ device=device,
181
+ combine_label_or=labels,
182
+ top_region_index_tensor=top_region_index_tensor,
183
+ bottom_region_index_tensor=bottom_region_index_tensor,
184
+ spacing_tensor=spacing_tensor,
185
+ modality_tensor=modality_tensor,
186
+ latent_shape=latent_shape,
187
+ output_size=output_size,
188
+ noise_factor=1.0,
189
+ num_inference_steps=args.controlnet_infer["num_inference_steps"],
190
+ autoencoder_sliding_window_infer_size=args.controlnet_infer["autoencoder_sliding_window_infer_size"],
191
+ autoencoder_sliding_window_infer_overlap=args.controlnet_infer["autoencoder_sliding_window_infer_overlap"],
192
+ )
193
+ # save image/label pairs
194
+ labels = decollate_batch(batch)[0]["label"]
195
+ real_object_name = labels.meta.get("filename_or_obj", "default_name.nii.gz")
196
+ labels.meta["filename_or_obj"] = real_object_name
197
+ output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
198
+ synthetic_images = MetaTensor(synthetic_images.squeeze(0), meta=labels.meta)
199
+ img_saver = SaveImage(
200
+ output_dir=args.output_dir,
201
+ output_postfix="image",
202
+ separate_folder=False,
203
+ )
204
+ img_saver(synthetic_images)
205
+ label_saver = SaveImage(
206
+ output_dir=args.output_dir,
207
+ output_postfix="label",
208
+ separate_folder=False,
209
+ )
210
+ label_saver(labels)
211
+ if use_ddp:
212
+ dist.destroy_process_group()
213
+
214
+
215
+ if __name__ == "__main__":
216
+ logging.basicConfig(
217
+ stream=sys.stdout,
218
+ level=logging.INFO,
219
+ format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
220
+ datefmt="%Y-%m-%d %H:%M:%S",
221
+ )
222
+ main()
scripts/infer_testV2_controlnet.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import argparse
13
+ import json
14
+ import logging
15
+ import os
16
+ import sys
17
+ from datetime import datetime
18
+
19
+ import torch
20
+ import torch.distributed as dist
21
+ from monai.data import MetaTensor, decollate_batch
22
+ from monai.networks.utils import copy_model_state
23
+ from monai.transforms import SaveImage
24
+ from monai.utils import RankFilter
25
+
26
+ from .sample import check_input, ldm_conditional_sample_one_image
27
+ from .utils import define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp, prepare_maisi_controlnet_test_dataloader
28
+
29
+
30
+ @torch.inference_mode()
31
+ def main():
32
+ parser = argparse.ArgumentParser(description="maisi.controlnet.infer")
33
+ parser.add_argument(
34
+ "-e",
35
+ "--environment-file",
36
+ default="./configs/environment_maisi_controlnet_train.json",
37
+ help="environment json file that stores environment path",
38
+ )
39
+ parser.add_argument(
40
+ "-c",
41
+ "--config-file",
42
+ default="./configs/config_maisi.json",
43
+ help="config json file that stores network hyper-parameters",
44
+ )
45
+ parser.add_argument(
46
+ "-t",
47
+ "--training-config",
48
+ default="./configs/config_maisi_controlnet_train.json",
49
+ help="config json file that stores training hyper-parameters",
50
+ )
51
+ parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node")
52
+
53
+ args = parser.parse_args()
54
+
55
+ # Step 0: configuration
56
+ logger = logging.getLogger("maisi.controlnet.infer")
57
+ # whether to use distributed data parallel
58
+ use_ddp = args.gpus > 1
59
+ if use_ddp:
60
+ rank = int(os.environ["LOCAL_RANK"])
61
+ world_size = int(os.environ["WORLD_SIZE"])
62
+ device = setup_ddp(rank, world_size)
63
+ logger.addFilter(RankFilter())
64
+ else:
65
+ rank = 0
66
+ world_size = 1
67
+ device = torch.device(f"cuda:{rank}")
68
+
69
+ torch.cuda.set_device(device)
70
+ logger.info(f"Number of GPUs: {torch.cuda.device_count()}")
71
+ logger.info(f"World_size: {world_size}")
72
+
73
+ with open(args.environment_file, "r") as env_file:
74
+ env_dict = json.load(env_file)
75
+ with open(args.config_file, "r") as config_file:
76
+ config_dict = json.load(config_file)
77
+ with open(args.training_config, "r") as training_config_file:
78
+ training_config_dict = json.load(training_config_file)
79
+
80
+ for k, v in env_dict.items():
81
+ setattr(args, k, v)
82
+ for k, v in config_dict.items():
83
+ setattr(args, k, v)
84
+ for k, v in training_config_dict.items():
85
+ setattr(args, k, v)
86
+
87
+ # Step 1: set data loader
88
+ val_loader = prepare_maisi_controlnet_test_dataloader(
89
+ json_data_list=args.json_data_list,
90
+ data_base_dir=args.data_base_dir,
91
+ batch_size=args.controlnet_train["batch_size"],
92
+ cache_rate=args.controlnet_train["cache_rate"],
93
+ rank=rank,
94
+ world_size=world_size,)
95
+
96
+ # Step 2: define AE, diffusion model and controlnet
97
+ # define AE
98
+ autoencoder = define_instance(args, "autoencoder_def").to(device)
99
+ # load trained autoencoder model
100
+ if args.trained_autoencoder_path is not None:
101
+ if not os.path.exists(args.trained_autoencoder_path):
102
+ raise ValueError("Please download the autoencoder checkpoint.")
103
+ autoencoder_ckpt = torch.load(args.trained_autoencoder_path, weights_only=True)
104
+ autoencoder.load_state_dict(autoencoder_ckpt)
105
+ logger.info(f"Load trained diffusion model from {args.trained_autoencoder_path}.")
106
+ else:
107
+ logger.info("trained autoencoder model is not loaded.")
108
+
109
+ # define diffusion Model
110
+ unet = define_instance(args, "diffusion_unet_def").to(device)
111
+ include_body_region = unet.include_top_region_index_input
112
+ include_modality = unet.num_class_embeds is not None
113
+
114
+ # load trained diffusion model
115
+ if args.trained_diffusion_path is not None:
116
+ if not os.path.exists(args.trained_diffusion_path):
117
+ raise ValueError("Please download the trained diffusion unet checkpoint.")
118
+ diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device, weights_only=False)
119
+ unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"])
120
+ # load scale factor from diffusion model checkpoint
121
+ scale_factor = diffusion_model_ckpt["scale_factor"]
122
+ logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.")
123
+ logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.")
124
+ else:
125
+ logger.info("trained diffusion model is not loaded.")
126
+ scale_factor = 1.0
127
+ logger.info(f"set scale_factor -> {scale_factor}.")
128
+
129
+ # define ControlNet
130
+ controlnet = define_instance(args, "controlnet_def").to(device)
131
+ # copy weights from the DM to the controlnet
132
+ copy_model_state(controlnet, unet.state_dict())
133
+ # load trained controlnet model if it is provided
134
+ if args.trained_controlnet_path is not None:
135
+ if not os.path.exists(args.trained_controlnet_path):
136
+ raise ValueError("Please download the trained ControlNet checkpoint.")
137
+ controlnet.load_state_dict(
138
+ torch.load(args.trained_controlnet_path, map_location=device, weights_only=False)["controlnet_state_dict"]
139
+ )
140
+ logger.info(f"load trained controlnet model from {args.trained_controlnet_path}")
141
+ else:
142
+ logger.info("trained controlnet is not loaded.")
143
+
144
+ noise_scheduler = define_instance(args, "noise_scheduler")
145
+
146
+ # Step 3: inference
147
+ autoencoder.eval()
148
+ controlnet.eval()
149
+ unet.eval()
150
+
151
+ for batch in val_loader:
152
+
153
+ # get label mask
154
+ labels = batch["label"].to(device)
155
+ # get corresponding conditions
156
+ if include_body_region:
157
+ top_region_index_tensor = batch["top_region_index"].to(device)
158
+ bottom_region_index_tensor = batch["bottom_region_index"].to(device)
159
+ else:
160
+ top_region_index_tensor = None
161
+ bottom_region_index_tensor = None
162
+ spacing_tensor = batch["spacing"].to(device)
163
+ modality_tensor = args.controlnet_infer["modality"] * torch.ones((len(labels),), dtype=torch.long).to(device)
164
+ out_spacing = tuple((batch["spacing"].squeeze().numpy() / 100).tolist())
165
+ # get target dimension
166
+ dim = batch["dim"]
167
+ output_size = (dim[0].item(), dim[1].item(), dim[2].item())
168
+ latent_shape = (args.latent_channels, output_size[0] // 4, output_size[1] // 4, output_size[2] // 4)
169
+ # check if output_size and out_spacing are valid.
170
+ check_input(None, None, None, output_size, out_spacing, None)
171
+ # generate a single synthetic image using a latent diffusion model with controlnet.
172
+ synthetic_images, _ = ldm_conditional_sample_one_image(
173
+ autoencoder=autoencoder,
174
+ diffusion_unet=unet,
175
+ controlnet=controlnet,
176
+ noise_scheduler=noise_scheduler,
177
+ scale_factor=scale_factor,
178
+ device=device,
179
+ combine_label_or=labels,
180
+ top_region_index_tensor=top_region_index_tensor,
181
+ bottom_region_index_tensor=bottom_region_index_tensor,
182
+ spacing_tensor=spacing_tensor,
183
+ modality_tensor=modality_tensor,
184
+ latent_shape=latent_shape,
185
+ output_size=output_size,
186
+ noise_factor=1.0,
187
+ num_inference_steps=args.controlnet_infer["num_inference_steps"],
188
+ autoencoder_sliding_window_infer_size=args.controlnet_infer["autoencoder_sliding_window_infer_size"],
189
+ autoencoder_sliding_window_infer_overlap=args.controlnet_infer["autoencoder_sliding_window_infer_overlap"],
190
+ )
191
+ # save image/label pairs
192
+ labels = decollate_batch(batch)[0]["label"]
193
+ real_object_name = labels.meta.get("filename_or_obj", "default_name.nii.gz")
194
+ labels.meta["filename_or_obj"] = real_object_name
195
+ output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
196
+ synthetic_images = MetaTensor(synthetic_images.squeeze(0), meta=labels.meta)
197
+ img_saver = SaveImage(
198
+ output_dir=args.output_dir,
199
+ output_postfix="image",
200
+ separate_folder=False,
201
+ )
202
+ img_saver(synthetic_images)
203
+ label_saver = SaveImage(
204
+ output_dir=args.output_dir,
205
+ output_postfix="label",
206
+ separate_folder=False,
207
+ )
208
+ label_saver(labels)
209
+ if use_ddp:
210
+ dist.destroy_process_group()
211
+
212
+
213
+ if __name__ == "__main__":
214
+ logging.basicConfig(
215
+ stream=sys.stdout,
216
+ level=logging.INFO,
217
+ format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
218
+ datefmt="%Y-%m-%d %H:%M:%S",
219
+ )
220
+ main()
scripts/infer_test_controlnet.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import argparse
13
+ import json
14
+ import logging
15
+ import os
16
+ import sys
17
+ from datetime import datetime
18
+
19
+ import torch
20
+ import torch.distributed as dist
21
+ from monai.data import MetaTensor, decollate_batch
22
+ from monai.networks.utils import copy_model_state
23
+ from monai.transforms import SaveImage
24
+ from monai.utils import RankFilter
25
+
26
+ from .sample import check_input, ldm_conditional_sample_one_image
27
+ from .utils import define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp, prepare_maisi_controlnet_infer_dataloader
28
+
29
+
30
+ @torch.inference_mode()
31
+ def main():
32
+ parser = argparse.ArgumentParser(description="maisi.controlnet.infer")
33
+ parser.add_argument(
34
+ "-e",
35
+ "--environment-file",
36
+ default="./configs/environment_maisi_controlnet_train.json",
37
+ help="environment json file that stores environment path",
38
+ )
39
+ parser.add_argument(
40
+ "-c",
41
+ "--config-file",
42
+ default="./configs/config_maisi.json",
43
+ help="config json file that stores network hyper-parameters",
44
+ )
45
+ parser.add_argument(
46
+ "-t",
47
+ "--training-config",
48
+ default="./configs/config_maisi_controlnet_train.json",
49
+ help="config json file that stores training hyper-parameters",
50
+ )
51
+ parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node")
52
+
53
+ args = parser.parse_args()
54
+
55
+ # Step 0: configuration
56
+ logger = logging.getLogger("maisi.controlnet.infer")
57
+ # whether to use distributed data parallel
58
+ use_ddp = args.gpus > 1
59
+ if use_ddp:
60
+ rank = int(os.environ["LOCAL_RANK"])
61
+ world_size = int(os.environ["WORLD_SIZE"])
62
+ device = setup_ddp(rank, world_size)
63
+ logger.addFilter(RankFilter())
64
+ else:
65
+ rank = 0
66
+ world_size = 1
67
+ device = torch.device(f"cuda:{rank}")
68
+
69
+ torch.cuda.set_device(device)
70
+ logger.info(f"Number of GPUs: {torch.cuda.device_count()}")
71
+ logger.info(f"World_size: {world_size}")
72
+
73
+ with open(args.environment_file, "r") as env_file:
74
+ env_dict = json.load(env_file)
75
+ with open(args.config_file, "r") as config_file:
76
+ config_dict = json.load(config_file)
77
+ with open(args.training_config, "r") as training_config_file:
78
+ training_config_dict = json.load(training_config_file)
79
+
80
+ for k, v in env_dict.items():
81
+ setattr(args, k, v)
82
+ for k, v in config_dict.items():
83
+ setattr(args, k, v)
84
+ for k, v in training_config_dict.items():
85
+ setattr(args, k, v)
86
+
87
+ # Step 1: set data loader
88
+ val_loader = prepare_maisi_controlnet_infer_dataloader(
89
+ json_data_list=args.json_data_list,
90
+ data_base_dir=args.data_base_dir,
91
+ batch_size=args.controlnet_train["batch_size"],
92
+ cache_rate=args.controlnet_train["cache_rate"],
93
+ rank=rank,
94
+ world_size=world_size,)
95
+
96
+ # Step 2: define AE, diffusion model and controlnet
97
+ # define AE
98
+ autoencoder = define_instance(args, "autoencoder_def").to(device)
99
+ # load trained autoencoder model
100
+ if args.trained_autoencoder_path is not None:
101
+ if not os.path.exists(args.trained_autoencoder_path):
102
+ raise ValueError("Please download the autoencoder checkpoint.")
103
+ autoencoder_ckpt = torch.load(args.trained_autoencoder_path, weights_only=True)
104
+ autoencoder.load_state_dict(autoencoder_ckpt)
105
+ logger.info(f"Load trained diffusion model from {args.trained_autoencoder_path}.")
106
+ else:
107
+ logger.info("trained autoencoder model is not loaded.")
108
+
109
+ # define diffusion Model
110
+ unet = define_instance(args, "diffusion_unet_def").to(device)
111
+ include_body_region = unet.include_top_region_index_input
112
+ include_modality = unet.num_class_embeds is not None
113
+
114
+ # load trained diffusion model
115
+ if args.trained_diffusion_path is not None:
116
+ if not os.path.exists(args.trained_diffusion_path):
117
+ raise ValueError("Please download the trained diffusion unet checkpoint.")
118
+ diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device, weights_only=False)
119
+ unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"])
120
+ # load scale factor from diffusion model checkpoint
121
+ scale_factor = diffusion_model_ckpt["scale_factor"]
122
+ logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.")
123
+ logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.")
124
+ else:
125
+ logger.info("trained diffusion model is not loaded.")
126
+ scale_factor = 1.0
127
+ logger.info(f"set scale_factor -> {scale_factor}.")
128
+
129
+ # define ControlNet
130
+ controlnet = define_instance(args, "controlnet_def").to(device)
131
+ # copy weights from the DM to the controlnet
132
+ copy_model_state(controlnet, unet.state_dict())
133
+ # load trained controlnet model if it is provided
134
+ if args.trained_controlnet_path is not None:
135
+ if not os.path.exists(args.trained_controlnet_path):
136
+ raise ValueError("Please download the trained ControlNet checkpoint.")
137
+ controlnet.load_state_dict(
138
+ torch.load(args.trained_controlnet_path, map_location=device, weights_only=False)["controlnet_state_dict"]
139
+ )
140
+ logger.info(f"load trained controlnet model from {args.trained_controlnet_path}")
141
+ else:
142
+ logger.info("trained controlnet is not loaded.")
143
+
144
+ noise_scheduler = define_instance(args, "noise_scheduler")
145
+
146
+ # Step 3: inference
147
+ autoencoder.eval()
148
+ controlnet.eval()
149
+ unet.eval()
150
+
151
+ for batch in val_loader:
152
+
153
+ # get label mask
154
+ labels = batch["label"].to(device)
155
+ # get corresponding conditions
156
+ if include_body_region:
157
+ top_region_index_tensor = batch["top_region_index"].to(device)
158
+ bottom_region_index_tensor = batch["bottom_region_index"].to(device)
159
+ else:
160
+ top_region_index_tensor = None
161
+ bottom_region_index_tensor = None
162
+ spacing_tensor = batch["spacing"].to(device)
163
+ modality_tensor = args.controlnet_infer["modality"] * torch.ones((len(labels),), dtype=torch.long).to(device)
164
+ out_spacing = tuple((batch["spacing"].squeeze().numpy() / 100).tolist())
165
+ # get target dimension
166
+ dim = batch["dim"]
167
+ output_size = (dim[0].item(), dim[1].item(), dim[2].item())
168
+ latent_shape = (args.latent_channels, output_size[0] // 4, output_size[1] // 4, output_size[2] // 4)
169
+ # check if output_size and out_spacing are valid.
170
+ check_input(None, None, None, output_size, out_spacing, None)
171
+ # generate a single synthetic image using a latent diffusion model with controlnet.
172
+ synthetic_images, _ = ldm_conditional_sample_one_image(
173
+ autoencoder=autoencoder,
174
+ diffusion_unet=unet,
175
+ controlnet=controlnet,
176
+ noise_scheduler=noise_scheduler,
177
+ scale_factor=scale_factor,
178
+ device=device,
179
+ combine_label_or=labels,
180
+ top_region_index_tensor=top_region_index_tensor,
181
+ bottom_region_index_tensor=bottom_region_index_tensor,
182
+ spacing_tensor=spacing_tensor,
183
+ modality_tensor=modality_tensor,
184
+ latent_shape=latent_shape,
185
+ output_size=output_size,
186
+ noise_factor=1.0,
187
+ num_inference_steps=args.controlnet_infer["num_inference_steps"],
188
+ autoencoder_sliding_window_infer_size=args.controlnet_infer["autoencoder_sliding_window_infer_size"],
189
+ autoencoder_sliding_window_infer_overlap=args.controlnet_infer["autoencoder_sliding_window_infer_overlap"],
190
+ )
191
+ # save image/label pairs
192
+ labels = decollate_batch(batch)[0]["label"]
193
+ real_object_name = labels.meta.get("filename_or_obj", "default_name.nii.gz")
194
+ labels.meta["filename_or_obj"] = real_object_name
195
+ output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
196
+ synthetic_images = MetaTensor(synthetic_images.squeeze(0), meta=labels.meta)
197
+ img_saver = SaveImage(
198
+ output_dir=args.output_dir,
199
+ output_postfix="image",
200
+ separate_folder=False,
201
+ )
202
+ img_saver(synthetic_images)
203
+ label_saver = SaveImage(
204
+ output_dir=args.output_dir,
205
+ output_postfix="label",
206
+ separate_folder=False,
207
+ )
208
+ label_saver(labels)
209
+ if use_ddp:
210
+ dist.destroy_process_group()
211
+
212
+
213
+ if __name__ == "__main__":
214
+ logging.basicConfig(
215
+ stream=sys.stdout,
216
+ level=logging.INFO,
217
+ format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
218
+ datefmt="%Y-%m-%d %H:%M:%S",
219
+ )
220
+ main()
scripts/inference.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # &nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ #
12
+ # # MAISI Inference Script
13
+ import argparse
14
+ import json
15
+ import logging
16
+ import os
17
+ import sys
18
+ import tempfile
19
+
20
+ import monai
21
+ import torch
22
+ from monai.apps import download_url
23
+ from monai.config import print_config
24
+ from monai.transforms import LoadImage, Orientation
25
+ from monai.utils import set_determinism
26
+
27
+ from scripts.sample import LDMSampler, check_input
28
+ from scripts.utils import define_instance
29
+ from scripts.utils_plot import find_label_center_loc, get_xyz_plot, show_image
30
+
31
+
32
+ def main():
33
+ parser = argparse.ArgumentParser(description="maisi.controlnet.training")
34
+ parser.add_argument(
35
+ "-e",
36
+ "--environment-file",
37
+ default="./configs/environment.json",
38
+ help="environment json file that stores environment path",
39
+ )
40
+ parser.add_argument(
41
+ "-c",
42
+ "--config-file",
43
+ default="./configs/config_maisi.json",
44
+ help="config json file that stores network hyper-parameters",
45
+ )
46
+ parser.add_argument(
47
+ "-i",
48
+ "--inference-file",
49
+ default="./configs/config_infer.json",
50
+ help="config json file that stores inference hyper-parameters",
51
+ )
52
+ parser.add_argument(
53
+ "-x",
54
+ "--extra-config-file",
55
+ default=None,
56
+ help="config json file that stores inference extra parameters",
57
+ )
58
+ parser.add_argument(
59
+ "-s",
60
+ "--random-seed",
61
+ default=None,
62
+ help="random seed, can be None or int",
63
+ )
64
+ parser.add_argument(
65
+ "--version",
66
+ default="maisi3d-rflow",
67
+ type=str,
68
+ help="maisi_version, choose from ['maisi3d-ddpm', 'maisi3d-rflow']",
69
+ )
70
+ args = parser.parse_args()
71
+ # Step 0: configuration
72
+ logger = logging.getLogger("maisi.inference")
73
+
74
+ maisi_version = args.version
75
+
76
+ # ## Set deterministic training for reproducibility
77
+ if args.random_seed is not None:
78
+ set_determinism(seed=args.random_seed)
79
+
80
+ # ## Setup data directory
81
+ # You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.
82
+ # This allows you to save results and reuse downloads.
83
+ # If not specified a temporary directory will be used.
84
+
85
+ directory = os.environ.get("MONAI_DATA_DIRECTORY")
86
+ if directory is not None:
87
+ os.makedirs(directory, exist_ok=True)
88
+ root_dir = tempfile.mkdtemp() if directory is None else directory
89
+ print(root_dir)
90
+
91
+ # TODO: remove the `files` after the files are uploaded to the NGC
92
+ files = [
93
+ {
94
+ "path": "models/autoencoder_epoch273.pt",
95
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials"
96
+ "/model_zoo/model_maisi_autoencoder_epoch273_alternative.pt",
97
+ },
98
+ {
99
+ "path": "models/mask_generation_autoencoder.pt",
100
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai"
101
+ "/tutorials/mask_generation_autoencoder.pt",
102
+ },
103
+ {
104
+ "path": "models/mask_generation_diffusion_unet.pt",
105
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai"
106
+ "/tutorials/model_zoo/model_maisi_mask_generation_diffusion_unet_v2.pt",
107
+ },
108
+ {
109
+ "path": "configs/all_anatomy_size_condtions.json",
110
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/all_anatomy_size_condtions.json",
111
+ },
112
+ {
113
+ "path": "datasets/all_masks_flexible_size_and_spacing_4000.zip",
114
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai"
115
+ "/tutorials/all_masks_flexible_size_and_spacing_4000.zip",
116
+ },
117
+ ]
118
+
119
+ if maisi_version == "maisi3d-ddpm":
120
+ files += [
121
+ {
122
+ "path": "models/diff_unet_3d_ddpm.pt",
123
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo"
124
+ "/model_maisi_input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1_alternative.pt",
125
+ },
126
+ {
127
+ "path": "models/controlnet_3d_ddpm.pt",
128
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo"
129
+ "/model_maisi_controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current_alternative.pt",
130
+ },
131
+ {
132
+ "path": "configs/candidate_masks_flexible_size_and_spacing_3000.json",
133
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai"
134
+ "/tutorials/candidate_masks_flexible_size_and_spacing_3000.json",
135
+ },
136
+ ]
137
+ elif maisi_version == "maisi3d-rflow":
138
+ files += [
139
+ {
140
+ "path": "models/diff_unet_3d_rflow.pt",
141
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/"
142
+ "diff_unet_ckpt_rflow_epoch19350.pt",
143
+ },
144
+ {
145
+ "path": "models/controlnet_3d_rflow.pt",
146
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/"
147
+ "controlnet_rflow_epoch60.pt",
148
+ },
149
+ {
150
+ "path": "configs/candidate_masks_flexible_size_and_spacing_4000.json",
151
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai"
152
+ "/tutorials/candidate_masks_flexible_size_and_spacing_4000.json",
153
+ },
154
+ ]
155
+ else:
156
+ raise ValueError(
157
+ f"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}."
158
+ )
159
+
160
+ for file in files:
161
+ file["path"] = file["path"] if "datasets/" not in file["path"] else os.path.join(root_dir, file["path"])
162
+ download_url(url=file["url"], filepath=file["path"])
163
+
164
+ # ## Read in environment setting, including data directory, model directory, and output directory
165
+ # The information for data directory, model directory, and output directory are saved in ./configs/environment.json
166
+ env_dict = json.load(open(args.environment_file, "r"))
167
+ for k, v in env_dict.items():
168
+ # Update the path to the downloaded dataset in MONAI_DATA_DIRECTORY
169
+ val = v if "datasets/" not in v else os.path.join(root_dir, v)
170
+ setattr(args, k, val)
171
+ print(f"{k}: {val}")
172
+ print("Global config variables have been loaded.")
173
+
174
+ # ## Read in configuration setting, including network definition, body region and anatomy to generate, etc.
175
+ #
176
+ # The information for the inference input, like body region and anatomy to generate, is stored in "./configs/config_infer.json".
177
+ # Please refer to README.md for the details.
178
+ config_dict = json.load(open(args.config_file, "r"))
179
+ for k, v in config_dict.items():
180
+ setattr(args, k, v)
181
+
182
+ # check the format of inference inputs
183
+ config_infer_dict = json.load(open(args.inference_file, "r"))
184
+ # override num_split if asked
185
+ if "autoencoder_tp_num_splits" in config_infer_dict:
186
+ args.autoencoder_def["num_splits"] = config_infer_dict["autoencoder_tp_num_splits"]
187
+ args.mask_generation_autoencoder_def["num_splits"] = config_infer_dict["autoencoder_tp_num_splits"]
188
+ for k, v in config_infer_dict.items():
189
+ setattr(args, k, v)
190
+ print(f"{k}: {v}")
191
+
192
+ #
193
+ # ## Read in optional extra configuration setting - typically acceleration options (TRT)
194
+ #
195
+ #
196
+ if args.extra_config_file is not None:
197
+ extra_config_dict = json.load(open(args.extra_config_file, "r"))
198
+ for k, v in extra_config_dict.items():
199
+ setattr(args, k, v)
200
+ print(f"{k}: {v}")
201
+
202
+ check_input(
203
+ args.body_region,
204
+ args.anatomy_list,
205
+ args.label_dict_json,
206
+ args.output_size,
207
+ args.spacing,
208
+ args.controllable_anatomy_size,
209
+ )
210
+ latent_shape = [args.latent_channels, args.output_size[0] // 4, args.output_size[1] // 4, args.output_size[2] // 4]
211
+ print("Network definition and inference inputs have been loaded.")
212
+
213
+ # ## Initialize networks and noise scheduler, then load the trained model weights.
214
+ # The networks and noise scheduler are defined in `config_file`. We will read them in and load the model weights.
215
+ noise_scheduler = define_instance(args, "noise_scheduler")
216
+ mask_generation_noise_scheduler = define_instance(args, "mask_generation_noise_scheduler")
217
+
218
+ device = torch.device("cuda")
219
+
220
+ autoencoder = define_instance(args, "autoencoder").to(device)
221
+ checkpoint_autoencoder = torch.load(args.trained_autoencoder_path, weights_only=True)
222
+ autoencoder.load_state_dict(checkpoint_autoencoder)
223
+
224
+ diffusion_unet = define_instance(args, "diffusion_unet").to(device)
225
+ checkpoint_diffusion_unet = torch.load(args.trained_diffusion_path, weights_only=False)
226
+ diffusion_unet.load_state_dict(checkpoint_diffusion_unet["unet_state_dict"], strict=True)
227
+ scale_factor = checkpoint_diffusion_unet["scale_factor"].to(device)
228
+
229
+ controlnet = define_instance(args, "controlnet").to(device)
230
+ checkpoint_controlnet = torch.load(args.trained_controlnet_path, weights_only=False)
231
+ monai.networks.utils.copy_model_state(controlnet, diffusion_unet.state_dict())
232
+ controlnet.load_state_dict(checkpoint_controlnet["controlnet_state_dict"], strict=True)
233
+
234
+ mask_generation_autoencoder = define_instance(args, "mask_generation_autoencoder").to(device)
235
+ checkpoint_mask_generation_autoencoder = torch.load(
236
+ args.trained_mask_generation_autoencoder_path, weights_only=True
237
+ )
238
+ mask_generation_autoencoder.load_state_dict(checkpoint_mask_generation_autoencoder)
239
+
240
+ mask_generation_diffusion_unet = define_instance(args, "mask_generation_diffusion").to(device)
241
+ checkpoint_mask_generation_diffusion_unet = torch.load(
242
+ args.trained_mask_generation_diffusion_path, weights_only=False
243
+ )
244
+ mask_generation_diffusion_unet.load_state_dict(checkpoint_mask_generation_diffusion_unet["unet_state_dict"])
245
+ mask_generation_scale_factor = checkpoint_mask_generation_diffusion_unet["scale_factor"]
246
+
247
+ print("All the trained model weights have been loaded.")
248
+
249
+ # ## Define the LDM Sampler, which contains functions that will perform the inference.
250
+ ldm_sampler = LDMSampler(
251
+ args.body_region,
252
+ args.anatomy_list,
253
+ args.all_mask_files_json,
254
+ args.all_anatomy_size_conditions_json,
255
+ args.all_mask_files_base_dir,
256
+ args.label_dict_json,
257
+ args.label_dict_remap_json,
258
+ autoencoder,
259
+ diffusion_unet,
260
+ controlnet,
261
+ noise_scheduler,
262
+ scale_factor,
263
+ mask_generation_autoencoder,
264
+ mask_generation_diffusion_unet,
265
+ mask_generation_scale_factor,
266
+ mask_generation_noise_scheduler,
267
+ device,
268
+ latent_shape,
269
+ args.mask_generation_latent_shape,
270
+ args.output_size,
271
+ args.output_dir,
272
+ args.controllable_anatomy_size,
273
+ image_output_ext=args.image_output_ext,
274
+ label_output_ext=args.label_output_ext,
275
+ spacing=args.spacing,
276
+ modality=args.modality,
277
+ num_inference_steps=args.num_inference_steps,
278
+ mask_generation_num_inference_steps=args.mask_generation_num_inference_steps,
279
+ random_seed=args.random_seed,
280
+ autoencoder_sliding_window_infer_size=args.autoencoder_sliding_window_infer_size,
281
+ autoencoder_sliding_window_infer_overlap=args.autoencoder_sliding_window_infer_overlap,
282
+ )
283
+
284
+ print(f"The generated image/mask pairs will be saved in {args.output_dir}.")
285
+ output_filenames = ldm_sampler.sample_multiple_images(args.num_output_samples)
286
+ print("MAISI image/mask generation finished")
287
+
288
+
289
+ if __name__ == "__main__":
290
+ logging.basicConfig(
291
+ stream=sys.stdout,
292
+ level=logging.INFO,
293
+ format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
294
+ datefmt="%Y-%m-%d %H:%M:%S",
295
+ )
296
+ torch.cuda.reset_peak_memory_stats()
297
+ main()
298
+ peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3) # Convert to GB
299
+ print(f"Peak GPU memory usage: {peak_memory_gb:.2f} GB")
scripts/quality_check.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import numpy as np
13
+
14
+
15
+ def get_masked_data(label_data, image_data, labels):
16
+ """
17
+ Extracts and returns the image data corresponding to specified labels within a 3D volume.
18
+
19
+ This function efficiently masks the `image_data` array based on the provided `labels` in the `label_data` array.
20
+ The function handles cases with both a large and small number of labels, optimizing performance accordingly.
21
+
22
+ Args:
23
+ label_data (np.ndarray): A NumPy array containing label data, representing different anatomical
24
+ regions or classes in a 3D medical image.
25
+ image_data (np.ndarray): A NumPy array containing the image data from which the relevant regions
26
+ will be extracted.
27
+ labels (list of int): A list of integers representing the label values to be used for masking.
28
+
29
+ Returns:
30
+ np.ndarray: A NumPy array containing the elements of `image_data` that correspond to the specified
31
+ labels in `label_data`. If no labels are provided, an empty array is returned.
32
+
33
+ Raises:
34
+ ValueError: If `image_data` and `label_data` do not have the same shape.
35
+
36
+ Example:
37
+ label_int_dict = {"liver": [1], "kidney": [5, 14]}
38
+ masked_data = get_masked_data(label_data, image_data, label_int_dict["kidney"])
39
+ """
40
+
41
+ # Check if the shapes of image_data and label_data match
42
+ if image_data.shape != label_data.shape:
43
+ raise ValueError(
44
+ f"Shape mismatch: image_data has shape {image_data.shape}, "
45
+ f"but label_data has shape {label_data.shape}. They must be the same."
46
+ )
47
+
48
+ if not labels:
49
+ return np.array([]) # Return an empty array if no labels are provided
50
+
51
+ labels = list(set(labels)) # remove duplicate items
52
+
53
+ # Optimize performance based on the number of labels
54
+ num_label_acceleration_thresh = 3
55
+ if len(labels) >= num_label_acceleration_thresh:
56
+ # if many labels, np.isin is faster
57
+ mask = np.isin(label_data, labels)
58
+ else:
59
+ # Use logical OR to combine masks if the number of labels is small
60
+ mask = np.zeros_like(label_data, dtype=bool)
61
+ for label in labels:
62
+ mask = np.logical_or(mask, label_data == label)
63
+
64
+ # Retrieve the masked data
65
+ masked_data = image_data[mask.astype(bool)]
66
+
67
+ return masked_data
68
+
69
+
70
+ def is_outlier(statistics, image_data, label_data, label_int_dict):
71
+ """
72
+ Perform a quality check on the generated image by comparing its statistics with precomputed thresholds.
73
+
74
+ Args:
75
+ statistics (dict): Dictionary containing precomputed statistics including mean +/- 3sigma ranges.
76
+ image_data (np.ndarray): The image data to be checked, typically a 3D NumPy array.
77
+ label_data (np.ndarray): The label data corresponding to the image, used for masking regions of interest.
78
+ label_int_dict (dict): Dictionary mapping label names to their corresponding integer lists.
79
+ e.g., label_int_dict = {"liver": [1], "kidney": [5, 14]}
80
+
81
+ Returns:
82
+ dict: A dictionary with labels as keys, each containing the quality check result,
83
+ including whether it's an outlier, the median value, and the thresholds used.
84
+ If no data is found for a label, the median value will be `None` and `is_outlier` will be `False`.
85
+
86
+ Example:
87
+ # Example input data
88
+ statistics = {
89
+ "liver": {
90
+ "sigma_6_low": -21.596463547885904,
91
+ "sigma_6_high": 156.27881534763367
92
+ },
93
+ "kidney": {
94
+ "sigma_6_low": -15.0,
95
+ "sigma_6_high": 120.0
96
+ }
97
+ }
98
+ label_int_dict = {
99
+ "liver": [1],
100
+ "kidney": [5, 14]
101
+ }
102
+ image_data = np.random.rand(100, 100, 100) # Replace with actual image data
103
+ label_data = np.zeros((100, 100, 100)) # Replace with actual label data
104
+ label_data[40:60, 40:60, 40:60] = 1 # Example region for liver
105
+ label_data[70:90, 70:90, 70:90] = 5 # Example region for kidney
106
+ result = is_outlier(statistics, image_data, label_data, label_int_dict)
107
+ """
108
+ outlier_results = {}
109
+
110
+ for label_name, stats in statistics.items():
111
+ # Get the thresholds from the statistics
112
+ low_thresh = min(stats["sigma_6_low"], stats["percentile_0_5"]) # or "sigma_12_low" depending on your needs
113
+ high_thresh = max(stats["sigma_6_high"], stats["percentile_99_5"]) # or "sigma_12_high" depending on your needs
114
+
115
+ if label_name == "bone":
116
+ high_thresh = 1000.0
117
+
118
+ # Retrieve the corresponding label integers
119
+ labels = label_int_dict.get(label_name, [])
120
+ masked_data = get_masked_data(label_data, image_data, labels)
121
+ masked_data = masked_data[~np.isnan(masked_data)]
122
+
123
+ if len(masked_data) == 0 or masked_data.size == 0:
124
+ outlier_results[label_name] = {
125
+ "is_outlier": False,
126
+ "median_value": None,
127
+ "low_thresh": low_thresh,
128
+ "high_thresh": high_thresh,
129
+ }
130
+ continue
131
+
132
+ # Compute the median of the masked region
133
+ median_value = np.nanmedian(masked_data)
134
+
135
+ if np.isnan(median_value):
136
+ median_value = None
137
+ is_outlier = False
138
+ else:
139
+ # Determine if the median value is an outlier
140
+ is_outlier = median_value < low_thresh or median_value > high_thresh
141
+
142
+ outlier_results[label_name] = {
143
+ "is_outlier": is_outlier,
144
+ "median_value": median_value,
145
+ "low_thresh": low_thresh,
146
+ "high_thresh": high_thresh,
147
+ }
148
+
149
+ return outlier_results
scripts/rectified_flow.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ #
12
+ # =========================================================================
13
+ # Adapted from https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py
14
+ # which has the following license:
15
+ # https://github.com/hpcaitech/Open-Sora/blob/main/LICENSE
16
+ # Licensed under the Apache License, Version 2.0 (the "License");
17
+ # you may not use this file except in compliance with the License.
18
+ # You may obtain a copy of the License at
19
+ #
20
+ # http://www.apache.org/licenses/LICENSE-2.0
21
+ #
22
+ # Unless required by applicable law or agreed to in writing, software
23
+ # distributed under the License is distributed on an "AS IS" BASIS,
24
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25
+ # See the License for the specific language governing permissions and
26
+ # limitations under the License.
27
+ # =========================================================================
28
+
29
+ from __future__ import annotations
30
+
31
+ from typing import Union
32
+
33
+ import numpy as np
34
+ import torch
35
+ from torch.distributions import LogisticNormal
36
+
37
+ from monai.utils import StrEnum
38
+
39
+ from .ddpm import DDPMPredictionType
40
+ from .scheduler import Scheduler
41
+
42
+
43
+ class RFlowPredictionType(StrEnum):
44
+ """
45
+ Set of valid prediction type names for the RFlow scheduler's `prediction_type` argument.
46
+
47
+ v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
48
+ """
49
+
50
+ V_PREDICTION = DDPMPredictionType.V_PREDICTION
51
+
52
+
53
+ def timestep_transform(
54
+ t, input_img_size_numel, base_img_size_numel=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3
55
+ ):
56
+ """
57
+ Applies a transformation to the timestep based on image resolution scaling.
58
+
59
+ Args:
60
+ t (torch.Tensor): The original timestep(s).
61
+ input_img_size_numel (torch.Tensor): The input image's size (H * W * D).
62
+ base_img_size_numel (int): reference H*W*D size, usually smaller than input_img_size_numel.
63
+ scale (float): Scaling factor for the transformation.
64
+ num_train_timesteps (int): Total number of training timesteps.
65
+ spatial_dim (int): Number of spatial dimensions in the image.
66
+
67
+ Returns:
68
+ torch.Tensor: Transformed timestep(s).
69
+ """
70
+ t = t / num_train_timesteps
71
+ ratio_space = (input_img_size_numel / base_img_size_numel) ** (1.0 / spatial_dim)
72
+
73
+ ratio = ratio_space * scale
74
+ new_t = ratio * t / (1 + (ratio - 1) * t)
75
+
76
+ new_t = new_t * num_train_timesteps
77
+ return new_t
78
+
79
+
80
+ class RFlowScheduler(Scheduler):
81
+ """
82
+ A rectified flow scheduler for guiding the diffusion process in a generative model.
83
+
84
+ Supports uniform and logit-normal sampling methods, timestep transformation for
85
+ different resolutions, and noise addition during diffusion.
86
+
87
+ Args:
88
+ num_train_timesteps (int): Total number of training timesteps.
89
+ use_discrete_timesteps (bool): Whether to use discrete timesteps.
90
+ sample_method (str): Training time step sampling method ('uniform' or 'logit-normal').
91
+ loc (float): Location parameter for logit-normal distribution, used only if sample_method='logit-normal'.
92
+ scale (float): Scale parameter for logit-normal distribution, used only if sample_method='logit-normal'.
93
+ use_timestep_transform (bool): Whether to apply timestep transformation.
94
+ If true, there will be more inference timesteps at early(noisy) stages for larger image volumes.
95
+ transform_scale (float): Scaling factor for timestep transformation, used only if use_timestep_transform=True.
96
+ steps_offset (int): Offset added to computed timesteps, used only if use_timestep_transform=True.
97
+ base_img_size_numel (int): Reference image volume size for scaling, used only if use_timestep_transform=True.
98
+ spatial_dim (int): 2 or 3, incidcating 2D or 3D images, used only if use_timestep_transform=True.
99
+
100
+ Example:
101
+
102
+ .. code-block:: python
103
+
104
+ # define a scheduler
105
+ noise_scheduler = RFlowScheduler(
106
+ num_train_timesteps = 1000,
107
+ use_discrete_timesteps = True,
108
+ sample_method = 'logit-normal',
109
+ use_timestep_transform = True,
110
+ base_img_size_numel = 32 * 32 * 32,
111
+ spatial_dim = 3
112
+ )
113
+
114
+ # during training
115
+ inputs = torch.ones(2,4,64,64,32)
116
+ noise = torch.randn_like(inputs)
117
+ timesteps = noise_scheduler.sample_timesteps(inputs)
118
+ noisy_inputs = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
119
+ predicted_velocity = diffusion_unet(
120
+ x=noisy_inputs,
121
+ timesteps=timesteps
122
+ )
123
+ loss = loss_l1(predicted_velocity, (inputs - noise))
124
+
125
+ # during inference
126
+ noisy_inputs = torch.randn(2,4,64,64,32)
127
+ input_img_size_numel = torch.prod(torch.tensor(noisy_inputs.shape[-3:])
128
+ noise_scheduler.set_timesteps(
129
+ num_inference_steps=30, input_img_size_numel=input_img_size_numel)
130
+ )
131
+ all_next_timesteps = torch.cat(
132
+ (noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype))
133
+ )
134
+ for t, next_t in tqdm(
135
+ zip(noise_scheduler.timesteps, all_next_timesteps),
136
+ total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)),
137
+ ):
138
+ predicted_velocity = diffusion_unet(
139
+ x=noisy_inputs,
140
+ timesteps=timesteps
141
+ )
142
+ noisy_inputs, _ = noise_scheduler.step(predicted_velocity, t, noisy_inputs, next_t)
143
+ final_output = noisy_inputs
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ num_train_timesteps: int = 1000,
149
+ use_discrete_timesteps: bool = True,
150
+ sample_method: str = "uniform",
151
+ loc: float = 0.0,
152
+ scale: float = 1.0,
153
+ use_timestep_transform: bool = False,
154
+ transform_scale: float = 1.0,
155
+ steps_offset: int = 0,
156
+ base_img_size_numel: int = 32 * 32 * 32,
157
+ spatial_dim: int = 3,
158
+ ):
159
+ # rectified flow only accepts velocity prediction
160
+ self.prediction_type = RFlowPredictionType.V_PREDICTION
161
+
162
+ self.num_train_timesteps = num_train_timesteps
163
+ self.use_discrete_timesteps = use_discrete_timesteps
164
+ self.base_img_size_numel = base_img_size_numel
165
+ self.spatial_dim = spatial_dim
166
+
167
+ # sample method
168
+ if sample_method not in ["uniform", "logit-normal"]:
169
+ raise ValueError(
170
+ f"sample_method = {sample_method}, which has to be chosen from ['uniform', 'logit-normal']."
171
+ )
172
+ self.sample_method = sample_method
173
+ if sample_method == "logit-normal":
174
+ self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale]))
175
+ self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device)
176
+
177
+ # timestep transform
178
+ self.use_timestep_transform = use_timestep_transform
179
+ self.transform_scale = transform_scale
180
+ self.steps_offset = steps_offset
181
+
182
+ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
183
+ """
184
+ Add noise to the original samples.
185
+
186
+ Args:
187
+ original_samples: original samples
188
+ noise: noise to add to samples
189
+ timesteps: timesteps tensor with shape of (N,), indicating the timestep to be computed for each sample.
190
+
191
+ Returns:
192
+ noisy_samples: sample with added noise
193
+ """
194
+ timepoints: torch.Tensor = timesteps.float() / self.num_train_timesteps
195
+ timepoints = 1 - timepoints # [1,1/1000]
196
+
197
+ # expand timepoint to noise shape
198
+ if noise.ndim == 5:
199
+ timepoints = timepoints[..., None, None, None, None].expand(-1, *noise.shape[1:])
200
+ elif noise.ndim == 4:
201
+ timepoints = timepoints[..., None, None, None].expand(-1, *noise.shape[1:])
202
+ else:
203
+ raise ValueError(f"noise tensor has to be 4D or 5D tensor, yet got shape of {noise.shape}")
204
+
205
+ noisy_samples: torch.Tensor = timepoints * original_samples + (1 - timepoints) * noise
206
+
207
+ return noisy_samples
208
+
209
+ def set_timesteps(
210
+ self,
211
+ num_inference_steps: int,
212
+ device: str | torch.device | None = None,
213
+ input_img_size_numel: int | None = None,
214
+ ) -> None:
215
+ """
216
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
217
+
218
+ Args:
219
+ num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
220
+ device: target device to put the data.
221
+ input_img_size_numel: int, H*W*D of the image, used with self.use_timestep_transform is True.
222
+ """
223
+ if num_inference_steps > self.num_train_timesteps or num_inference_steps < 1:
224
+ raise ValueError(
225
+ f"`num_inference_steps`: {num_inference_steps} should be at least 1, "
226
+ "and cannot be larger than `self.num_train_timesteps`:"
227
+ f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
228
+ f" maximal {self.num_train_timesteps} timesteps."
229
+ )
230
+
231
+ self.num_inference_steps = num_inference_steps
232
+ # prepare timesteps
233
+ timesteps = [
234
+ (1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps)
235
+ ]
236
+ if self.use_discrete_timesteps:
237
+ timesteps = [int(round(t)) for t in timesteps]
238
+ if self.use_timestep_transform:
239
+ timesteps = [
240
+ timestep_transform(
241
+ t,
242
+ input_img_size_numel=input_img_size_numel,
243
+ base_img_size_numel=self.base_img_size_numel,
244
+ num_train_timesteps=self.num_train_timesteps,
245
+ spatial_dim=self.spatial_dim,
246
+ )
247
+ for t in timesteps
248
+ ]
249
+ timesteps_np = np.array(timesteps).astype(np.float16)
250
+ if self.use_discrete_timesteps:
251
+ timesteps_np = timesteps_np.astype(np.int64)
252
+ self.timesteps = torch.from_numpy(timesteps_np).to(device)
253
+ self.timesteps += self.steps_offset
254
+
255
+ def sample_timesteps(self, x_start):
256
+ """
257
+ Randomly samples training timesteps using the chosen sampling method.
258
+
259
+ Args:
260
+ x_start (torch.Tensor): The input tensor for sampling.
261
+
262
+ Returns:
263
+ torch.Tensor: Sampled timesteps.
264
+ """
265
+ if self.sample_method == "uniform":
266
+ t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_train_timesteps
267
+ elif self.sample_method == "logit-normal":
268
+ t = self.sample_t(x_start) * self.num_train_timesteps
269
+
270
+ if self.use_discrete_timesteps:
271
+ t = t.long()
272
+
273
+ if self.use_timestep_transform:
274
+ input_img_size_numel = torch.prod(torch.tensor(x_start.shape[2:]))
275
+ t = timestep_transform(
276
+ t,
277
+ input_img_size_numel=input_img_size_numel,
278
+ base_img_size_numel=self.base_img_size_numel,
279
+ num_train_timesteps=self.num_train_timesteps,
280
+ spatial_dim=len(x_start.shape) - 2,
281
+ )
282
+
283
+ return t
284
+
285
+ def step(
286
+ self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: Union[int, None] = None
287
+ ) -> tuple[torch.Tensor, torch.Tensor]:
288
+ """
289
+ Predicts the next sample in the diffusion process.
290
+
291
+ Args:
292
+ model_output (torch.Tensor): Output from the trained diffusion model.
293
+ timestep (int): Current timestep in the diffusion chain.
294
+ sample (torch.Tensor): Current sample in the process.
295
+ next_timestep (Union[int, None]): Optional next timestep.
296
+
297
+ Returns:
298
+ tuple[torch.Tensor, torch.Tensor]: Predicted sample at the next step and additional info.
299
+ """
300
+ # Ensure num_inference_steps exists and is a valid integer
301
+ if not hasattr(self, "num_inference_steps") or not isinstance(self.num_inference_steps, int):
302
+ raise AttributeError(
303
+ "num_inference_steps is missing or not an integer in the class."
304
+ "Please run self.set_timesteps(num_inference_steps,device,input_img_size_numel) to set it."
305
+ )
306
+
307
+ v_pred = model_output
308
+
309
+ if next_timestep is not None:
310
+ next_timestep = int(next_timestep)
311
+ dt: float = (
312
+ float(timestep - next_timestep) / self.num_train_timesteps
313
+ ) # Now next_timestep is guaranteed to be int
314
+ else:
315
+ dt = (
316
+ 1.0 / float(self.num_inference_steps) if self.num_inference_steps > 0 else 0.0
317
+ ) # Avoid division by zero
318
+
319
+ pred_post_sample = sample + v_pred * dt
320
+ pred_original_sample = sample + v_pred * timestep / self.num_train_timesteps
321
+
322
+ return pred_post_sample, pred_original_sample