Gregory Ksenofontov
commited on
Commit
·
c67cc3f
1
Parent(s):
1eb7106
Init commit
Browse files- .gitattributes +2 -34
- README.md +67 -3
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_10/checkpoints/backward_3/model.safetensors +3 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_10/checkpoints/forward_3/model.safetensors +3 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_10/config.yaml +47 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_100/checkpoints/backward_3/model.safetensors +3 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_100/checkpoints/forward_3/model.safetensors +3 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_100/config.yaml +47 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_2/checkpoints/backward_3/model.safetensors +3 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_2/checkpoints/forward_3/model.safetensors +3 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_2/config.yaml +47 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_25/checkpoints/backward_3/model.safetensors +3 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_25/checkpoints/forward_3/model.safetensors +3 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_25/config.yaml +47 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_4/checkpoints/backward_3/model.safetensors +3 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_4/checkpoints/forward_3/model.safetensors +3 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_4/config.yaml +47 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_50/checkpoints/backward_3/model.safetensors +3 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_50/checkpoints/forward_3/model.safetensors +3 -0
- checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_50/config.yaml +47 -0
- checkpoints /images/cmnist/uniform/dim_32_aplha_0.01/checkpoints/backward_3/model.safetensors +3 -0
- checkpoints /images/cmnist/uniform/dim_32_aplha_0.01/checkpoints/forward_3/model.safetensors +3 -0
- checkpoints /images/cmnist/uniform/dim_32_aplha_0.01/config.yaml +47 -0
- checkpoints /images/cmnist/uniform/dim_32_aplha_0.05/checkpoints/backward_3/model.safetensors +3 -0
- checkpoints /images/cmnist/uniform/dim_32_aplha_0.05/checkpoints/forward_3/model.safetensors +3 -0
- checkpoints /images/cmnist/uniform/dim_32_aplha_0.05/config.yaml +47 -0
- checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.005/checkpoints/backward_4/model.safetensors +3 -0
- checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.005/checkpoints/forward_4/model.safetensors +3 -0
- checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.005/config.yaml +45 -0
- checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.01/checkpoints/backward_4/model.safetensors +3 -0
- checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.01/checkpoints/forward_4/model.safetensors +3 -0
- checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.01/config.yaml +52 -0
- checkpoints /texts/amazon/uniform/dim_100_aplha_0.005/checkpoints/backward_5/model.safetensors +3 -0
- checkpoints /texts/amazon/uniform/dim_100_aplha_0.005/checkpoints/forward_5/model.safetensors +3 -0
- checkpoints /texts/amazon/uniform/dim_100_aplha_0.005/config.yaml +49 -0
- checkpoints /texts/amazon/uniform/dim_100_aplha_0.01/checkpoints/backward_5/model.safetensors +3 -0
- checkpoints /texts/amazon/uniform/dim_100_aplha_0.01/checkpoints/forward_5/model.safetensors +3 -0
- checkpoints /texts/amazon/uniform/dim_100_aplha_0.01/config.yaml +49 -0
- checkpoints /tokenizer_amazon.json +3 -0
- checkpoints /vqgan_celeba_f8_1024.ckpt +3 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,3 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
|
| 27 |
-
*.
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,3 +1,67 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: mit
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
---
|
| 4 |
+
<div align="center">
|
| 5 |
+
|
| 6 |
+
# Categorical Schrödinger Bridge Matching (CSBM)
|
| 7 |
+
|
| 8 |
+
[Grigoriy Ksenofontov](https://scholar.google.com/citations?user=e0mirzYAAAAJ),
|
| 9 |
+
[Alexander Korotin](https://scholar.google.ru/citations?user=1rIIvjAAAAAJ)
|
| 10 |
+
|
| 11 |
+
[](https://arxiv.org/abs/2502.01416)
|
| 12 |
+
[](https://openreview.net/forum?id=RBly0nOr2h)
|
| 13 |
+
[](https://github.com/gregkseno/csbm)
|
| 14 |
+

|
| 15 |
+
[](https://wandb.ai/gregkseno/csbm)
|
| 16 |
+
|
| 17 |
+
</div>
|
| 18 |
+
|
| 19 |
+
This repository hosts the official checkpoints for the paper "Categorical Schrödinger Bridge Matching", accepted at ICML 2025.
|
| 20 |
+
|
| 21 |
+
## 📌 TL;DR
|
| 22 |
+
|
| 23 |
+
This paper extends the Schrödinger Bridge problem to work with discrete time and spaces.
|
| 24 |
+
|
| 25 |
+
<!--  -->
|
| 26 |
+
|
| 27 |
+
## 💾 Checkpoints
|
| 28 |
+
|
| 29 |
+
### CSBM
|
| 30 |
+
|
| 31 |
+
| Dataset | Reference Process | $\alpha$ | $N$ | Saved Iteration |
|
| 32 |
+
| ------------- | ----------------- | ----------- | --------------------- | --------------- |
|
| 33 |
+
| Colored MNIST | **gaussian** | 0.01 | 2, 4, 10, 25, 50, 100 | 3 |
|
| 34 |
+
| Colored MNIST | **uniform** | 0.01, 0.05 | 25 | 3 |
|
| 35 |
+
| CelebA | **uniform** | 0.01, 0.005 | 100 | 4 |
|
| 36 |
+
| Amazon Review | **uniform** | 0.01, 0.005 | 100 | 5 |
|
| 37 |
+
|
| 38 |
+
> [!NOTE]
|
| 39 |
+
> Each experiment directory includes a `config.yaml` file with the full training configuration.
|
| 40 |
+
|
| 41 |
+
### Additional Components
|
| 42 |
+
|
| 43 |
+
1. `vqgan_celeba_f8_1024.ckpt` — **VQ-GAN** pretrained on the CelebA dataset
|
| 44 |
+
2. `tokenizer_amazon.json` — **Tokenizer** trained on the Amazon Reviews dataset
|
| 45 |
+
|
| 46 |
+
## 🎓 Citation
|
| 47 |
+
|
| 48 |
+
```bibtex
|
| 49 |
+
@article{ksenofontov2025categorical,
|
| 50 |
+
title={Categorical {Schr\"odinger} Bridge Matching},
|
| 51 |
+
author={Ksenofontov, Grigoriy and Korotin, Alexander},
|
| 52 |
+
journal={arXiv preprint arXiv:2502.01416},
|
| 53 |
+
year={2025}
|
| 54 |
+
}
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## 🙏 Credits
|
| 58 |
+
|
| 59 |
+
- [Weights & Biases](https://wandb.ai) — experiment-tracking and visualization toolkit;
|
| 60 |
+
- [Hugging Face](https://huggingface.co) — Tokenizers and Accelerate libraries for tokenizer implementation, parallel training, and checkpoint hosting on the Hub;
|
| 61 |
+
- [D3PM](https://github.com/google-research/google-research/tree/master/d3pm) — reference implementation of discrete-diffusion models;
|
| 62 |
+
- [Taming Transformers](https://github.com/CompVis/taming-transformers) — original VQ-GAN codebase;
|
| 63 |
+
- [VQ-Diffusion](https://github.com/microsoft/VQ-Diffusion) — vector-quantized diffusion architecture;
|
| 64 |
+
- [MDLM](https://github.com/kuleshov-group/mdlm) — diffusion architecture for text-generation experiments;
|
| 65 |
+
- [ASBM](https://arxiv.org/abs/2405.14449) — evaluation metrics and baseline models for CelebA face transfer;
|
| 66 |
+
- [Balancing the Style-Content Trade-Off in Sentiment Transfer Using Polarity-Aware Denoising](https://arxiv.org/abs/2312.14708) — processed Amazon Reviews dataset and sentiment-transfer baselines;
|
| 67 |
+
- [Inkscape](https://inkscape.org/) — an excellent open-source editor for vector graphics.
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_10/checkpoints/backward_3/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d64537aa661f8ee3e8516b525dbc0b961eff87cecba8101296bbff5117ae445e
|
| 3 |
+
size 139416416
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_10/checkpoints/forward_3/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5870abf0ad490bbdfcc568ddc0aa82e5b087d46de7420bcb26589e68207dafd
|
| 3 |
+
size 139416416
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_10/config.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
type: images
|
| 3 |
+
dataset: cmnist
|
| 4 |
+
dim: 32
|
| 5 |
+
num_categories: 256
|
| 6 |
+
num_timesteps: 10
|
| 7 |
+
num_skip_steps: 10
|
| 8 |
+
coupling_type: independent
|
| 9 |
+
model:
|
| 10 |
+
in_channels: 3
|
| 11 |
+
num_channels: 64
|
| 12 |
+
num_layers: 2
|
| 13 |
+
ch_mults:
|
| 14 |
+
- 1
|
| 15 |
+
- 2
|
| 16 |
+
- 2
|
| 17 |
+
- 2
|
| 18 |
+
attention_resolution: 16
|
| 19 |
+
num_groups: 32
|
| 20 |
+
num_att_heads: 4
|
| 21 |
+
dropout: 0.1
|
| 22 |
+
prior:
|
| 23 |
+
alpha: 0.01
|
| 24 |
+
type: gaussian
|
| 25 |
+
eps: 1.0e-06
|
| 26 |
+
train:
|
| 27 |
+
batch_size: 128
|
| 28 |
+
low_precision: false
|
| 29 |
+
gradient_accumulation_steps: 1
|
| 30 |
+
iterations: 20
|
| 31 |
+
prior_iterations: 5
|
| 32 |
+
inner_iterations: 20000
|
| 33 |
+
use_mini_batch: true
|
| 34 |
+
ce_loss_coeff: 0.001
|
| 35 |
+
kl_loss_coeff: 1
|
| 36 |
+
mse_loss_coeff: 0
|
| 37 |
+
ema_decay: 0.9999
|
| 38 |
+
optimizer:
|
| 39 |
+
lr: 0.0002
|
| 40 |
+
betas:
|
| 41 |
+
- 0.95
|
| 42 |
+
- 0.99
|
| 43 |
+
eval:
|
| 44 |
+
freq: 1000
|
| 45 |
+
num_samples: 25
|
| 46 |
+
num_trajectories: 4
|
| 47 |
+
num_translations: 2
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_100/checkpoints/backward_3/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f5896a264f6e29453ea4cbba2c625487ad3423ddfda8f93d9a94acaf7f36e145
|
| 3 |
+
size 139416416
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_100/checkpoints/forward_3/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e99e791d3d37ff2b4bfe1f88461526b2216cc1e5dc785917beb8c4e973e3d48d
|
| 3 |
+
size 139416416
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_100/config.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
type: images
|
| 3 |
+
dataset: cmnist
|
| 4 |
+
dim: 32
|
| 5 |
+
num_categories: 256
|
| 6 |
+
num_timesteps: 100
|
| 7 |
+
num_skip_steps: 1
|
| 8 |
+
coupling_type: independent
|
| 9 |
+
model:
|
| 10 |
+
in_channels: 3
|
| 11 |
+
num_channels: 64
|
| 12 |
+
num_layers: 2
|
| 13 |
+
ch_mults:
|
| 14 |
+
- 1
|
| 15 |
+
- 2
|
| 16 |
+
- 2
|
| 17 |
+
- 2
|
| 18 |
+
attention_resolution: 16
|
| 19 |
+
num_groups: 32
|
| 20 |
+
num_att_heads: 4
|
| 21 |
+
dropout: 0.1
|
| 22 |
+
prior:
|
| 23 |
+
alpha: 0.01
|
| 24 |
+
type: gaussian
|
| 25 |
+
eps: 1.0e-06
|
| 26 |
+
train:
|
| 27 |
+
batch_size: 128
|
| 28 |
+
low_precision: false
|
| 29 |
+
gradient_accumulation_steps: 1
|
| 30 |
+
iterations: 20
|
| 31 |
+
prior_iterations: 10
|
| 32 |
+
inner_iterations: 20000
|
| 33 |
+
use_mini_batch: true
|
| 34 |
+
ce_loss_coeff: 0.001
|
| 35 |
+
kl_loss_coeff: 1
|
| 36 |
+
mse_loss_coeff: 0
|
| 37 |
+
ema_decay: 0.9999
|
| 38 |
+
optimizer:
|
| 39 |
+
lr: 0.0002
|
| 40 |
+
betas:
|
| 41 |
+
- 0.95
|
| 42 |
+
- 0.99
|
| 43 |
+
eval:
|
| 44 |
+
freq: 1000
|
| 45 |
+
num_samples: 25
|
| 46 |
+
num_trajectories: 4
|
| 47 |
+
num_translations: 2
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_2/checkpoints/backward_3/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ca3e655484abba1645b826534921224c05db3126fb92d352ee11a89085547a3d
|
| 3 |
+
size 139416416
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_2/checkpoints/forward_3/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b8747b4c6f4cea35f13596dc81992b70750f9a0fe9d8c7ee8c080da52bf3451
|
| 3 |
+
size 139416416
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_2/config.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
type: images
|
| 3 |
+
dataset: cmnist
|
| 4 |
+
dim: 32
|
| 5 |
+
num_categories: 256
|
| 6 |
+
num_timesteps: 2
|
| 7 |
+
num_skip_steps: 50
|
| 8 |
+
coupling_type: independent
|
| 9 |
+
model:
|
| 10 |
+
in_channels: 3
|
| 11 |
+
num_channels: 64
|
| 12 |
+
num_layers: 2
|
| 13 |
+
ch_mults:
|
| 14 |
+
- 1
|
| 15 |
+
- 2
|
| 16 |
+
- 2
|
| 17 |
+
- 2
|
| 18 |
+
attention_resolution: 16
|
| 19 |
+
num_groups: 32
|
| 20 |
+
num_att_heads: 4
|
| 21 |
+
dropout: 0.1
|
| 22 |
+
prior:
|
| 23 |
+
alpha: 0.01
|
| 24 |
+
type: gaussian
|
| 25 |
+
eps: 1.0e-06
|
| 26 |
+
train:
|
| 27 |
+
batch_size: 128
|
| 28 |
+
low_precision: false
|
| 29 |
+
gradient_accumulation_steps: 1
|
| 30 |
+
iterations: 20
|
| 31 |
+
prior_iterations: 5
|
| 32 |
+
inner_iterations: 20000
|
| 33 |
+
use_mini_batch: true
|
| 34 |
+
ce_loss_coeff: 0.001
|
| 35 |
+
kl_loss_coeff: 1
|
| 36 |
+
mse_loss_coeff: 0
|
| 37 |
+
ema_decay: 0.9999
|
| 38 |
+
optimizer:
|
| 39 |
+
lr: 0.0002
|
| 40 |
+
betas:
|
| 41 |
+
- 0.95
|
| 42 |
+
- 0.99
|
| 43 |
+
eval:
|
| 44 |
+
freq: 1000
|
| 45 |
+
num_samples: 25
|
| 46 |
+
num_trajectories: 4
|
| 47 |
+
num_translations: 2
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_25/checkpoints/backward_3/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eb36e26158e2cba3b5b65ba35b7ebeba0817a568f3d9b8eca9d413fc752e4713
|
| 3 |
+
size 139416416
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_25/checkpoints/forward_3/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f44afb3a90eef04257fc4ea5067386bba3a5713cc667dc1f6e7e305b8d92e540
|
| 3 |
+
size 139416416
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_25/config.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
type: images
|
| 3 |
+
dataset: cmnist
|
| 4 |
+
dim: 32
|
| 5 |
+
num_categories: 256
|
| 6 |
+
num_timesteps: 25
|
| 7 |
+
num_skip_steps: 4
|
| 8 |
+
coupling_type: independent
|
| 9 |
+
model:
|
| 10 |
+
in_channels: 3
|
| 11 |
+
num_channels: 64
|
| 12 |
+
num_layers: 2
|
| 13 |
+
ch_mults:
|
| 14 |
+
- 1
|
| 15 |
+
- 2
|
| 16 |
+
- 2
|
| 17 |
+
- 2
|
| 18 |
+
attention_resolution: 16
|
| 19 |
+
num_groups: 32
|
| 20 |
+
num_att_heads: 4
|
| 21 |
+
dropout: 0.1
|
| 22 |
+
prior:
|
| 23 |
+
alpha: 0.01
|
| 24 |
+
type: gaussian
|
| 25 |
+
eps: 1.0e-06
|
| 26 |
+
train:
|
| 27 |
+
batch_size: 128
|
| 28 |
+
low_precision: false
|
| 29 |
+
gradient_accumulation_steps: 1
|
| 30 |
+
iterations: 20
|
| 31 |
+
prior_iterations: 5
|
| 32 |
+
inner_iterations: 20000
|
| 33 |
+
use_mini_batch: true
|
| 34 |
+
ce_loss_coeff: 0.001
|
| 35 |
+
kl_loss_coeff: 1
|
| 36 |
+
mse_loss_coeff: 0
|
| 37 |
+
ema_decay: 0.9999
|
| 38 |
+
optimizer:
|
| 39 |
+
lr: 0.0002
|
| 40 |
+
betas:
|
| 41 |
+
- 0.95
|
| 42 |
+
- 0.99
|
| 43 |
+
eval:
|
| 44 |
+
freq: 1000
|
| 45 |
+
num_samples: 25
|
| 46 |
+
num_trajectories: 4
|
| 47 |
+
num_translations: 2
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_4/checkpoints/backward_3/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba218a2b25c2e7682efee04227fe1c736d6f24905d7adda7cb254ef6422787ec
|
| 3 |
+
size 139416416
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_4/checkpoints/forward_3/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:10d5cb66e1770eb54c2d872c24d584125563facf67af25799bf561edf0c64ac1
|
| 3 |
+
size 139416416
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_4/config.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
type: images
|
| 3 |
+
dataset: cmnist
|
| 4 |
+
dim: 32
|
| 5 |
+
num_categories: 256
|
| 6 |
+
num_timesteps: 4
|
| 7 |
+
num_skip_steps: 25
|
| 8 |
+
coupling_type: independent
|
| 9 |
+
model:
|
| 10 |
+
in_channels: 3
|
| 11 |
+
num_channels: 64
|
| 12 |
+
num_layers: 2
|
| 13 |
+
ch_mults:
|
| 14 |
+
- 1
|
| 15 |
+
- 2
|
| 16 |
+
- 2
|
| 17 |
+
- 2
|
| 18 |
+
attention_resolution: 16
|
| 19 |
+
num_groups: 32
|
| 20 |
+
num_att_heads: 4
|
| 21 |
+
dropout: 0.1
|
| 22 |
+
prior:
|
| 23 |
+
alpha: 0.01
|
| 24 |
+
type: gaussian
|
| 25 |
+
eps: 1.0e-06
|
| 26 |
+
train:
|
| 27 |
+
batch_size: 128
|
| 28 |
+
low_precision: false
|
| 29 |
+
gradient_accumulation_steps: 1
|
| 30 |
+
iterations: 20
|
| 31 |
+
prior_iterations: 5
|
| 32 |
+
inner_iterations: 20000
|
| 33 |
+
use_mini_batch: true
|
| 34 |
+
ce_loss_coeff: 0.001
|
| 35 |
+
kl_loss_coeff: 1
|
| 36 |
+
mse_loss_coeff: 0
|
| 37 |
+
ema_decay: 0.9999
|
| 38 |
+
optimizer:
|
| 39 |
+
lr: 0.0002
|
| 40 |
+
betas:
|
| 41 |
+
- 0.95
|
| 42 |
+
- 0.99
|
| 43 |
+
eval:
|
| 44 |
+
freq: 1000
|
| 45 |
+
num_samples: 25
|
| 46 |
+
num_trajectories: 4
|
| 47 |
+
num_translations: 2
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_50/checkpoints/backward_3/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2390760bcb242299d49d881e2c6c869739ead285f4067634ac9c1da8a9eaa2c4
|
| 3 |
+
size 139416416
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_50/checkpoints/forward_3/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:57aad3763421081c13d9435c2dc80ace6aaf68397d3a8fc512f9ad19d85e660b
|
| 3 |
+
size 139416416
|
checkpoints /images/cmnist/gaussian/dim_32_aplha_0.01_num_timesteps_50/config.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
type: images
|
| 3 |
+
dataset: cmnist
|
| 4 |
+
dim: 32
|
| 5 |
+
num_categories: 256
|
| 6 |
+
num_timesteps: 50
|
| 7 |
+
num_skip_steps: 2
|
| 8 |
+
coupling_type: independent
|
| 9 |
+
model:
|
| 10 |
+
in_channels: 3
|
| 11 |
+
num_channels: 64
|
| 12 |
+
num_layers: 2
|
| 13 |
+
ch_mults:
|
| 14 |
+
- 1
|
| 15 |
+
- 2
|
| 16 |
+
- 2
|
| 17 |
+
- 2
|
| 18 |
+
attention_resolution: 16
|
| 19 |
+
num_groups: 32
|
| 20 |
+
num_att_heads: 4
|
| 21 |
+
dropout: 0.1
|
| 22 |
+
prior:
|
| 23 |
+
alpha: 0.01
|
| 24 |
+
type: gaussian
|
| 25 |
+
eps: 1.0e-06
|
| 26 |
+
train:
|
| 27 |
+
batch_size: 128
|
| 28 |
+
low_precision: false
|
| 29 |
+
gradient_accumulation_steps: 1
|
| 30 |
+
iterations: 20
|
| 31 |
+
prior_iterations: 10
|
| 32 |
+
inner_iterations: 20000
|
| 33 |
+
use_mini_batch: true
|
| 34 |
+
ce_loss_coeff: 0.001
|
| 35 |
+
kl_loss_coeff: 1
|
| 36 |
+
mse_loss_coeff: 0
|
| 37 |
+
ema_decay: 0.9999
|
| 38 |
+
optimizer:
|
| 39 |
+
lr: 0.0002
|
| 40 |
+
betas:
|
| 41 |
+
- 0.95
|
| 42 |
+
- 0.99
|
| 43 |
+
eval:
|
| 44 |
+
freq: 1000
|
| 45 |
+
num_samples: 25
|
| 46 |
+
num_trajectories: 4
|
| 47 |
+
num_translations: 2
|
checkpoints /images/cmnist/uniform/dim_32_aplha_0.01/checkpoints/backward_3/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:125c3702db67b5d8c46969e8772f8ffed3fdbc399808968a317086e2d9fe34e1
|
| 3 |
+
size 139416416
|
checkpoints /images/cmnist/uniform/dim_32_aplha_0.01/checkpoints/forward_3/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f743431c8174d039c6f3a48a662725893eee873331a149dc0c788f3138d914d6
|
| 3 |
+
size 139416416
|
checkpoints /images/cmnist/uniform/dim_32_aplha_0.01/config.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
type: images
|
| 3 |
+
dataset: cmnist
|
| 4 |
+
dim: 32
|
| 5 |
+
num_categories: 256
|
| 6 |
+
num_timesteps: 25
|
| 7 |
+
num_skip_steps: 4
|
| 8 |
+
coupling_type: independent
|
| 9 |
+
model:
|
| 10 |
+
in_channels: 3
|
| 11 |
+
num_channels: 64
|
| 12 |
+
num_layers: 2
|
| 13 |
+
ch_mults:
|
| 14 |
+
- 1
|
| 15 |
+
- 2
|
| 16 |
+
- 2
|
| 17 |
+
- 2
|
| 18 |
+
attention_resolution: 16
|
| 19 |
+
num_groups: 32
|
| 20 |
+
num_att_heads: 4
|
| 21 |
+
dropout: 0.1
|
| 22 |
+
prior:
|
| 23 |
+
alpha: 0.01
|
| 24 |
+
type: uniform
|
| 25 |
+
eps: 1.0e-06
|
| 26 |
+
train:
|
| 27 |
+
batch_size: 128
|
| 28 |
+
low_precision: false
|
| 29 |
+
gradient_accumulation_steps: 1
|
| 30 |
+
iterations: 20
|
| 31 |
+
prior_iterations: 5
|
| 32 |
+
inner_iterations: 20000
|
| 33 |
+
use_mini_batch: false
|
| 34 |
+
ce_loss_coeff: 0.001
|
| 35 |
+
kl_loss_coeff: 1
|
| 36 |
+
mse_loss_coeff: 0
|
| 37 |
+
ema_decay: 0.9999
|
| 38 |
+
optimizer:
|
| 39 |
+
lr: 0.0002
|
| 40 |
+
betas:
|
| 41 |
+
- 0.95
|
| 42 |
+
- 0.99
|
| 43 |
+
eval:
|
| 44 |
+
freq: 1000
|
| 45 |
+
num_samples: 25
|
| 46 |
+
num_trajectories: 4
|
| 47 |
+
num_translations: 2
|
checkpoints /images/cmnist/uniform/dim_32_aplha_0.05/checkpoints/backward_3/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7fdbd67065bb6ef3933b370fee02c4d935103d8e32a8d7d358f27fb4ba78946c
|
| 3 |
+
size 139416416
|
checkpoints /images/cmnist/uniform/dim_32_aplha_0.05/checkpoints/forward_3/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f329e98da91a5c305b820d013ec4e816d83ce54bd3c33cf42e8b70698c937a2b
|
| 3 |
+
size 139416416
|
checkpoints /images/cmnist/uniform/dim_32_aplha_0.05/config.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
type: images
|
| 3 |
+
dataset: cmnist
|
| 4 |
+
dim: 32
|
| 5 |
+
num_categories: 256
|
| 6 |
+
num_timesteps: 25
|
| 7 |
+
num_skip_steps: 4
|
| 8 |
+
coupling_type: independent
|
| 9 |
+
model:
|
| 10 |
+
in_channels: 3
|
| 11 |
+
num_channels: 64
|
| 12 |
+
num_layers: 2
|
| 13 |
+
ch_mults:
|
| 14 |
+
- 1
|
| 15 |
+
- 2
|
| 16 |
+
- 2
|
| 17 |
+
- 2
|
| 18 |
+
attention_resolution: 16
|
| 19 |
+
num_groups: 32
|
| 20 |
+
num_att_heads: 4
|
| 21 |
+
dropout: 0.1
|
| 22 |
+
prior:
|
| 23 |
+
alpha: 0.05
|
| 24 |
+
type: uniform
|
| 25 |
+
eps: 1.0e-06
|
| 26 |
+
train:
|
| 27 |
+
batch_size: 128
|
| 28 |
+
low_precision: false
|
| 29 |
+
gradient_accumulation_steps: 1
|
| 30 |
+
iterations: 20
|
| 31 |
+
prior_iterations: 5
|
| 32 |
+
inner_iterations: 20000
|
| 33 |
+
use_mini_batch: false
|
| 34 |
+
ce_loss_coeff: 0.001
|
| 35 |
+
kl_loss_coeff: 1
|
| 36 |
+
mse_loss_coeff: 0
|
| 37 |
+
ema_decay: 0.9999
|
| 38 |
+
optimizer:
|
| 39 |
+
lr: 0.0002
|
| 40 |
+
betas:
|
| 41 |
+
- 0.95
|
| 42 |
+
- 0.99
|
| 43 |
+
eval:
|
| 44 |
+
freq: 1000
|
| 45 |
+
num_samples: 25
|
| 46 |
+
num_trajectories: 4
|
| 47 |
+
num_translations: 2
|
checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.005/checkpoints/backward_4/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fee703324c0024cfa87965f8ce98ecac4d9e199b19b7f07e56189e96bdae4483
|
| 3 |
+
size 372373768
|
checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.005/checkpoints/forward_4/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1f8b9bd34a828628b91a1a32ef82c66860e7d00f6cd017ded46af500d15a35a5
|
| 3 |
+
size 372373768
|
checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.005/config.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
type: quantized_images
|
| 3 |
+
dataset: celeba
|
| 4 |
+
dim: 128
|
| 5 |
+
latent_dim: 16
|
| 6 |
+
num_categories: 1024
|
| 7 |
+
num_timesteps: 100
|
| 8 |
+
num_skip_steps: 1
|
| 9 |
+
coupling_type: independent
|
| 10 |
+
train_test_split: 0.9
|
| 11 |
+
model:
|
| 12 |
+
hidden_dim: 256
|
| 13 |
+
num_channels: 4
|
| 14 |
+
num_layers: 18
|
| 15 |
+
num_att_heads: 16
|
| 16 |
+
dropout: 0
|
| 17 |
+
codec:
|
| 18 |
+
ckpt_path: checkpoints/vqgan_celeba_f8_1024.ckpt
|
| 19 |
+
config_path: configs/vqgan_celeba_f8_1024.yaml
|
| 20 |
+
prior:
|
| 21 |
+
alpha: 0.005
|
| 22 |
+
type: uniform
|
| 23 |
+
eps: 1e-6
|
| 24 |
+
train:
|
| 25 |
+
batch_size: 32
|
| 26 |
+
low_precision: false
|
| 27 |
+
gradient_accumulation_steps: 1
|
| 28 |
+
iterations: 20
|
| 29 |
+
prior_iterations: 20
|
| 30 |
+
inner_iterations: 20000
|
| 31 |
+
use_mini_batch: false
|
| 32 |
+
ce_loss_coeff: 0.001
|
| 33 |
+
kl_loss_coeff: 1
|
| 34 |
+
mse_loss_coeff: 0
|
| 35 |
+
ema_decay: 0.999
|
| 36 |
+
optimizer:
|
| 37 |
+
lr: 0.0004
|
| 38 |
+
betas:
|
| 39 |
+
- 0.95
|
| 40 |
+
- 0.99
|
| 41 |
+
eval:
|
| 42 |
+
freq: 1000
|
| 43 |
+
num_samples: 25
|
| 44 |
+
num_trajectories: 4
|
| 45 |
+
num_translations: 1
|
checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.01/checkpoints/backward_4/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:08630e97e5a15fcdb4b51b2ad17f1a6362b48d9b8bc502cd6e517c483ef53c3d
|
| 3 |
+
size 372373768
|
checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.01/checkpoints/forward_4/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:814c99bd6124b3ec34af0b68c395f8c6842c79ec8c01a034314e023c4b754fa0
|
| 3 |
+
size 372373768
|
checkpoints /quantized_images/celeba/uniform/dim_128_aplha_0.01/config.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
type: quantized_images # choices: "toy", "images", "quantized_images", "texts"
|
| 3 |
+
dataset: celeba
|
| 4 |
+
dim: 128
|
| 5 |
+
latent_dim: 16
|
| 6 |
+
num_categories: 1024
|
| 7 |
+
num_timesteps: 100
|
| 8 |
+
num_skip_steps: 1
|
| 9 |
+
coupling_type: independent # choices: "independent", "prior"
|
| 10 |
+
train_test_split: 0.9
|
| 11 |
+
model:
|
| 12 |
+
hidden_dim: 256
|
| 13 |
+
num_channels: 4
|
| 14 |
+
num_layers: 18
|
| 15 |
+
num_att_heads: 16
|
| 16 |
+
dropout: 0
|
| 17 |
+
codec:
|
| 18 |
+
ckpt_path: checkpoints/vqgan_celeba_f8_1024.ckpt
|
| 19 |
+
config_path: configs/vqgan_celeba_f8_1024.yaml
|
| 20 |
+
prior:
|
| 21 |
+
alpha: 0.01
|
| 22 |
+
type: uniform # choices: "uniform", "gaussian", "centroid_gaussian", "von_mises"
|
| 23 |
+
eps: 1e-6
|
| 24 |
+
train:
|
| 25 |
+
batch_size: 32
|
| 26 |
+
low_precision: false
|
| 27 |
+
gradient_accumulation_steps: 4
|
| 28 |
+
iterations: 20
|
| 29 |
+
prior_iterations: 20
|
| 30 |
+
inner_iterations: 20000
|
| 31 |
+
use_mini_batch: false
|
| 32 |
+
ce_loss_coeff: 0.001
|
| 33 |
+
kl_loss_coeff: 1
|
| 34 |
+
mse_loss_coeff: 0
|
| 35 |
+
ema_decay: 0.999
|
| 36 |
+
optimizer:
|
| 37 |
+
lr: 0.0004
|
| 38 |
+
betas: [0.95, 0.99]
|
| 39 |
+
# weight_decay: 4.5e-2
|
| 40 |
+
# scheduler:
|
| 41 |
+
# factor: 0.5
|
| 42 |
+
# patience: 1000
|
| 43 |
+
# min_lr: 1.0e-6
|
| 44 |
+
# threshold: 1.0e-1
|
| 45 |
+
# threshold_mode: rel
|
| 46 |
+
# warmup_lr: 4.5e-4 # the lr to be touched after warmup
|
| 47 |
+
# warmup: 10000
|
| 48 |
+
eval:
|
| 49 |
+
freq: 1000
|
| 50 |
+
num_samples: 25
|
| 51 |
+
num_trajectories: 4 # How many trajecotries
|
| 52 |
+
num_translations: 1 # How many times sample trajecotry from single point
|
checkpoints /texts/amazon/uniform/dim_100_aplha_0.005/checkpoints/backward_5/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f0bf711f8a55bf39b769f4541affe5d6551975fc68f9628199ae593843b6205e
|
| 3 |
+
size 399311828
|
checkpoints /texts/amazon/uniform/dim_100_aplha_0.005/checkpoints/forward_5/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:945bd9efd2d46e284aed78db5beaf11504f67602d9dc5f0d9c159d7829753dc1
|
| 3 |
+
size 399311828
|
checkpoints /texts/amazon/uniform/dim_100_aplha_0.005/config.yaml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
type: texts
|
| 3 |
+
dataset: amazon
|
| 4 |
+
dim: 100
|
| 5 |
+
num_categories: 8192
|
| 6 |
+
num_timesteps: 100
|
| 7 |
+
num_skip_steps: 1
|
| 8 |
+
coupling_type: independent
|
| 9 |
+
train_test_split: null
|
| 10 |
+
model:
|
| 11 |
+
config:
|
| 12 |
+
name: small
|
| 13 |
+
type: ddit
|
| 14 |
+
hidden_size: 768
|
| 15 |
+
cond_dim: 128
|
| 16 |
+
length: 1024
|
| 17 |
+
n_blocks: 12
|
| 18 |
+
n_heads: 12
|
| 19 |
+
scale_by_sigma: true
|
| 20 |
+
dropout: 0.1
|
| 21 |
+
tie_word_embeddings: false
|
| 22 |
+
tokenizer:
|
| 23 |
+
path: checkpoints/tokenizer_amazon.json
|
| 24 |
+
prior:
|
| 25 |
+
alpha: 0.005
|
| 26 |
+
type: uniform
|
| 27 |
+
eps: 1.0e-20
|
| 28 |
+
train:
|
| 29 |
+
batch_size: 32
|
| 30 |
+
low_precision: true
|
| 31 |
+
gradient_accumulation_steps: 1
|
| 32 |
+
iterations: 20
|
| 33 |
+
prior_iterations: 20
|
| 34 |
+
inner_iterations: 20000
|
| 35 |
+
use_mini_batch: false
|
| 36 |
+
ce_loss_coeff: 0.001
|
| 37 |
+
kl_loss_coeff: 1
|
| 38 |
+
mse_loss_coeff: 0
|
| 39 |
+
ema_decay: 0.999
|
| 40 |
+
optimizer:
|
| 41 |
+
lr: 0.0004
|
| 42 |
+
betas:
|
| 43 |
+
- 0.95
|
| 44 |
+
- 0.99
|
| 45 |
+
eval:
|
| 46 |
+
freq: 1000
|
| 47 |
+
num_samples: 32
|
| 48 |
+
num_trajectories: 1
|
| 49 |
+
num_translations: 1
|
checkpoints /texts/amazon/uniform/dim_100_aplha_0.01/checkpoints/backward_5/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1cac9cc955db1256fa5cf9d80da014c34d30a80dedb256e47dae4de486b12c62
|
| 3 |
+
size 399311828
|
checkpoints /texts/amazon/uniform/dim_100_aplha_0.01/checkpoints/forward_5/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6820f1d460f3a023f00c1ca89d0dbf6ab8965906e3647ec2d77418661104138c
|
| 3 |
+
size 399311828
|
checkpoints /texts/amazon/uniform/dim_100_aplha_0.01/config.yaml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
type: texts
|
| 3 |
+
dataset: amazon
|
| 4 |
+
dim: 100
|
| 5 |
+
num_categories: 8192
|
| 6 |
+
num_timesteps: 100
|
| 7 |
+
num_skip_steps: 1
|
| 8 |
+
coupling_type: independent
|
| 9 |
+
train_test_split: null
|
| 10 |
+
model:
|
| 11 |
+
config:
|
| 12 |
+
name: small
|
| 13 |
+
type: ddit
|
| 14 |
+
hidden_size: 768
|
| 15 |
+
cond_dim: 128
|
| 16 |
+
length: 1024
|
| 17 |
+
n_blocks: 12
|
| 18 |
+
n_heads: 12
|
| 19 |
+
scale_by_sigma: true
|
| 20 |
+
dropout: 0.1
|
| 21 |
+
tie_word_embeddings: false
|
| 22 |
+
tokenizer:
|
| 23 |
+
path: checkpoints/tokenizer_amazon.json
|
| 24 |
+
prior:
|
| 25 |
+
alpha: 0.01
|
| 26 |
+
type: uniform
|
| 27 |
+
eps: 1.0e-20
|
| 28 |
+
train:
|
| 29 |
+
batch_size: 32
|
| 30 |
+
low_precision: true
|
| 31 |
+
gradient_accumulation_steps: 1
|
| 32 |
+
iterations: 20
|
| 33 |
+
prior_iterations: 20
|
| 34 |
+
inner_iterations: 20000
|
| 35 |
+
use_mini_batch: false
|
| 36 |
+
ce_loss_coeff: 0.001
|
| 37 |
+
kl_loss_coeff: 1
|
| 38 |
+
mse_loss_coeff: 0
|
| 39 |
+
ema_decay: 0.999
|
| 40 |
+
optimizer:
|
| 41 |
+
lr: 0.0004
|
| 42 |
+
betas:
|
| 43 |
+
- 0.95
|
| 44 |
+
- 0.99
|
| 45 |
+
eval:
|
| 46 |
+
freq: 1000
|
| 47 |
+
num_samples: 32
|
| 48 |
+
num_trajectories: 1
|
| 49 |
+
num_translations: 1
|
checkpoints /tokenizer_amazon.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:52d20f72ce3968e1af9b0f0fbea672c7915db6db5fa1ef639d32ee95c5d53a22
|
| 3 |
+
size 309914
|
checkpoints /vqgan_celeba_f8_1024.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:59e325381f46e6b43a5e4ded5d651928d52bc8910ae8a4114be4c871e827fd56
|
| 3 |
+
size 936814106
|