Gregory Ksenofontov commited on
Commit
c67cc3f
·
1 Parent(s): 1eb7106

Init commit

Browse files
Files changed (40) hide show
  1. .gitattributes +2 -34
  2. README.md +67 -3
  3. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_10/checkpoints/backward_3/model.safetensors +3 -0
  4. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_10/checkpoints/forward_3/model.safetensors +3 -0
  5. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_10/config.yaml +47 -0
  6. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_100/checkpoints/backward_3/model.safetensors +3 -0
  7. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_100/checkpoints/forward_3/model.safetensors +3 -0
  8. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_100/config.yaml +47 -0
  9. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_2/checkpoints/backward_3/model.safetensors +3 -0
  10. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_2/checkpoints/forward_3/model.safetensors +3 -0
  11. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_2/config.yaml +47 -0
  12. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_25/checkpoints/backward_3/model.safetensors +3 -0
  13. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_25/checkpoints/forward_3/model.safetensors +3 -0
  14. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_25/config.yaml +47 -0
  15. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_4/checkpoints/backward_3/model.safetensors +3 -0
  16. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_4/checkpoints/forward_3/model.safetensors +3 -0
  17. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_4/config.yaml +47 -0
  18. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_50/checkpoints/backward_3/model.safetensors +3 -0
  19. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_50/checkpoints/forward_3/model.safetensors +3 -0
  20. checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_50/config.yaml +47 -0
  21. checkpoints /images/cmnist/uniform/dim_32_aplha_0.01/checkpoints/backward_3/model.safetensors +3 -0
  22. checkpoints /images/cmnist/uniform/dim_32_aplha_0.01/checkpoints/forward_3/model.safetensors +3 -0
  23. checkpoints /images/cmnist/uniform/dim_32_aplha_0.01/config.yaml +47 -0
  24. checkpoints /images/cmnist/uniform/dim_32_aplha_0.05/checkpoints/backward_3/model.safetensors +3 -0
  25. checkpoints /images/cmnist/uniform/dim_32_aplha_0.05/checkpoints/forward_3/model.safetensors +3 -0
  26. checkpoints /images/cmnist/uniform/dim_32_aplha_0.05/config.yaml +47 -0
  27. checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.005/checkpoints/backward_4/model.safetensors +3 -0
  28. checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.005/checkpoints/forward_4/model.safetensors +3 -0
  29. checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.005/config.yaml +45 -0
  30. checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.01/checkpoints/backward_4/model.safetensors +3 -0
  31. checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.01/checkpoints/forward_4/model.safetensors +3 -0
  32. checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.01/config.yaml +52 -0
  33. checkpoints /texts/amazon/uniform/dim_100_aplha_0.005/checkpoints/backward_5/model.safetensors +3 -0
  34. checkpoints /texts/amazon/uniform/dim_100_aplha_0.005/checkpoints/forward_5/model.safetensors +3 -0
  35. checkpoints /texts/amazon/uniform/dim_100_aplha_0.005/config.yaml +49 -0
  36. checkpoints /texts/amazon/uniform/dim_100_aplha_0.01/checkpoints/backward_5/model.safetensors +3 -0
  37. checkpoints /texts/amazon/uniform/dim_100_aplha_0.01/checkpoints/forward_5/model.safetensors +3 -0
  38. checkpoints /texts/amazon/uniform/dim_100_aplha_0.01/config.yaml +49 -0
  39. checkpoints /tokenizer_amazon.json +3 -0
  40. checkpoints /vqgan_celeba_f8_1024.ckpt +3 -0
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
3
+ *.json filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,3 +1,67 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+ <div align="center">
5
+
6
+ # Categorical Schrödinger Bridge Matching (CSBM)
7
+
8
+ [Grigoriy Ksenofontov](https://scholar.google.com/citations?user=e0mirzYAAAAJ),
9
+ [Alexander Korotin](https://scholar.google.ru/citations?user=1rIIvjAAAAAJ)
10
+
11
+ [![arXiv Paper](https://img.shields.io/badge/arXiv-2502.01416-b31b1b)](https://arxiv.org/abs/2502.01416)
12
+ [![OpenReview Paper](https://img.shields.io/badge/OpenReview-PDF-8c1b13)](https://openreview.net/forum?id=RBly0nOr2h)
13
+ [![GitHub](https://img.shields.io/github/stars/gregkseno/csbm?style=social)](https://github.com/gregkseno/csbm)
14
+ ![GitHub License](https://img.shields.io/github/license/gregkseno/csbm)
15
+ [![WandB](https://img.shields.io/badge/W%26B-view-FFCC33?logo=wandb)](https://wandb.ai/gregkseno/csbm)
16
+
17
+ </div>
18
+
19
+ This repository hosts the official checkpoints for the paper "Categorical Schrödinger Bridge Matching", accepted at ICML 2025.
20
+
21
+ ## 📌 TL;DR
22
+
23
+ This paper extends the Schrödinger Bridge problem to work with discrete time and spaces.
24
+
25
+ <!-- ![teaser](./images/teaser.png) -->
26
+
27
+ ## 💾 Checkpoints
28
+
29
+ ### CSBM
30
+
31
+ | Dataset | Reference Process | $\alpha$ | $N$ | Saved Iteration |
32
+ | ------------- | ----------------- | ----------- | --------------------- | --------------- |
33
+ | Colored MNIST | **gaussian** | 0.01 | 2, 4, 10, 25, 50, 100 | 3 |
34
+ | Colored MNIST | **uniform** | 0.01, 0.05 | 25 | 3 |
35
+ | CelebA | **uniform** | 0.01, 0.005 | 100 | 4 |
36
+ | Amazon Review | **uniform** | 0.01, 0.005 | 100 | 5 |
37
+
38
+ > [!NOTE]
39
+ > Each experiment directory includes a `config.yaml` file with the full training configuration.
40
+
41
+ ### Additional Components
42
+
43
+ 1. `vqgan_celeba_f8_1024.ckpt` — **VQ-GAN** pretrained on the CelebA dataset
44
+ 2. `tokenizer_amazon.json` — **Tokenizer** trained on the Amazon Reviews dataset
45
+
46
+ ## 🎓 Citation
47
+
48
+ ```bibtex
49
+ @article{ksenofontov2025categorical,
50
+ title={Categorical {Schr\"odinger} Bridge Matching},
51
+ author={Ksenofontov, Grigoriy and Korotin, Alexander},
52
+ journal={arXiv preprint arXiv:2502.01416},
53
+ year={2025}
54
+ }
55
+ ```
56
+
57
+ ## 🙏 Credits
58
+
59
+ - [Weights & Biases](https://wandb.ai) — experiment-tracking and visualization toolkit;
60
+ - [Hugging Face](https://huggingface.co) — Tokenizers and Accelerate libraries for tokenizer implementation, parallel training, and checkpoint hosting on the Hub;
61
+ - [D3PM](https://github.com/google-research/google-research/tree/master/d3pm) — reference implementation of discrete-diffusion models;
62
+ - [Taming Transformers](https://github.com/CompVis/taming-transformers) — original VQ-GAN codebase;
63
+ - [VQ-Diffusion](https://github.com/microsoft/VQ-Diffusion) — vector-quantized diffusion architecture;
64
+ - [MDLM](https://github.com/kuleshov-group/mdlm) — diffusion architecture for text-generation experiments;
65
+ - [ASBM](https://arxiv.org/abs/2405.14449) — evaluation metrics and baseline models for CelebA face transfer;
66
+ - [Balancing the Style-Content Trade-Off in Sentiment Transfer Using Polarity-Aware Denoising](https://arxiv.org/abs/2312.14708) — processed Amazon Reviews dataset and sentiment-transfer baselines;
67
+ - [Inkscape](https://inkscape.org/) — an excellent open-source editor for vector graphics.
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_10/checkpoints/backward_3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d64537aa661f8ee3e8516b525dbc0b961eff87cecba8101296bbff5117ae445e
3
+ size 139416416
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_10/checkpoints/forward_3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5870abf0ad490bbdfcc568ddc0aa82e5b087d46de7420bcb26589e68207dafd
3
+ size 139416416
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_10/config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ type: images
3
+ dataset: cmnist
4
+ dim: 32
5
+ num_categories: 256
6
+ num_timesteps: 10
7
+ num_skip_steps: 10
8
+ coupling_type: independent
9
+ model:
10
+ in_channels: 3
11
+ num_channels: 64
12
+ num_layers: 2
13
+ ch_mults:
14
+ - 1
15
+ - 2
16
+ - 2
17
+ - 2
18
+ attention_resolution: 16
19
+ num_groups: 32
20
+ num_att_heads: 4
21
+ dropout: 0.1
22
+ prior:
23
+ alpha: 0.01
24
+ type: gaussian
25
+ eps: 1.0e-06
26
+ train:
27
+ batch_size: 128
28
+ low_precision: false
29
+ gradient_accumulation_steps: 1
30
+ iterations: 20
31
+ prior_iterations: 5
32
+ inner_iterations: 20000
33
+ use_mini_batch: true
34
+ ce_loss_coeff: 0.001
35
+ kl_loss_coeff: 1
36
+ mse_loss_coeff: 0
37
+ ema_decay: 0.9999
38
+ optimizer:
39
+ lr: 0.0002
40
+ betas:
41
+ - 0.95
42
+ - 0.99
43
+ eval:
44
+ freq: 1000
45
+ num_samples: 25
46
+ num_trajectories: 4
47
+ num_translations: 2
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_100/checkpoints/backward_3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5896a264f6e29453ea4cbba2c625487ad3423ddfda8f93d9a94acaf7f36e145
3
+ size 139416416
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_100/checkpoints/forward_3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e99e791d3d37ff2b4bfe1f88461526b2216cc1e5dc785917beb8c4e973e3d48d
3
+ size 139416416
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_100/config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ type: images
3
+ dataset: cmnist
4
+ dim: 32
5
+ num_categories: 256
6
+ num_timesteps: 100
7
+ num_skip_steps: 1
8
+ coupling_type: independent
9
+ model:
10
+ in_channels: 3
11
+ num_channels: 64
12
+ num_layers: 2
13
+ ch_mults:
14
+ - 1
15
+ - 2
16
+ - 2
17
+ - 2
18
+ attention_resolution: 16
19
+ num_groups: 32
20
+ num_att_heads: 4
21
+ dropout: 0.1
22
+ prior:
23
+ alpha: 0.01
24
+ type: gaussian
25
+ eps: 1.0e-06
26
+ train:
27
+ batch_size: 128
28
+ low_precision: false
29
+ gradient_accumulation_steps: 1
30
+ iterations: 20
31
+ prior_iterations: 10
32
+ inner_iterations: 20000
33
+ use_mini_batch: true
34
+ ce_loss_coeff: 0.001
35
+ kl_loss_coeff: 1
36
+ mse_loss_coeff: 0
37
+ ema_decay: 0.9999
38
+ optimizer:
39
+ lr: 0.0002
40
+ betas:
41
+ - 0.95
42
+ - 0.99
43
+ eval:
44
+ freq: 1000
45
+ num_samples: 25
46
+ num_trajectories: 4
47
+ num_translations: 2
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_2/checkpoints/backward_3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca3e655484abba1645b826534921224c05db3126fb92d352ee11a89085547a3d
3
+ size 139416416
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_2/checkpoints/forward_3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b8747b4c6f4cea35f13596dc81992b70750f9a0fe9d8c7ee8c080da52bf3451
3
+ size 139416416
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_2/config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ type: images
3
+ dataset: cmnist
4
+ dim: 32
5
+ num_categories: 256
6
+ num_timesteps: 2
7
+ num_skip_steps: 50
8
+ coupling_type: independent
9
+ model:
10
+ in_channels: 3
11
+ num_channels: 64
12
+ num_layers: 2
13
+ ch_mults:
14
+ - 1
15
+ - 2
16
+ - 2
17
+ - 2
18
+ attention_resolution: 16
19
+ num_groups: 32
20
+ num_att_heads: 4
21
+ dropout: 0.1
22
+ prior:
23
+ alpha: 0.01
24
+ type: gaussian
25
+ eps: 1.0e-06
26
+ train:
27
+ batch_size: 128
28
+ low_precision: false
29
+ gradient_accumulation_steps: 1
30
+ iterations: 20
31
+ prior_iterations: 5
32
+ inner_iterations: 20000
33
+ use_mini_batch: true
34
+ ce_loss_coeff: 0.001
35
+ kl_loss_coeff: 1
36
+ mse_loss_coeff: 0
37
+ ema_decay: 0.9999
38
+ optimizer:
39
+ lr: 0.0002
40
+ betas:
41
+ - 0.95
42
+ - 0.99
43
+ eval:
44
+ freq: 1000
45
+ num_samples: 25
46
+ num_trajectories: 4
47
+ num_translations: 2
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_25/checkpoints/backward_3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb36e26158e2cba3b5b65ba35b7ebeba0817a568f3d9b8eca9d413fc752e4713
3
+ size 139416416
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_25/checkpoints/forward_3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f44afb3a90eef04257fc4ea5067386bba3a5713cc667dc1f6e7e305b8d92e540
3
+ size 139416416
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_25/config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ type: images
3
+ dataset: cmnist
4
+ dim: 32
5
+ num_categories: 256
6
+ num_timesteps: 25
7
+ num_skip_steps: 4
8
+ coupling_type: independent
9
+ model:
10
+ in_channels: 3
11
+ num_channels: 64
12
+ num_layers: 2
13
+ ch_mults:
14
+ - 1
15
+ - 2
16
+ - 2
17
+ - 2
18
+ attention_resolution: 16
19
+ num_groups: 32
20
+ num_att_heads: 4
21
+ dropout: 0.1
22
+ prior:
23
+ alpha: 0.01
24
+ type: gaussian
25
+ eps: 1.0e-06
26
+ train:
27
+ batch_size: 128
28
+ low_precision: false
29
+ gradient_accumulation_steps: 1
30
+ iterations: 20
31
+ prior_iterations: 5
32
+ inner_iterations: 20000
33
+ use_mini_batch: true
34
+ ce_loss_coeff: 0.001
35
+ kl_loss_coeff: 1
36
+ mse_loss_coeff: 0
37
+ ema_decay: 0.9999
38
+ optimizer:
39
+ lr: 0.0002
40
+ betas:
41
+ - 0.95
42
+ - 0.99
43
+ eval:
44
+ freq: 1000
45
+ num_samples: 25
46
+ num_trajectories: 4
47
+ num_translations: 2
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_4/checkpoints/backward_3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba218a2b25c2e7682efee04227fe1c736d6f24905d7adda7cb254ef6422787ec
3
+ size 139416416
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_4/checkpoints/forward_3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10d5cb66e1770eb54c2d872c24d584125563facf67af25799bf561edf0c64ac1
3
+ size 139416416
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_4/config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ type: images
3
+ dataset: cmnist
4
+ dim: 32
5
+ num_categories: 256
6
+ num_timesteps: 4
7
+ num_skip_steps: 25
8
+ coupling_type: independent
9
+ model:
10
+ in_channels: 3
11
+ num_channels: 64
12
+ num_layers: 2
13
+ ch_mults:
14
+ - 1
15
+ - 2
16
+ - 2
17
+ - 2
18
+ attention_resolution: 16
19
+ num_groups: 32
20
+ num_att_heads: 4
21
+ dropout: 0.1
22
+ prior:
23
+ alpha: 0.01
24
+ type: gaussian
25
+ eps: 1.0e-06
26
+ train:
27
+ batch_size: 128
28
+ low_precision: false
29
+ gradient_accumulation_steps: 1
30
+ iterations: 20
31
+ prior_iterations: 5
32
+ inner_iterations: 20000
33
+ use_mini_batch: true
34
+ ce_loss_coeff: 0.001
35
+ kl_loss_coeff: 1
36
+ mse_loss_coeff: 0
37
+ ema_decay: 0.9999
38
+ optimizer:
39
+ lr: 0.0002
40
+ betas:
41
+ - 0.95
42
+ - 0.99
43
+ eval:
44
+ freq: 1000
45
+ num_samples: 25
46
+ num_trajectories: 4
47
+ num_translations: 2
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_50/checkpoints/backward_3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2390760bcb242299d49d881e2c6c869739ead285f4067634ac9c1da8a9eaa2c4
3
+ size 139416416
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_50/checkpoints/forward_3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57aad3763421081c13d9435c2dc80ace6aaf68397d3a8fc512f9ad19d85e660b
3
+ size 139416416
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_50/config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ type: images
3
+ dataset: cmnist
4
+ dim: 32
5
+ num_categories: 256
6
+ num_timesteps: 50
7
+ num_skip_steps: 2
8
+ coupling_type: independent
9
+ model:
10
+ in_channels: 3
11
+ num_channels: 64
12
+ num_layers: 2
13
+ ch_mults:
14
+ - 1
15
+ - 2
16
+ - 2
17
+ - 2
18
+ attention_resolution: 16
19
+ num_groups: 32
20
+ num_att_heads: 4
21
+ dropout: 0.1
22
+ prior:
23
+ alpha: 0.01
24
+ type: gaussian
25
+ eps: 1.0e-06
26
+ train:
27
+ batch_size: 128
28
+ low_precision: false
29
+ gradient_accumulation_steps: 1
30
+ iterations: 20
31
+ prior_iterations: 10
32
+ inner_iterations: 20000
33
+ use_mini_batch: true
34
+ ce_loss_coeff: 0.001
35
+ kl_loss_coeff: 1
36
+ mse_loss_coeff: 0
37
+ ema_decay: 0.9999
38
+ optimizer:
39
+ lr: 0.0002
40
+ betas:
41
+ - 0.95
42
+ - 0.99
43
+ eval:
44
+ freq: 1000
45
+ num_samples: 25
46
+ num_trajectories: 4
47
+ num_translations: 2
checkpoints /images/cmnist/uniform/dim_32_aplha_0.01/checkpoints/backward_3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:125c3702db67b5d8c46969e8772f8ffed3fdbc399808968a317086e2d9fe34e1
3
+ size 139416416
checkpoints /images/cmnist/uniform/dim_32_aplha_0.01/checkpoints/forward_3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f743431c8174d039c6f3a48a662725893eee873331a149dc0c788f3138d914d6
3
+ size 139416416
checkpoints /images/cmnist/uniform/dim_32_aplha_0.01/config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ type: images
3
+ dataset: cmnist
4
+ dim: 32
5
+ num_categories: 256
6
+ num_timesteps: 25
7
+ num_skip_steps: 4
8
+ coupling_type: independent
9
+ model:
10
+ in_channels: 3
11
+ num_channels: 64
12
+ num_layers: 2
13
+ ch_mults:
14
+ - 1
15
+ - 2
16
+ - 2
17
+ - 2
18
+ attention_resolution: 16
19
+ num_groups: 32
20
+ num_att_heads: 4
21
+ dropout: 0.1
22
+ prior:
23
+ alpha: 0.01
24
+ type: uniform
25
+ eps: 1.0e-06
26
+ train:
27
+ batch_size: 128
28
+ low_precision: false
29
+ gradient_accumulation_steps: 1
30
+ iterations: 20
31
+ prior_iterations: 5
32
+ inner_iterations: 20000
33
+ use_mini_batch: false
34
+ ce_loss_coeff: 0.001
35
+ kl_loss_coeff: 1
36
+ mse_loss_coeff: 0
37
+ ema_decay: 0.9999
38
+ optimizer:
39
+ lr: 0.0002
40
+ betas:
41
+ - 0.95
42
+ - 0.99
43
+ eval:
44
+ freq: 1000
45
+ num_samples: 25
46
+ num_trajectories: 4
47
+ num_translations: 2
checkpoints /images/cmnist/uniform/dim_32_aplha_0.05/checkpoints/backward_3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fdbd67065bb6ef3933b370fee02c4d935103d8e32a8d7d358f27fb4ba78946c
3
+ size 139416416
checkpoints /images/cmnist/uniform/dim_32_aplha_0.05/checkpoints/forward_3/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f329e98da91a5c305b820d013ec4e816d83ce54bd3c33cf42e8b70698c937a2b
3
+ size 139416416
checkpoints /images/cmnist/uniform/dim_32_aplha_0.05/config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ type: images
3
+ dataset: cmnist
4
+ dim: 32
5
+ num_categories: 256
6
+ num_timesteps: 25
7
+ num_skip_steps: 4
8
+ coupling_type: independent
9
+ model:
10
+ in_channels: 3
11
+ num_channels: 64
12
+ num_layers: 2
13
+ ch_mults:
14
+ - 1
15
+ - 2
16
+ - 2
17
+ - 2
18
+ attention_resolution: 16
19
+ num_groups: 32
20
+ num_att_heads: 4
21
+ dropout: 0.1
22
+ prior:
23
+ alpha: 0.05
24
+ type: uniform
25
+ eps: 1.0e-06
26
+ train:
27
+ batch_size: 128
28
+ low_precision: false
29
+ gradient_accumulation_steps: 1
30
+ iterations: 20
31
+ prior_iterations: 5
32
+ inner_iterations: 20000
33
+ use_mini_batch: false
34
+ ce_loss_coeff: 0.001
35
+ kl_loss_coeff: 1
36
+ mse_loss_coeff: 0
37
+ ema_decay: 0.9999
38
+ optimizer:
39
+ lr: 0.0002
40
+ betas:
41
+ - 0.95
42
+ - 0.99
43
+ eval:
44
+ freq: 1000
45
+ num_samples: 25
46
+ num_trajectories: 4
47
+ num_translations: 2
checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.005/checkpoints/backward_4/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fee703324c0024cfa87965f8ce98ecac4d9e199b19b7f07e56189e96bdae4483
3
+ size 372373768
checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.005/checkpoints/forward_4/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f8b9bd34a828628b91a1a32ef82c66860e7d00f6cd017ded46af500d15a35a5
3
+ size 372373768
checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.005/config.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ type: quantized_images
3
+ dataset: celeba
4
+ dim: 128
5
+ latent_dim: 16
6
+ num_categories: 1024
7
+ num_timesteps: 100
8
+ num_skip_steps: 1
9
+ coupling_type: independent
10
+ train_test_split: 0.9
11
+ model:
12
+ hidden_dim: 256
13
+ num_channels: 4
14
+ num_layers: 18
15
+ num_att_heads: 16
16
+ dropout: 0
17
+ codec:
18
+ ckpt_path: checkpoints/vqgan_celeba_f8_1024.ckpt
19
+ config_path: configs/vqgan_celeba_f8_1024.yaml
20
+ prior:
21
+ alpha: 0.005
22
+ type: uniform
23
+ eps: 1e-6
24
+ train:
25
+ batch_size: 32
26
+ low_precision: false
27
+ gradient_accumulation_steps: 1
28
+ iterations: 20
29
+ prior_iterations: 20
30
+ inner_iterations: 20000
31
+ use_mini_batch: false
32
+ ce_loss_coeff: 0.001
33
+ kl_loss_coeff: 1
34
+ mse_loss_coeff: 0
35
+ ema_decay: 0.999
36
+ optimizer:
37
+ lr: 0.0004
38
+ betas:
39
+ - 0.95
40
+ - 0.99
41
+ eval:
42
+ freq: 1000
43
+ num_samples: 25
44
+ num_trajectories: 4
45
+ num_translations: 1
checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.01/checkpoints/backward_4/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08630e97e5a15fcdb4b51b2ad17f1a6362b48d9b8bc502cd6e517c483ef53c3d
3
+ size 372373768
checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.01/checkpoints/forward_4/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:814c99bd6124b3ec34af0b68c395f8c6842c79ec8c01a034314e023c4b754fa0
3
+ size 372373768
checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.01/config.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ type: quantized_images # choices: "toy", "images", "quantized_images", "texts"
3
+ dataset: celeba
4
+ dim: 128
5
+ latent_dim: 16
6
+ num_categories: 1024
7
+ num_timesteps: 100
8
+ num_skip_steps: 1
9
+ coupling_type: independent # choices: "independent", "prior"
10
+ train_test_split: 0.9
11
+ model:
12
+ hidden_dim: 256
13
+ num_channels: 4
14
+ num_layers: 18
15
+ num_att_heads: 16
16
+ dropout: 0
17
+ codec:
18
+ ckpt_path: checkpoints/vqgan_celeba_f8_1024.ckpt
19
+ config_path: configs/vqgan_celeba_f8_1024.yaml
20
+ prior:
21
+ alpha: 0.01
22
+ type: uniform # choices: "uniform", "gaussian", "centroid_gaussian", "von_mises"
23
+ eps: 1e-6
24
+ train:
25
+ batch_size: 32
26
+ low_precision: false
27
+ gradient_accumulation_steps: 4
28
+ iterations: 20
29
+ prior_iterations: 20
30
+ inner_iterations: 20000
31
+ use_mini_batch: false
32
+ ce_loss_coeff: 0.001
33
+ kl_loss_coeff: 1
34
+ mse_loss_coeff: 0
35
+ ema_decay: 0.999
36
+ optimizer:
37
+ lr: 0.0004
38
+ betas: [0.95, 0.99]
39
+ # weight_decay: 4.5e-2
40
+ # scheduler:
41
+ # factor: 0.5
42
+ # patience: 1000
43
+ # min_lr: 1.0e-6
44
+ # threshold: 1.0e-1
45
+ # threshold_mode: rel
46
+ # warmup_lr: 4.5e-4 # the lr to be touched after warmup
47
+ # warmup: 10000
48
+ eval:
49
+ freq: 1000
50
+ num_samples: 25
51
+ num_trajectories: 4 # How many trajecotries
52
+ num_translations: 1 # How many times sample trajecotry from single point
checkpoints /texts/amazon/uniform/dim_100_aplha_0.005/checkpoints/backward_5/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0bf711f8a55bf39b769f4541affe5d6551975fc68f9628199ae593843b6205e
3
+ size 399311828
checkpoints /texts/amazon/uniform/dim_100_aplha_0.005/checkpoints/forward_5/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:945bd9efd2d46e284aed78db5beaf11504f67602d9dc5f0d9c159d7829753dc1
3
+ size 399311828
checkpoints /texts/amazon/uniform/dim_100_aplha_0.005/config.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ type: texts
3
+ dataset: amazon
4
+ dim: 100
5
+ num_categories: 8192
6
+ num_timesteps: 100
7
+ num_skip_steps: 1
8
+ coupling_type: independent
9
+ train_test_split: null
10
+ model:
11
+ config:
12
+ name: small
13
+ type: ddit
14
+ hidden_size: 768
15
+ cond_dim: 128
16
+ length: 1024
17
+ n_blocks: 12
18
+ n_heads: 12
19
+ scale_by_sigma: true
20
+ dropout: 0.1
21
+ tie_word_embeddings: false
22
+ tokenizer:
23
+ path: checkpoints/tokenizer_amazon.json
24
+ prior:
25
+ alpha: 0.005
26
+ type: uniform
27
+ eps: 1.0e-20
28
+ train:
29
+ batch_size: 32
30
+ low_precision: true
31
+ gradient_accumulation_steps: 1
32
+ iterations: 20
33
+ prior_iterations: 20
34
+ inner_iterations: 20000
35
+ use_mini_batch: false
36
+ ce_loss_coeff: 0.001
37
+ kl_loss_coeff: 1
38
+ mse_loss_coeff: 0
39
+ ema_decay: 0.999
40
+ optimizer:
41
+ lr: 0.0004
42
+ betas:
43
+ - 0.95
44
+ - 0.99
45
+ eval:
46
+ freq: 1000
47
+ num_samples: 32
48
+ num_trajectories: 1
49
+ num_translations: 1
checkpoints /texts/amazon/uniform/dim_100_aplha_0.01/checkpoints/backward_5/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1cac9cc955db1256fa5cf9d80da014c34d30a80dedb256e47dae4de486b12c62
3
+ size 399311828
checkpoints /texts/amazon/uniform/dim_100_aplha_0.01/checkpoints/forward_5/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6820f1d460f3a023f00c1ca89d0dbf6ab8965906e3647ec2d77418661104138c
3
+ size 399311828
checkpoints /texts/amazon/uniform/dim_100_aplha_0.01/config.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ type: texts
3
+ dataset: amazon
4
+ dim: 100
5
+ num_categories: 8192
6
+ num_timesteps: 100
7
+ num_skip_steps: 1
8
+ coupling_type: independent
9
+ train_test_split: null
10
+ model:
11
+ config:
12
+ name: small
13
+ type: ddit
14
+ hidden_size: 768
15
+ cond_dim: 128
16
+ length: 1024
17
+ n_blocks: 12
18
+ n_heads: 12
19
+ scale_by_sigma: true
20
+ dropout: 0.1
21
+ tie_word_embeddings: false
22
+ tokenizer:
23
+ path: checkpoints/tokenizer_amazon.json
24
+ prior:
25
+ alpha: 0.01
26
+ type: uniform
27
+ eps: 1.0e-20
28
+ train:
29
+ batch_size: 32
30
+ low_precision: true
31
+ gradient_accumulation_steps: 1
32
+ iterations: 20
33
+ prior_iterations: 20
34
+ inner_iterations: 20000
35
+ use_mini_batch: false
36
+ ce_loss_coeff: 0.001
37
+ kl_loss_coeff: 1
38
+ mse_loss_coeff: 0
39
+ ema_decay: 0.999
40
+ optimizer:
41
+ lr: 0.0004
42
+ betas:
43
+ - 0.95
44
+ - 0.99
45
+ eval:
46
+ freq: 1000
47
+ num_samples: 32
48
+ num_trajectories: 1
49
+ num_translations: 1
checkpoints /tokenizer_amazon.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52d20f72ce3968e1af9b0f0fbea672c7915db6db5fa1ef639d32ee95c5d53a22
3
+ size 309914
checkpoints /vqgan_celeba_f8_1024.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59e325381f46e6b43a5e4ded5d651928d52bc8910ae8a4114be4c871e827fd56
3
+ size 936814106