Zhizhou Zhong
commited on
feat: data preprocessing and training (#294)
Browse files* docs: update readme
* docs: update readme
* feat: training codes
* feat: data preprocess
* docs: release training
- .gitignore +5 -1
- README.md +74 -3
- configs/training/gpu.yaml +21 -0
- configs/training/preprocess.yaml +31 -0
- configs/training/stage1.yaml +89 -0
- configs/training/stage2.yaml +89 -0
- configs/training/syncnet.yaml +19 -0
- inference.sh +1 -1
- musetalk/data/audio.py +168 -0
- musetalk/data/dataset.py +607 -0
- musetalk/data/sample_method.py +233 -0
- musetalk/loss/basic_loss.py +81 -0
- musetalk/loss/conv.py +44 -0
- musetalk/loss/discriminator.py +145 -0
- musetalk/loss/resnet.py +152 -0
- musetalk/loss/syncnet.py +95 -0
- musetalk/loss/vgg_face.py +237 -0
- musetalk/models/syncnet.py +240 -0
- musetalk/utils/training_utils.py +337 -0
- musetalk/utils/utils.py +250 -1
- scripts/preprocess.py +322 -0
- train.py +580 -0
- train.sh +34 -0
.gitignore
CHANGED
|
@@ -8,4 +8,8 @@ results/
|
|
| 8 |
./models
|
| 9 |
**/__pycache__/
|
| 10 |
*.py[cod]
|
| 11 |
-
*$py.class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
./models
|
| 9 |
**/__pycache__/
|
| 10 |
*.py[cod]
|
| 11 |
+
*$py.class
|
| 12 |
+
dataset/
|
| 13 |
+
ffmpeg*
|
| 14 |
+
debug
|
| 15 |
+
exp_out
|
README.md
CHANGED
|
@@ -130,8 +130,9 @@ https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde
|
|
| 130 |
- [x] codes for real-time inference.
|
| 131 |
- [x] [technical report](https://arxiv.org/abs/2410.10122v2).
|
| 132 |
- [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
|
| 133 |
-
- [x] realtime inference code for 1.5 version
|
| 134 |
-
- [
|
|
|
|
| 135 |
|
| 136 |
|
| 137 |
# Getting Started
|
|
@@ -187,6 +188,7 @@ huggingface-cli download TMElyralab/MuseTalk --local-dir models/
|
|
| 187 |
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
|
| 188 |
- [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
|
| 189 |
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
|
|
|
|
| 190 |
|
| 191 |
|
| 192 |
Finally, these weights should be organized in `models` as follows:
|
|
@@ -198,6 +200,8 @@ Finally, these weights should be organized in `models` as follows:
|
|
| 198 |
├── musetalkV15
|
| 199 |
│ └── musetalk.json
|
| 200 |
│ └── unet.pth
|
|
|
|
|
|
|
| 201 |
├── dwpose
|
| 202 |
│ └── dw-ll_ucoco_384.pth
|
| 203 |
├── face-parse-bisent
|
|
@@ -265,6 +269,73 @@ For faster generation without saving images, you can use:
|
|
| 265 |
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
|
| 266 |
```
|
| 267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
## TestCases For 1.0
|
| 269 |
<table class="center">
|
| 270 |
<tr style="font-weight: bolder;text-align:center;">
|
|
@@ -368,7 +439,7 @@ python -m scripts.inference --inference_config configs/inference/test.yaml --bbo
|
|
| 368 |
As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
|
| 369 |
|
| 370 |
# Acknowledgement
|
| 371 |
-
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
|
| 372 |
1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
|
| 373 |
1. MuseTalk has been built on [HDTF](https://github.com/MRzzm/HDTF) datasets.
|
| 374 |
|
|
|
|
| 130 |
- [x] codes for real-time inference.
|
| 131 |
- [x] [technical report](https://arxiv.org/abs/2410.10122v2).
|
| 132 |
- [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
|
| 133 |
+
- [x] realtime inference code for 1.5 version.
|
| 134 |
+
- [x] training and data preprocessing codes.
|
| 135 |
+
- [ ] **always** welcome to submit issues and PRs to improve this repository! 😊
|
| 136 |
|
| 137 |
|
| 138 |
# Getting Started
|
|
|
|
| 188 |
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
|
| 189 |
- [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
|
| 190 |
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
|
| 191 |
+
- [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main)
|
| 192 |
|
| 193 |
|
| 194 |
Finally, these weights should be organized in `models` as follows:
|
|
|
|
| 200 |
├── musetalkV15
|
| 201 |
│ └── musetalk.json
|
| 202 |
│ └── unet.pth
|
| 203 |
+
├── syncnet
|
| 204 |
+
│ └── latentsync_syncnet.pt
|
| 205 |
├── dwpose
|
| 206 |
│ └── dw-ll_ucoco_384.pth
|
| 207 |
├── face-parse-bisent
|
|
|
|
| 269 |
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
|
| 270 |
```
|
| 271 |
|
| 272 |
+
## Training
|
| 273 |
+
|
| 274 |
+
### Data Preparation
|
| 275 |
+
To train MuseTalk, you need to prepare your dataset following these steps:
|
| 276 |
+
|
| 277 |
+
1. **Place your source videos**
|
| 278 |
+
|
| 279 |
+
For example, if you're using the HDTF dataset, place all your video files in `./dataset/HDTF/source`.
|
| 280 |
+
|
| 281 |
+
2. **Run the preprocessing script**
|
| 282 |
+
```bash
|
| 283 |
+
python -m scripts.preprocess --config ./configs/training/preprocess.yaml
|
| 284 |
+
```
|
| 285 |
+
This script will:
|
| 286 |
+
- Extract frames from videos
|
| 287 |
+
- Detect and align faces
|
| 288 |
+
- Generate audio features
|
| 289 |
+
- Create the necessary data structure for training
|
| 290 |
+
|
| 291 |
+
### Training Process
|
| 292 |
+
After data preprocessing, you can start the training process:
|
| 293 |
+
|
| 294 |
+
1. **First Stage**
|
| 295 |
+
```bash
|
| 296 |
+
sh train.sh stage1
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
2. **Second Stage**
|
| 300 |
+
```bash
|
| 301 |
+
sh train.sh stage2
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
### Configuration Adjustment
|
| 305 |
+
Before starting the training, you should adjust the configuration files according to your hardware and requirements:
|
| 306 |
+
|
| 307 |
+
1. **GPU Configuration** (`configs/training/gpu.yaml`):
|
| 308 |
+
- `gpu_ids`: Specify the GPU IDs you want to use (e.g., "0,1,2,3")
|
| 309 |
+
- `num_processes`: Set this to match the number of GPUs you're using
|
| 310 |
+
|
| 311 |
+
2. **Stage 1 Configuration** (`configs/training/stage1.yaml`):
|
| 312 |
+
- `data.train_bs`: Adjust batch size based on your GPU memory (default: 32)
|
| 313 |
+
- `data.n_sample_frames`: Number of sampled frames per video (default: 1)
|
| 314 |
+
|
| 315 |
+
3. **Stage 2 Configuration** (`configs/training/stage2.yaml`):
|
| 316 |
+
- `random_init_unet`: Must be set to `False` to use the model from stage 1
|
| 317 |
+
- `data.train_bs`: Smaller batch size due to high GPU memory cost (default: 2)
|
| 318 |
+
- `data.n_sample_frames`: Higher value for temporal consistency (default: 16)
|
| 319 |
+
- `solver.gradient_accumulation_steps`: Increase to simulate larger batch sizes (default: 8)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
### GPU Memory Requirements
|
| 323 |
+
Based on our testing on a machine with 8 NVIDIA H20 GPUs:
|
| 324 |
+
|
| 325 |
+
#### Stage 1 Memory Usage
|
| 326 |
+
| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
|
| 327 |
+
|:----------:|:----------------------:|:--------------:|:--------------:|
|
| 328 |
+
| 8 | 1 | ~32GB | |
|
| 329 |
+
| 16 | 1 | ~45GB | |
|
| 330 |
+
| 32 | 1 | ~74GB | ✓ |
|
| 331 |
+
|
| 332 |
+
#### Stage 2 Memory Usage
|
| 333 |
+
| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
|
| 334 |
+
|:----------:|:----------------------:|:--------------:|:--------------:|
|
| 335 |
+
| 1 | 8 | ~54GB | |
|
| 336 |
+
| 2 | 2 | ~80GB | |
|
| 337 |
+
| 2 | 8 | ~85GB | ✓ |
|
| 338 |
+
|
| 339 |
## TestCases For 1.0
|
| 340 |
<table class="center">
|
| 341 |
<tr style="font-weight: bolder;text-align:center;">
|
|
|
|
| 439 |
As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
|
| 440 |
|
| 441 |
# Acknowledgement
|
| 442 |
+
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch) and [LatentSync](https://huggingface.co/ByteDance/LatentSync/tree/main).
|
| 443 |
1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
|
| 444 |
1. MuseTalk has been built on [HDTF](https://github.com/MRzzm/HDTF) datasets.
|
| 445 |
|
configs/training/gpu.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: True
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
offload_optimizer_device: none
|
| 5 |
+
offload_param_device: none
|
| 6 |
+
zero3_init_flag: False
|
| 7 |
+
zero_stage: 2
|
| 8 |
+
|
| 9 |
+
distributed_type: DEEPSPEED
|
| 10 |
+
downcast_bf16: 'no'
|
| 11 |
+
gpu_ids: "5, 7" # modify this according to your GPU number
|
| 12 |
+
machine_rank: 0
|
| 13 |
+
main_training_function: main
|
| 14 |
+
num_machines: 1
|
| 15 |
+
num_processes: 2 # it should be the same as the number of GPUs
|
| 16 |
+
rdzv_backend: static
|
| 17 |
+
same_network: true
|
| 18 |
+
tpu_env: []
|
| 19 |
+
tpu_use_cluster: false
|
| 20 |
+
tpu_use_sudo: false
|
| 21 |
+
use_cpu: false
|
configs/training/preprocess.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
clip_len_second: 30 # the length of the video clip
|
| 2 |
+
video_root_raw: "./dataset/HDTF/source/" # the path of the original video
|
| 3 |
+
val_list_hdtf:
|
| 4 |
+
- RD_Radio7_000
|
| 5 |
+
- RD_Radio8_000
|
| 6 |
+
- RD_Radio9_000
|
| 7 |
+
- WDA_TinaSmith_000
|
| 8 |
+
- WDA_TomCarper_000
|
| 9 |
+
- WDA_TomPerez_000
|
| 10 |
+
- WDA_TomUdall_000
|
| 11 |
+
- WDA_VeronicaEscobar0_000
|
| 12 |
+
- WDA_VeronicaEscobar1_000
|
| 13 |
+
- WDA_WhipJimClyburn_000
|
| 14 |
+
- WDA_XavierBecerra_000
|
| 15 |
+
- WDA_XavierBecerra_001
|
| 16 |
+
- WDA_XavierBecerra_002
|
| 17 |
+
- WDA_ZoeLofgren_000
|
| 18 |
+
- WRA_SteveScalise1_000
|
| 19 |
+
- WRA_TimScott_000
|
| 20 |
+
- WRA_ToddYoung_000
|
| 21 |
+
- WRA_TomCotton_000
|
| 22 |
+
- WRA_TomPrice_000
|
| 23 |
+
- WRA_VickyHartzler_000
|
| 24 |
+
|
| 25 |
+
# following dir will be automatically generated
|
| 26 |
+
video_root_25fps: "./dataset/HDTF/video_root_25fps/"
|
| 27 |
+
video_file_list: "./dataset/HDTF/video_file_list.txt"
|
| 28 |
+
video_audio_clip_root: "./dataset/HDTF/video_audio_clip_root/"
|
| 29 |
+
meta_root: "./dataset/HDTF/meta/"
|
| 30 |
+
video_clip_file_list_train: "./dataset/HDTF/train.txt"
|
| 31 |
+
video_clip_file_list_val: "./dataset/HDTF/val.txt"
|
configs/training/stage1.yaml
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
exp_name: 'test' # Name of the experiment
|
| 2 |
+
output_dir: './exp_out/stage1/' # Directory to save experiment outputs
|
| 3 |
+
unet_sub_folder: musetalk # Subfolder name for UNet model
|
| 4 |
+
random_init_unet: True # Whether to randomly initialize UNet (stage1) or use pretrained weights (stage2)
|
| 5 |
+
whisper_path: "./models/whisper" # Path to the Whisper model
|
| 6 |
+
pretrained_model_name_or_path: "./models" # Path to pretrained models
|
| 7 |
+
resume_from_checkpoint: True # Whether to resume training from a checkpoint
|
| 8 |
+
padding_pixel_mouth: 10 # Number of pixels to pad around the mouth region
|
| 9 |
+
vae_type: "sd-vae" # Type of VAE model to use
|
| 10 |
+
# Validation parameters
|
| 11 |
+
num_images_to_keep: 8 # Number of validation images to keep
|
| 12 |
+
ref_dropout_rate: 0 # Dropout rate for reference images
|
| 13 |
+
syncnet_config_path: "./configs/training/syncnet.yaml" # Path to SyncNet configuration
|
| 14 |
+
use_adapted_weight: False # Whether to use adapted weights for loss calculation
|
| 15 |
+
cropping_jaw2edge_margin_mean: 10 # Mean margin for jaw-to-edge cropping
|
| 16 |
+
cropping_jaw2edge_margin_std: 10 # Standard deviation for jaw-to-edge cropping
|
| 17 |
+
crop_type: "crop_resize" # Type of cropping method
|
| 18 |
+
random_margin_method: "normal" # Method for random margin generation
|
| 19 |
+
num_backward_frames: 16 # Number of frames to use for backward pass in SyncNet
|
| 20 |
+
|
| 21 |
+
data:
|
| 22 |
+
dataset_key: "HDTF" # Dataset to use for training
|
| 23 |
+
train_bs: 32 # Training batch size (actual batch size is train_bs*n_sample_frames)
|
| 24 |
+
image_size: 256 # Size of input images
|
| 25 |
+
n_sample_frames: 1 # Number of frames to sample per batch
|
| 26 |
+
num_workers: 8 # Number of data loading workers
|
| 27 |
+
audio_padding_length_left: 2 # Left padding length for audio features
|
| 28 |
+
audio_padding_length_right: 2 # Right padding length for audio features
|
| 29 |
+
sample_method: pose_similarity_and_mouth_dissimilarity # Method for sampling frames
|
| 30 |
+
top_k_ratio: 0.51 # Ratio for top-k sampling
|
| 31 |
+
contorl_face_min_size: True # Whether to control minimum face size
|
| 32 |
+
min_face_size: 150 # Minimum face size in pixels
|
| 33 |
+
|
| 34 |
+
loss_params:
|
| 35 |
+
l1_loss: 1.0 # Weight for L1 loss
|
| 36 |
+
vgg_loss: 0.01 # Weight for VGG perceptual loss
|
| 37 |
+
vgg_layer_weight: [1, 1, 1, 1, 1] # Weights for different VGG layers
|
| 38 |
+
pyramid_scale: [1, 0.5, 0.25, 0.125] # Scales for image pyramid
|
| 39 |
+
gan_loss: 0 # Weight for GAN loss
|
| 40 |
+
fm_loss: [1.0, 1.0, 1.0, 1.0] # Weights for feature matching loss
|
| 41 |
+
sync_loss: 0 # Weight for sync loss
|
| 42 |
+
mouth_gan_loss: 0 # Weight for mouth-specific GAN loss
|
| 43 |
+
|
| 44 |
+
model_params:
|
| 45 |
+
discriminator_params:
|
| 46 |
+
scales: [1] # Scales for discriminator
|
| 47 |
+
block_expansion: 32 # Expansion factor for discriminator blocks
|
| 48 |
+
max_features: 512 # Maximum number of features in discriminator
|
| 49 |
+
num_blocks: 4 # Number of blocks in discriminator
|
| 50 |
+
sn: True # Whether to use spectral normalization
|
| 51 |
+
image_channel: 3 # Number of image channels
|
| 52 |
+
estimate_jacobian: False # Whether to estimate Jacobian
|
| 53 |
+
|
| 54 |
+
discriminator_train_params:
|
| 55 |
+
lr: 0.000005 # Learning rate for discriminator
|
| 56 |
+
eps: 0.00000001 # Epsilon for optimizer
|
| 57 |
+
weight_decay: 0.01 # Weight decay for optimizer
|
| 58 |
+
patch_size: 1 # Size of patches for discriminator
|
| 59 |
+
betas: [0.5, 0.999] # Beta parameters for Adam optimizer
|
| 60 |
+
epochs: 10000 # Number of training epochs
|
| 61 |
+
start_gan: 1000 # Step to start GAN training
|
| 62 |
+
|
| 63 |
+
solver:
|
| 64 |
+
gradient_accumulation_steps: 1 # Number of steps for gradient accumulation
|
| 65 |
+
uncond_steps: 10 # Number of unconditional steps
|
| 66 |
+
mixed_precision: 'fp32' # Precision mode for training
|
| 67 |
+
enable_xformers_memory_efficient_attention: True # Whether to use memory efficient attention
|
| 68 |
+
gradient_checkpointing: True # Whether to use gradient checkpointing
|
| 69 |
+
max_train_steps: 250000 # Maximum number of training steps
|
| 70 |
+
max_grad_norm: 1.0 # Maximum gradient norm for clipping
|
| 71 |
+
# Learning rate parameters
|
| 72 |
+
learning_rate: 2.0e-5 # Base learning rate
|
| 73 |
+
scale_lr: False # Whether to scale learning rate
|
| 74 |
+
lr_warmup_steps: 1000 # Number of warmup steps for learning rate
|
| 75 |
+
lr_scheduler: "linear" # Type of learning rate scheduler
|
| 76 |
+
# Optimizer parameters
|
| 77 |
+
use_8bit_adam: False # Whether to use 8-bit Adam optimizer
|
| 78 |
+
adam_beta1: 0.5 # Beta1 parameter for Adam optimizer
|
| 79 |
+
adam_beta2: 0.999 # Beta2 parameter for Adam optimizer
|
| 80 |
+
adam_weight_decay: 1.0e-2 # Weight decay for Adam optimizer
|
| 81 |
+
adam_epsilon: 1.0e-8 # Epsilon for Adam optimizer
|
| 82 |
+
|
| 83 |
+
total_limit: 10 # Maximum number of checkpoints to keep
|
| 84 |
+
save_model_epoch_interval: 250000 # Interval between model saves
|
| 85 |
+
checkpointing_steps: 10000 # Number of steps between checkpoints
|
| 86 |
+
val_freq: 2000 # Frequency of validation
|
| 87 |
+
|
| 88 |
+
seed: 41 # Random seed for reproducibility
|
| 89 |
+
|
configs/training/stage2.yaml
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
exp_name: 'test' # Name of the experiment
|
| 2 |
+
output_dir: './exp_out/stage2/' # Directory to save experiment outputs
|
| 3 |
+
unet_sub_folder: musetalk # Subfolder name for UNet model
|
| 4 |
+
random_init_unet: False # Whether to randomly initialize UNet (stage1) or use pretrained weights (stage2)
|
| 5 |
+
whisper_path: "./models/whisper" # Path to the Whisper model
|
| 6 |
+
pretrained_model_name_or_path: "./models" # Path to pretrained models
|
| 7 |
+
resume_from_checkpoint: True # Whether to resume training from a checkpoint
|
| 8 |
+
padding_pixel_mouth: 10 # Number of pixels to pad around the mouth region
|
| 9 |
+
vae_type: "sd-vae" # Type of VAE model to use
|
| 10 |
+
# Validation parameters
|
| 11 |
+
num_images_to_keep: 8 # Number of validation images to keep
|
| 12 |
+
ref_dropout_rate: 0 # Dropout rate for reference images
|
| 13 |
+
syncnet_config_path: "./configs/training/syncnet.yaml" # Path to SyncNet configuration
|
| 14 |
+
use_adapted_weight: False # Whether to use adapted weights for loss calculation
|
| 15 |
+
cropping_jaw2edge_margin_mean: 10 # Mean margin for jaw-to-edge cropping
|
| 16 |
+
cropping_jaw2edge_margin_std: 10 # Standard deviation for jaw-to-edge cropping
|
| 17 |
+
crop_type: "dynamic_margin_crop_resize" # Type of cropping method
|
| 18 |
+
random_margin_method: "normal" # Method for random margin generation
|
| 19 |
+
num_backward_frames: 16 # Number of frames to use for backward pass in SyncNet
|
| 20 |
+
|
| 21 |
+
data:
|
| 22 |
+
dataset_key: "HDTF" # Dataset to use for training
|
| 23 |
+
train_bs: 2 # Training batch size (actual batch size is train_bs*n_sample_frames)
|
| 24 |
+
image_size: 256 # Size of input images
|
| 25 |
+
n_sample_frames: 16 # Number of frames to sample per batch
|
| 26 |
+
num_workers: 8 # Number of data loading workers
|
| 27 |
+
audio_padding_length_left: 2 # Left padding length for audio features
|
| 28 |
+
audio_padding_length_right: 2 # Right padding length for audio features
|
| 29 |
+
sample_method: pose_similarity_and_mouth_dissimilarity # Method for sampling frames
|
| 30 |
+
top_k_ratio: 0.51 # Ratio for top-k sampling
|
| 31 |
+
contorl_face_min_size: True # Whether to control minimum face size
|
| 32 |
+
min_face_size: 200 # Minimum face size in pixels
|
| 33 |
+
|
| 34 |
+
loss_params:
|
| 35 |
+
l1_loss: 1.0 # Weight for L1 loss
|
| 36 |
+
vgg_loss: 0.01 # Weight for VGG perceptual loss
|
| 37 |
+
vgg_layer_weight: [1, 1, 1, 1, 1] # Weights for different VGG layers
|
| 38 |
+
pyramid_scale: [1, 0.5, 0.25, 0.125] # Scales for image pyramid
|
| 39 |
+
gan_loss: 0.01 # Weight for GAN loss
|
| 40 |
+
fm_loss: [1.0, 1.0, 1.0, 1.0] # Weights for feature matching loss
|
| 41 |
+
sync_loss: 0.05 # Weight for sync loss
|
| 42 |
+
mouth_gan_loss: 0.01 # Weight for mouth-specific GAN loss
|
| 43 |
+
|
| 44 |
+
model_params:
|
| 45 |
+
discriminator_params:
|
| 46 |
+
scales: [1] # Scales for discriminator
|
| 47 |
+
block_expansion: 32 # Expansion factor for discriminator blocks
|
| 48 |
+
max_features: 512 # Maximum number of features in discriminator
|
| 49 |
+
num_blocks: 4 # Number of blocks in discriminator
|
| 50 |
+
sn: True # Whether to use spectral normalization
|
| 51 |
+
image_channel: 3 # Number of image channels
|
| 52 |
+
estimate_jacobian: False # Whether to estimate Jacobian
|
| 53 |
+
|
| 54 |
+
discriminator_train_params:
|
| 55 |
+
lr: 0.000005 # Learning rate for discriminator
|
| 56 |
+
eps: 0.00000001 # Epsilon for optimizer
|
| 57 |
+
weight_decay: 0.01 # Weight decay for optimizer
|
| 58 |
+
patch_size: 1 # Size of patches for discriminator
|
| 59 |
+
betas: [0.5, 0.999] # Beta parameters for Adam optimizer
|
| 60 |
+
epochs: 10000 # Number of training epochs
|
| 61 |
+
start_gan: 1000 # Step to start GAN training
|
| 62 |
+
|
| 63 |
+
solver:
|
| 64 |
+
gradient_accumulation_steps: 8 # Number of steps for gradient accumulation
|
| 65 |
+
uncond_steps: 10 # Number of unconditional steps
|
| 66 |
+
mixed_precision: 'fp32' # Precision mode for training
|
| 67 |
+
enable_xformers_memory_efficient_attention: True # Whether to use memory efficient attention
|
| 68 |
+
gradient_checkpointing: True # Whether to use gradient checkpointing
|
| 69 |
+
max_train_steps: 250000 # Maximum number of training steps
|
| 70 |
+
max_grad_norm: 1.0 # Maximum gradient norm for clipping
|
| 71 |
+
# Learning rate parameters
|
| 72 |
+
learning_rate: 5.0e-6 # Base learning rate
|
| 73 |
+
scale_lr: False # Whether to scale learning rate
|
| 74 |
+
lr_warmup_steps: 1000 # Number of warmup steps for learning rate
|
| 75 |
+
lr_scheduler: "linear" # Type of learning rate scheduler
|
| 76 |
+
# Optimizer parameters
|
| 77 |
+
use_8bit_adam: False # Whether to use 8-bit Adam optimizer
|
| 78 |
+
adam_beta1: 0.5 # Beta1 parameter for Adam optimizer
|
| 79 |
+
adam_beta2: 0.999 # Beta2 parameter for Adam optimizer
|
| 80 |
+
adam_weight_decay: 1.0e-2 # Weight decay for Adam optimizer
|
| 81 |
+
adam_epsilon: 1.0e-8 # Epsilon for Adam optimizer
|
| 82 |
+
|
| 83 |
+
total_limit: 10 # Maximum number of checkpoints to keep
|
| 84 |
+
save_model_epoch_interval: 250000 # Interval between model saves
|
| 85 |
+
checkpointing_steps: 2000 # Number of steps between checkpoints
|
| 86 |
+
val_freq: 2000 # Frequency of validation
|
| 87 |
+
|
| 88 |
+
seed: 41 # Random seed for reproducibility
|
| 89 |
+
|
configs/training/syncnet.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/configs/training/syncnet_16_pixel.yaml).
|
| 2 |
+
model:
|
| 3 |
+
audio_encoder: # input (1, 80, 52)
|
| 4 |
+
in_channels: 1
|
| 5 |
+
block_out_channels: [32, 64, 128, 256, 512, 1024, 2048]
|
| 6 |
+
downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]]
|
| 7 |
+
attn_blocks: [0, 0, 0, 0, 0, 0, 0]
|
| 8 |
+
dropout: 0.0
|
| 9 |
+
visual_encoder: # input (48, 128, 256)
|
| 10 |
+
in_channels: 48
|
| 11 |
+
block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048]
|
| 12 |
+
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
|
| 13 |
+
attn_blocks: [0, 0, 0, 0, 0, 0, 0, 0]
|
| 14 |
+
dropout: 0.0
|
| 15 |
+
|
| 16 |
+
ckpt:
|
| 17 |
+
resume_ckpt_path: ""
|
| 18 |
+
inference_ckpt_path: ./models/syncnet/latentsync_syncnet.pt # this pretrained model is from LatentSync (https://huggingface.co/ByteDance/LatentSync/tree/main)
|
| 19 |
+
save_ckpt_steps: 2500
|
inference.sh
CHANGED
|
@@ -59,7 +59,7 @@ cmd_args="--inference_config $config_path \
|
|
| 59 |
--result_dir $result_dir \
|
| 60 |
--unet_model_path $unet_model_path \
|
| 61 |
--unet_config $unet_config \
|
| 62 |
-
--version $
|
| 63 |
|
| 64 |
# Add realtime-specific arguments if in realtime mode
|
| 65 |
if [ "$mode" = "realtime" ]; then
|
|
|
|
| 59 |
--result_dir $result_dir \
|
| 60 |
--unet_model_path $unet_model_path \
|
| 61 |
--unet_config $unet_config \
|
| 62 |
+
--version $version_arg"
|
| 63 |
|
| 64 |
# Add realtime-specific arguments if in realtime mode
|
| 65 |
if [ "$mode" = "realtime" ]; then
|
musetalk/data/audio.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import librosa.filters
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scipy import signal
|
| 5 |
+
from scipy.io import wavfile
|
| 6 |
+
|
| 7 |
+
class HParams:
|
| 8 |
+
# copy from wav2lip
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.n_fft = 800
|
| 11 |
+
self.hop_size = 200
|
| 12 |
+
self.win_size = 800
|
| 13 |
+
self.sample_rate = 16000
|
| 14 |
+
self.frame_shift_ms = None
|
| 15 |
+
self.signal_normalization = True
|
| 16 |
+
|
| 17 |
+
self.allow_clipping_in_normalization = True
|
| 18 |
+
self.symmetric_mels = True
|
| 19 |
+
self.max_abs_value = 4.0
|
| 20 |
+
self.preemphasize = True
|
| 21 |
+
self.preemphasis = 0.97
|
| 22 |
+
self.min_level_db = -100
|
| 23 |
+
self.ref_level_db = 20
|
| 24 |
+
self.fmin = 55
|
| 25 |
+
self.fmax=7600
|
| 26 |
+
|
| 27 |
+
self.use_lws=False
|
| 28 |
+
self.num_mels=80 # Number of mel-spectrogram channels and local conditioning dimensionality
|
| 29 |
+
self.rescale=True # Whether to rescale audio prior to preprocessing
|
| 30 |
+
self.rescaling_max=0.9 # Rescaling value
|
| 31 |
+
self.use_lws=False
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
hp = HParams()
|
| 35 |
+
|
| 36 |
+
def load_wav(path, sr):
|
| 37 |
+
return librosa.core.load(path, sr=sr)[0]
|
| 38 |
+
#def load_wav(path, sr):
|
| 39 |
+
# audio, sr_native = sf.read(path)
|
| 40 |
+
# if sr != sr_native:
|
| 41 |
+
# audio = librosa.resample(audio.T, sr_native, sr).T
|
| 42 |
+
# return audio
|
| 43 |
+
|
| 44 |
+
def save_wav(wav, path, sr):
|
| 45 |
+
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
| 46 |
+
#proposed by @dsmiller
|
| 47 |
+
wavfile.write(path, sr, wav.astype(np.int16))
|
| 48 |
+
|
| 49 |
+
def save_wavenet_wav(wav, path, sr):
|
| 50 |
+
librosa.output.write_wav(path, wav, sr=sr)
|
| 51 |
+
|
| 52 |
+
def preemphasis(wav, k, preemphasize=True):
|
| 53 |
+
if preemphasize:
|
| 54 |
+
return signal.lfilter([1, -k], [1], wav)
|
| 55 |
+
return wav
|
| 56 |
+
|
| 57 |
+
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
| 58 |
+
if inv_preemphasize:
|
| 59 |
+
return signal.lfilter([1], [1, -k], wav)
|
| 60 |
+
return wav
|
| 61 |
+
|
| 62 |
+
def get_hop_size():
|
| 63 |
+
hop_size = hp.hop_size
|
| 64 |
+
if hop_size is None:
|
| 65 |
+
assert hp.frame_shift_ms is not None
|
| 66 |
+
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
|
| 67 |
+
return hop_size
|
| 68 |
+
|
| 69 |
+
def linearspectrogram(wav):
|
| 70 |
+
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
| 71 |
+
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
|
| 72 |
+
|
| 73 |
+
if hp.signal_normalization:
|
| 74 |
+
return _normalize(S)
|
| 75 |
+
return S
|
| 76 |
+
|
| 77 |
+
def melspectrogram(wav):
|
| 78 |
+
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
| 79 |
+
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
|
| 80 |
+
|
| 81 |
+
if hp.signal_normalization:
|
| 82 |
+
return _normalize(S)
|
| 83 |
+
return S
|
| 84 |
+
|
| 85 |
+
def _lws_processor():
|
| 86 |
+
import lws
|
| 87 |
+
return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
|
| 88 |
+
|
| 89 |
+
def _stft(y):
|
| 90 |
+
if hp.use_lws:
|
| 91 |
+
return _lws_processor(hp).stft(y).T
|
| 92 |
+
else:
|
| 93 |
+
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
|
| 94 |
+
|
| 95 |
+
##########################################################
|
| 96 |
+
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
| 97 |
+
def num_frames(length, fsize, fshift):
|
| 98 |
+
"""Compute number of time frames of spectrogram
|
| 99 |
+
"""
|
| 100 |
+
pad = (fsize - fshift)
|
| 101 |
+
if length % fshift == 0:
|
| 102 |
+
M = (length + pad * 2 - fsize) // fshift + 1
|
| 103 |
+
else:
|
| 104 |
+
M = (length + pad * 2 - fsize) // fshift + 2
|
| 105 |
+
return M
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def pad_lr(x, fsize, fshift):
|
| 109 |
+
"""Compute left and right padding
|
| 110 |
+
"""
|
| 111 |
+
M = num_frames(len(x), fsize, fshift)
|
| 112 |
+
pad = (fsize - fshift)
|
| 113 |
+
T = len(x) + 2 * pad
|
| 114 |
+
r = (M - 1) * fshift + fsize - T
|
| 115 |
+
return pad, pad + r
|
| 116 |
+
##########################################################
|
| 117 |
+
#Librosa correct padding
|
| 118 |
+
def librosa_pad_lr(x, fsize, fshift):
|
| 119 |
+
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
| 120 |
+
|
| 121 |
+
# Conversions
|
| 122 |
+
_mel_basis = None
|
| 123 |
+
|
| 124 |
+
def _linear_to_mel(spectogram):
|
| 125 |
+
global _mel_basis
|
| 126 |
+
if _mel_basis is None:
|
| 127 |
+
_mel_basis = _build_mel_basis()
|
| 128 |
+
return np.dot(_mel_basis, spectogram)
|
| 129 |
+
|
| 130 |
+
def _build_mel_basis():
|
| 131 |
+
assert hp.fmax <= hp.sample_rate // 2
|
| 132 |
+
return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
|
| 133 |
+
fmin=hp.fmin, fmax=hp.fmax)
|
| 134 |
+
|
| 135 |
+
def _amp_to_db(x):
|
| 136 |
+
min_level = np.exp(hp.min_level_db / 20 * np.log(10))
|
| 137 |
+
return 20 * np.log10(np.maximum(min_level, x))
|
| 138 |
+
|
| 139 |
+
def _db_to_amp(x):
|
| 140 |
+
return np.power(10.0, (x) * 0.05)
|
| 141 |
+
|
| 142 |
+
def _normalize(S):
|
| 143 |
+
if hp.allow_clipping_in_normalization:
|
| 144 |
+
if hp.symmetric_mels:
|
| 145 |
+
return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
|
| 146 |
+
-hp.max_abs_value, hp.max_abs_value)
|
| 147 |
+
else:
|
| 148 |
+
return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
|
| 149 |
+
|
| 150 |
+
assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
|
| 151 |
+
if hp.symmetric_mels:
|
| 152 |
+
return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
|
| 153 |
+
else:
|
| 154 |
+
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
|
| 155 |
+
|
| 156 |
+
def _denormalize(D):
|
| 157 |
+
if hp.allow_clipping_in_normalization:
|
| 158 |
+
if hp.symmetric_mels:
|
| 159 |
+
return (((np.clip(D, -hp.max_abs_value,
|
| 160 |
+
hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
|
| 161 |
+
+ hp.min_level_db)
|
| 162 |
+
else:
|
| 163 |
+
return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
| 164 |
+
|
| 165 |
+
if hp.symmetric_mels:
|
| 166 |
+
return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
|
| 167 |
+
else:
|
| 168 |
+
return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
musetalk/data/dataset.py
ADDED
|
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
from transformers import AutoFeatureExtractor
|
| 9 |
+
import librosa
|
| 10 |
+
import time
|
| 11 |
+
import json
|
| 12 |
+
import math
|
| 13 |
+
from decord import AudioReader, VideoReader
|
| 14 |
+
from decord.ndarray import cpu
|
| 15 |
+
|
| 16 |
+
from musetalk.data.sample_method import get_src_idx, shift_landmarks_to_face_coordinates, resize_landmark
|
| 17 |
+
from musetalk.data import audio
|
| 18 |
+
|
| 19 |
+
syncnet_mel_step_size = math.ceil(16 / 5 * 16) # latentsync
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class FaceDataset(Dataset):
|
| 23 |
+
"""Dataset class for loading and processing video data
|
| 24 |
+
|
| 25 |
+
Each video can be represented as:
|
| 26 |
+
- Concatenated frame images
|
| 27 |
+
- '.mp4' or '.gif' files
|
| 28 |
+
- Folder containing all frames
|
| 29 |
+
"""
|
| 30 |
+
def __init__(self,
|
| 31 |
+
cfg,
|
| 32 |
+
list_paths,
|
| 33 |
+
root_path='./dataset/',
|
| 34 |
+
repeats=None):
|
| 35 |
+
# Initialize dataset paths
|
| 36 |
+
meta_paths = []
|
| 37 |
+
if repeats is None:
|
| 38 |
+
repeats = [1] * len(list_paths)
|
| 39 |
+
assert len(repeats) == len(list_paths)
|
| 40 |
+
|
| 41 |
+
# Load data list
|
| 42 |
+
for list_path, repeat_time in zip(list_paths, repeats):
|
| 43 |
+
with open(list_path, 'r') as f:
|
| 44 |
+
num = 0
|
| 45 |
+
f.readline() # Skip header line
|
| 46 |
+
for line in f.readlines():
|
| 47 |
+
line_info = line.strip()
|
| 48 |
+
meta = line_info.split()
|
| 49 |
+
meta = meta[0]
|
| 50 |
+
meta_paths.extend([os.path.join(root_path, meta)] * repeat_time)
|
| 51 |
+
num += 1
|
| 52 |
+
print(f'{list_path}: {num} x {repeat_time} = {num * repeat_time} samples')
|
| 53 |
+
|
| 54 |
+
# Set basic attributes
|
| 55 |
+
self.meta_paths = meta_paths
|
| 56 |
+
self.root_path = root_path
|
| 57 |
+
self.image_size = cfg['image_size']
|
| 58 |
+
self.min_face_size = cfg['min_face_size']
|
| 59 |
+
self.T = cfg['T']
|
| 60 |
+
self.sample_method = cfg['sample_method']
|
| 61 |
+
self.top_k_ratio = cfg['top_k_ratio']
|
| 62 |
+
self.max_attempts = 200
|
| 63 |
+
self.padding_pixel_mouth = cfg['padding_pixel_mouth']
|
| 64 |
+
|
| 65 |
+
# Cropping related parameters
|
| 66 |
+
self.crop_type = cfg['crop_type']
|
| 67 |
+
self.jaw2edge_margin_mean = cfg['cropping_jaw2edge_margin_mean']
|
| 68 |
+
self.jaw2edge_margin_std = cfg['cropping_jaw2edge_margin_std']
|
| 69 |
+
self.random_margin_method = cfg['random_margin_method']
|
| 70 |
+
|
| 71 |
+
# Image transformations
|
| 72 |
+
self.to_tensor = transforms.Compose([
|
| 73 |
+
transforms.ToTensor(),
|
| 74 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
| 75 |
+
])
|
| 76 |
+
self.pose_to_tensor = transforms.Compose([
|
| 77 |
+
transforms.ToTensor(),
|
| 78 |
+
])
|
| 79 |
+
|
| 80 |
+
# Feature extractor
|
| 81 |
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained(cfg['whisper_path'])
|
| 82 |
+
self.contorl_face_min_size = cfg["contorl_face_min_size"]
|
| 83 |
+
|
| 84 |
+
print("The sample method is: ", self.sample_method)
|
| 85 |
+
print(f"only use face size > {self.min_face_size}", self.contorl_face_min_size)
|
| 86 |
+
|
| 87 |
+
def generate_random_value(self):
|
| 88 |
+
"""Generate random value
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
float: Generated random value
|
| 92 |
+
"""
|
| 93 |
+
if self.random_margin_method == "uniform":
|
| 94 |
+
random_value = np.random.uniform(
|
| 95 |
+
self.jaw2edge_margin_mean - self.jaw2edge_margin_std,
|
| 96 |
+
self.jaw2edge_margin_mean + self.jaw2edge_margin_std
|
| 97 |
+
)
|
| 98 |
+
elif self.random_margin_method == "normal":
|
| 99 |
+
random_value = np.random.normal(
|
| 100 |
+
loc=self.jaw2edge_margin_mean,
|
| 101 |
+
scale=self.jaw2edge_margin_std
|
| 102 |
+
)
|
| 103 |
+
random_value = np.clip(
|
| 104 |
+
random_value,
|
| 105 |
+
self.jaw2edge_margin_mean - self.jaw2edge_margin_std,
|
| 106 |
+
self.jaw2edge_margin_mean + self.jaw2edge_margin_std,
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
raise ValueError(f"Invalid random margin method: {self.random_margin_method}")
|
| 110 |
+
return max(0, random_value)
|
| 111 |
+
|
| 112 |
+
def dynamic_margin_crop(self, img, original_bbox, extra_margin=None):
|
| 113 |
+
"""Dynamically crop image with dynamic margin
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
img: Input image
|
| 117 |
+
original_bbox: Original bounding box
|
| 118 |
+
extra_margin: Extra margin
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
tuple: (x1, y1, x2, y2, extra_margin)
|
| 122 |
+
"""
|
| 123 |
+
if extra_margin is None:
|
| 124 |
+
extra_margin = self.generate_random_value()
|
| 125 |
+
w, h = img.size
|
| 126 |
+
x1, y1, x2, y2 = original_bbox
|
| 127 |
+
y2 = min(y2 + int(extra_margin), h)
|
| 128 |
+
return x1, y1, x2, y2, extra_margin
|
| 129 |
+
|
| 130 |
+
def crop_resize_img(self, img, bbox, crop_type='crop_resize', extra_margin=None):
|
| 131 |
+
"""Crop and resize image
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
img: Input image
|
| 135 |
+
bbox: Bounding box
|
| 136 |
+
crop_type: Type of cropping
|
| 137 |
+
extra_margin: Extra margin
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
tuple: (Processed image, extra_margin, mask_scaled_factor)
|
| 141 |
+
"""
|
| 142 |
+
mask_scaled_factor = 1.
|
| 143 |
+
if crop_type == 'crop_resize':
|
| 144 |
+
x1, y1, x2, y2 = bbox
|
| 145 |
+
img = img.crop((x1, y1, x2, y2))
|
| 146 |
+
img = img.resize((self.image_size, self.image_size), Image.LANCZOS)
|
| 147 |
+
elif crop_type == 'dynamic_margin_crop_resize':
|
| 148 |
+
x1, y1, x2, y2, extra_margin = self.dynamic_margin_crop(img, bbox, extra_margin)
|
| 149 |
+
w_original, _ = img.size
|
| 150 |
+
img = img.crop((x1, y1, x2, y2))
|
| 151 |
+
w_cropped, _ = img.size
|
| 152 |
+
mask_scaled_factor = w_cropped / w_original
|
| 153 |
+
img = img.resize((self.image_size, self.image_size), Image.LANCZOS)
|
| 154 |
+
elif crop_type == 'resize':
|
| 155 |
+
w, h = img.size
|
| 156 |
+
scale = np.sqrt(self.image_size ** 2 / (h * w))
|
| 157 |
+
new_w = int(w * scale) / 64 * 64
|
| 158 |
+
new_h = int(h * scale) / 64 * 64
|
| 159 |
+
img = img.resize((new_w, new_h), Image.LANCZOS)
|
| 160 |
+
return img, extra_margin, mask_scaled_factor
|
| 161 |
+
|
| 162 |
+
def get_audio_file(self, wav_path, start_index):
|
| 163 |
+
"""Get audio file features
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
wav_path: Audio file path
|
| 167 |
+
start_index: Starting index
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
tuple: (Audio features, start index)
|
| 171 |
+
"""
|
| 172 |
+
if not os.path.exists(wav_path):
|
| 173 |
+
return None
|
| 174 |
+
audio_input_librosa, sampling_rate = librosa.load(wav_path, sr=16000)
|
| 175 |
+
assert sampling_rate == 16000
|
| 176 |
+
|
| 177 |
+
while start_index >= 25 * 30:
|
| 178 |
+
audio_input = audio_input_librosa[16000*30:]
|
| 179 |
+
start_index -= 25 * 30
|
| 180 |
+
if start_index + 2 * 25 >= 25 * 30:
|
| 181 |
+
start_index -= 4 * 25
|
| 182 |
+
audio_input = audio_input_librosa[16000*4:16000*34]
|
| 183 |
+
else:
|
| 184 |
+
audio_input = audio_input_librosa[:16000*30]
|
| 185 |
+
|
| 186 |
+
assert 2 * (start_index) >= 0
|
| 187 |
+
assert 2 * (start_index + 2 * 25) <= 1500
|
| 188 |
+
|
| 189 |
+
audio_input = self.feature_extractor(
|
| 190 |
+
audio_input,
|
| 191 |
+
return_tensors="pt",
|
| 192 |
+
sampling_rate=sampling_rate
|
| 193 |
+
).input_features
|
| 194 |
+
return audio_input, start_index
|
| 195 |
+
|
| 196 |
+
def get_audio_file_mel(self, wav_path, start_index):
|
| 197 |
+
"""Get mel spectrogram of audio file
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
wav_path: Audio file path
|
| 201 |
+
start_index: Starting index
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
tuple: (Mel spectrogram, start index)
|
| 205 |
+
"""
|
| 206 |
+
if not os.path.exists(wav_path):
|
| 207 |
+
return None
|
| 208 |
+
|
| 209 |
+
audio_input, sampling_rate = librosa.load(wav_path, sr=16000)
|
| 210 |
+
assert sampling_rate == 16000
|
| 211 |
+
|
| 212 |
+
audio_input = self.mel_feature_extractor(audio_input)
|
| 213 |
+
return audio_input, start_index
|
| 214 |
+
|
| 215 |
+
def mel_feature_extractor(self, audio_input):
|
| 216 |
+
"""Extract mel spectrogram features
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
audio_input: Input audio
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
ndarray: Mel spectrogram features
|
| 223 |
+
"""
|
| 224 |
+
orig_mel = audio.melspectrogram(audio_input)
|
| 225 |
+
return orig_mel.T
|
| 226 |
+
|
| 227 |
+
def crop_audio_window(self, spec, start_frame_num, fps=25):
|
| 228 |
+
"""Crop audio window
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
spec: Spectrogram
|
| 232 |
+
start_frame_num: Starting frame number
|
| 233 |
+
fps: Frames per second
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
ndarray: Cropped spectrogram
|
| 237 |
+
"""
|
| 238 |
+
start_idx = int(80. * (start_frame_num / float(fps)))
|
| 239 |
+
end_idx = start_idx + syncnet_mel_step_size
|
| 240 |
+
return spec[start_idx: end_idx, :]
|
| 241 |
+
|
| 242 |
+
def get_syncnet_input(self, video_path):
|
| 243 |
+
"""Get SyncNet input features
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
video_path: Video file path
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
ndarray: SyncNet input features
|
| 250 |
+
"""
|
| 251 |
+
ar = AudioReader(video_path, sample_rate=16000)
|
| 252 |
+
original_mel = audio.melspectrogram(ar[:].asnumpy().squeeze(0))
|
| 253 |
+
return original_mel.T
|
| 254 |
+
|
| 255 |
+
def get_resized_mouth_mask(
|
| 256 |
+
self,
|
| 257 |
+
img_resized,
|
| 258 |
+
landmark_array,
|
| 259 |
+
face_shape,
|
| 260 |
+
padding_pixel_mouth=0,
|
| 261 |
+
image_size=256,
|
| 262 |
+
crop_margin=0
|
| 263 |
+
):
|
| 264 |
+
landmark_array = np.array(landmark_array)
|
| 265 |
+
resized_landmark = resize_landmark(
|
| 266 |
+
landmark_array, w=face_shape[0], h=face_shape[1], new_w=image_size, new_h=image_size)
|
| 267 |
+
|
| 268 |
+
landmark_array = np.array(resized_landmark[48 : 67]) # the lip landmarks in 68 landmarks format
|
| 269 |
+
min_x, min_y = np.min(landmark_array, axis=0)
|
| 270 |
+
max_x, max_y = np.max(landmark_array, axis=0)
|
| 271 |
+
min_x = min_x - padding_pixel_mouth
|
| 272 |
+
max_x = max_x + padding_pixel_mouth
|
| 273 |
+
|
| 274 |
+
# Calculate x-axis length and use it for y-axis
|
| 275 |
+
width = max_x - min_x
|
| 276 |
+
|
| 277 |
+
# Calculate old center point
|
| 278 |
+
center_y = (max_y + min_y) / 2
|
| 279 |
+
|
| 280 |
+
# Determine new min_y and max_y based on width
|
| 281 |
+
min_y = center_y - width / 4
|
| 282 |
+
max_y = center_y + width / 4
|
| 283 |
+
|
| 284 |
+
# Adjust mask position for dynamic crop, shift y-axis
|
| 285 |
+
min_y = min_y - crop_margin
|
| 286 |
+
max_y = max_y - crop_margin
|
| 287 |
+
|
| 288 |
+
# Prevent out of bounds
|
| 289 |
+
min_x = max(min_x, 0)
|
| 290 |
+
min_y = max(min_y, 0)
|
| 291 |
+
max_x = min(max_x, face_shape[0])
|
| 292 |
+
max_y = min(max_y, face_shape[1])
|
| 293 |
+
|
| 294 |
+
mask = np.zeros_like(np.array(img_resized))
|
| 295 |
+
mask[round(min_y):round(max_y), round(min_x):round(max_x)] = 255
|
| 296 |
+
return Image.fromarray(mask)
|
| 297 |
+
|
| 298 |
+
def __len__(self):
|
| 299 |
+
return 100000
|
| 300 |
+
|
| 301 |
+
def __getitem__(self, idx):
|
| 302 |
+
attempts = 0
|
| 303 |
+
while attempts < self.max_attempts:
|
| 304 |
+
try:
|
| 305 |
+
meta_path = random.sample(self.meta_paths, k=1)[0]
|
| 306 |
+
with open(meta_path, 'r') as f:
|
| 307 |
+
meta_data = json.load(f)
|
| 308 |
+
except Exception as e:
|
| 309 |
+
print(f"meta file error:{meta_path}")
|
| 310 |
+
print(e)
|
| 311 |
+
attempts += 1
|
| 312 |
+
time.sleep(0.1)
|
| 313 |
+
continue
|
| 314 |
+
|
| 315 |
+
video_path = meta_data["mp4_path"]
|
| 316 |
+
wav_path = meta_data["wav_path"]
|
| 317 |
+
bbox_list = meta_data["face_list"]
|
| 318 |
+
landmark_list = meta_data["landmark_list"]
|
| 319 |
+
T = self.T
|
| 320 |
+
|
| 321 |
+
s = 0
|
| 322 |
+
e = meta_data["frames"]
|
| 323 |
+
len_valid_clip = e - s
|
| 324 |
+
|
| 325 |
+
if len_valid_clip < T * 10:
|
| 326 |
+
attempts += 1
|
| 327 |
+
print(f"video {video_path} has less than {T * 10} frames")
|
| 328 |
+
continue
|
| 329 |
+
|
| 330 |
+
try:
|
| 331 |
+
cap = VideoReader(video_path, fault_tol=1, ctx=cpu(0))
|
| 332 |
+
total_frames = len(cap)
|
| 333 |
+
assert total_frames == len(landmark_list)
|
| 334 |
+
assert total_frames == len(bbox_list)
|
| 335 |
+
landmark_shape = np.array(landmark_list).shape
|
| 336 |
+
if landmark_shape != (total_frames, 68, 2):
|
| 337 |
+
attempts += 1
|
| 338 |
+
print(f"video {video_path} has invalid landmark shape: {landmark_shape}, expected: {(total_frames, 68, 2)}") # we use 68 landmarks
|
| 339 |
+
continue
|
| 340 |
+
except Exception as e:
|
| 341 |
+
print(f"video file error:{video_path}")
|
| 342 |
+
print(e)
|
| 343 |
+
attempts += 1
|
| 344 |
+
time.sleep(0.1)
|
| 345 |
+
continue
|
| 346 |
+
|
| 347 |
+
shift_landmarks, bbox_list_union, face_shapes = shift_landmarks_to_face_coordinates(
|
| 348 |
+
landmark_list,
|
| 349 |
+
bbox_list
|
| 350 |
+
)
|
| 351 |
+
if self.contorl_face_min_size and face_shapes[0][0] < self.min_face_size:
|
| 352 |
+
print(f"video {video_path} has face size {face_shapes[0][0]} less than minimum required {self.min_face_size}")
|
| 353 |
+
attempts += 1
|
| 354 |
+
continue
|
| 355 |
+
|
| 356 |
+
step = 1
|
| 357 |
+
drive_idx_start = random.randint(s, e - T * step)
|
| 358 |
+
drive_idx_list = list(
|
| 359 |
+
range(drive_idx_start, drive_idx_start + T * step, step))
|
| 360 |
+
assert len(drive_idx_list) == T
|
| 361 |
+
|
| 362 |
+
src_idx_list = []
|
| 363 |
+
list_index_out_of_range = False
|
| 364 |
+
for drive_idx in drive_idx_list:
|
| 365 |
+
src_idx = get_src_idx(
|
| 366 |
+
drive_idx, T, self.sample_method, shift_landmarks, face_shapes, self.top_k_ratio)
|
| 367 |
+
if src_idx is None:
|
| 368 |
+
list_index_out_of_range = True
|
| 369 |
+
break
|
| 370 |
+
src_idx = min(src_idx, e - 1)
|
| 371 |
+
src_idx = max(src_idx, s)
|
| 372 |
+
src_idx_list.append(src_idx)
|
| 373 |
+
|
| 374 |
+
if list_index_out_of_range:
|
| 375 |
+
attempts += 1
|
| 376 |
+
print(f"video {video_path} has invalid source index for drive frames")
|
| 377 |
+
continue
|
| 378 |
+
|
| 379 |
+
ref_face_valid_flag = True
|
| 380 |
+
extra_margin = self.generate_random_value()
|
| 381 |
+
|
| 382 |
+
# Get reference images
|
| 383 |
+
ref_imgs = []
|
| 384 |
+
for src_idx in src_idx_list:
|
| 385 |
+
imSrc = Image.fromarray(cap[src_idx].asnumpy())
|
| 386 |
+
bbox_s = bbox_list_union[src_idx]
|
| 387 |
+
imSrc, _, _ = self.crop_resize_img(
|
| 388 |
+
imSrc,
|
| 389 |
+
bbox_s,
|
| 390 |
+
self.crop_type,
|
| 391 |
+
extra_margin=None
|
| 392 |
+
)
|
| 393 |
+
if self.contorl_face_min_size and min(imSrc.size[0], imSrc.size[1]) < self.min_face_size:
|
| 394 |
+
ref_face_valid_flag = False
|
| 395 |
+
break
|
| 396 |
+
ref_imgs.append(imSrc)
|
| 397 |
+
|
| 398 |
+
if not ref_face_valid_flag:
|
| 399 |
+
attempts += 1
|
| 400 |
+
print(f"video {video_path} has reference face size smaller than minimum required {self.min_face_size}")
|
| 401 |
+
continue
|
| 402 |
+
|
| 403 |
+
# Get target images and masks
|
| 404 |
+
imSameIDs = []
|
| 405 |
+
bboxes = []
|
| 406 |
+
face_masks = []
|
| 407 |
+
face_mask_valid = True
|
| 408 |
+
target_face_valid_flag = True
|
| 409 |
+
|
| 410 |
+
for drive_idx in drive_idx_list:
|
| 411 |
+
imSameID = Image.fromarray(cap[drive_idx].asnumpy())
|
| 412 |
+
bbox_s = bbox_list_union[drive_idx]
|
| 413 |
+
imSameID, _ , mask_scaled_factor = self.crop_resize_img(
|
| 414 |
+
imSameID,
|
| 415 |
+
bbox_s,
|
| 416 |
+
self.crop_type,
|
| 417 |
+
extra_margin=extra_margin
|
| 418 |
+
)
|
| 419 |
+
if self.contorl_face_min_size and min(imSameID.size[0], imSameID.size[1]) < self.min_face_size:
|
| 420 |
+
target_face_valid_flag = False
|
| 421 |
+
break
|
| 422 |
+
crop_margin = extra_margin * mask_scaled_factor
|
| 423 |
+
face_mask = self.get_resized_mouth_mask(
|
| 424 |
+
imSameID,
|
| 425 |
+
shift_landmarks[drive_idx],
|
| 426 |
+
face_shapes[drive_idx],
|
| 427 |
+
self.padding_pixel_mouth,
|
| 428 |
+
self.image_size,
|
| 429 |
+
crop_margin=crop_margin
|
| 430 |
+
)
|
| 431 |
+
if np.count_nonzero(face_mask) == 0:
|
| 432 |
+
face_mask_valid = False
|
| 433 |
+
break
|
| 434 |
+
|
| 435 |
+
if face_mask.size[1] == 0 or face_mask.size[0] == 0:
|
| 436 |
+
print(f"video {video_path} has invalid face mask size at frame {drive_idx}")
|
| 437 |
+
face_mask_valid = False
|
| 438 |
+
break
|
| 439 |
+
|
| 440 |
+
imSameIDs.append(imSameID)
|
| 441 |
+
bboxes.append(bbox_s)
|
| 442 |
+
face_masks.append(face_mask)
|
| 443 |
+
|
| 444 |
+
if not face_mask_valid:
|
| 445 |
+
attempts += 1
|
| 446 |
+
print(f"video {video_path} has invalid face mask")
|
| 447 |
+
continue
|
| 448 |
+
|
| 449 |
+
if not target_face_valid_flag:
|
| 450 |
+
attempts += 1
|
| 451 |
+
print(f"video {video_path} has target face size smaller than minimum required {self.min_face_size}")
|
| 452 |
+
continue
|
| 453 |
+
|
| 454 |
+
# Process audio features
|
| 455 |
+
audio_offset = drive_idx_list[0]
|
| 456 |
+
audio_step = step
|
| 457 |
+
fps = 25.0 / step
|
| 458 |
+
|
| 459 |
+
try:
|
| 460 |
+
audio_feature, audio_offset = self.get_audio_file(wav_path, audio_offset)
|
| 461 |
+
_, audio_offset = self.get_audio_file_mel(wav_path, audio_offset)
|
| 462 |
+
audio_feature_mel = self.get_syncnet_input(video_path)
|
| 463 |
+
except Exception as e:
|
| 464 |
+
print(f"audio file error:{wav_path}")
|
| 465 |
+
print(e)
|
| 466 |
+
attempts += 1
|
| 467 |
+
time.sleep(0.1)
|
| 468 |
+
continue
|
| 469 |
+
|
| 470 |
+
mel = self.crop_audio_window(audio_feature_mel, audio_offset)
|
| 471 |
+
if mel.shape[0] != syncnet_mel_step_size:
|
| 472 |
+
attempts += 1
|
| 473 |
+
print(f"video {video_path} has invalid mel spectrogram shape: {mel.shape}, expected: {syncnet_mel_step_size}")
|
| 474 |
+
continue
|
| 475 |
+
|
| 476 |
+
mel = torch.FloatTensor(mel.T).unsqueeze(0)
|
| 477 |
+
|
| 478 |
+
# Build sample dictionary
|
| 479 |
+
sample = dict(
|
| 480 |
+
pixel_values_vid=torch.stack(
|
| 481 |
+
[self.to_tensor(imSameID) for imSameID in imSameIDs], dim=0),
|
| 482 |
+
pixel_values_ref_img=torch.stack(
|
| 483 |
+
[self.to_tensor(ref_img) for ref_img in ref_imgs], dim=0),
|
| 484 |
+
pixel_values_face_mask=torch.stack(
|
| 485 |
+
[self.pose_to_tensor(face_mask) for face_mask in face_masks], dim=0),
|
| 486 |
+
audio_feature=audio_feature[0],
|
| 487 |
+
audio_offset=audio_offset,
|
| 488 |
+
audio_step=audio_step,
|
| 489 |
+
mel=mel,
|
| 490 |
+
wav_path=wav_path,
|
| 491 |
+
fps=fps,
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
return sample
|
| 495 |
+
|
| 496 |
+
raise ValueError("Unable to find a valid sample after maximum attempts.")
|
| 497 |
+
|
| 498 |
+
class HDTFDataset(FaceDataset):
|
| 499 |
+
"""HDTF dataset class"""
|
| 500 |
+
def __init__(self, cfg):
|
| 501 |
+
root_path = './dataset/HDTF/meta'
|
| 502 |
+
list_paths = [
|
| 503 |
+
'./dataset/HDTF/train.txt',
|
| 504 |
+
]
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
repeats = [10]
|
| 508 |
+
super().__init__(cfg, list_paths, root_path, repeats)
|
| 509 |
+
print('HDTFDataset: ', len(self))
|
| 510 |
+
|
| 511 |
+
class VFHQDataset(FaceDataset):
|
| 512 |
+
"""VFHQ dataset class"""
|
| 513 |
+
def __init__(self, cfg):
|
| 514 |
+
root_path = './dataset/VFHQ/meta'
|
| 515 |
+
list_paths = [
|
| 516 |
+
'./dataset/VFHQ/train.txt',
|
| 517 |
+
]
|
| 518 |
+
repeats = [1]
|
| 519 |
+
super().__init__(cfg, list_paths, root_path, repeats)
|
| 520 |
+
print('VFHQDataset: ', len(self))
|
| 521 |
+
|
| 522 |
+
def PortraitDataset(cfg=None):
|
| 523 |
+
"""Return dataset based on configuration
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
cfg: Configuration dictionary
|
| 527 |
+
|
| 528 |
+
Returns:
|
| 529 |
+
Dataset: Combined dataset
|
| 530 |
+
"""
|
| 531 |
+
if cfg["dataset_key"] == "HDTF":
|
| 532 |
+
return ConcatDataset([HDTFDataset(cfg)])
|
| 533 |
+
elif cfg["dataset_key"] == "VFHQ":
|
| 534 |
+
return ConcatDataset([VFHQDataset(cfg)])
|
| 535 |
+
else:
|
| 536 |
+
print("############ use all dataset ############ ")
|
| 537 |
+
return ConcatDataset([HDTFDataset(cfg), VFHQDataset(cfg)])
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
if __name__ == '__main__':
|
| 541 |
+
# Set random seeds for reproducibility
|
| 542 |
+
seed = 42
|
| 543 |
+
random.seed(seed)
|
| 544 |
+
np.random.seed(seed)
|
| 545 |
+
torch.manual_seed(seed)
|
| 546 |
+
torch.cuda.manual_seed(seed)
|
| 547 |
+
torch.cuda.manual_seed_all(seed)
|
| 548 |
+
|
| 549 |
+
# Create dataset with configuration parameters
|
| 550 |
+
dataset = PortraitDataset(cfg={
|
| 551 |
+
'T': 1, # Number of frames to process at once
|
| 552 |
+
'random_margin_method': "normal", # Method for generating random margins: "normal" or "uniform"
|
| 553 |
+
'dataset_key': "HDTF", # Dataset to use: "HDTF", "VFHQ", or None for both
|
| 554 |
+
'image_size': 256, # Size of processed images (height and width)
|
| 555 |
+
'sample_method': 'pose_similarity_and_mouth_dissimilarity', # Method for selecting reference frames
|
| 556 |
+
'top_k_ratio': 0.51, # Ratio for top-k selection in reference frame sampling
|
| 557 |
+
'contorl_face_min_size': True, # Whether to enforce minimum face size
|
| 558 |
+
'padding_pixel_mouth': 10, # Padding pixels around mouth region in mask
|
| 559 |
+
'min_face_size': 200, # Minimum face size requirement for dataset
|
| 560 |
+
'whisper_path': "./models/whisper", # Path to Whisper model
|
| 561 |
+
'cropping_jaw2edge_margin_mean': 10, # Mean margin for jaw-to-edge cropping
|
| 562 |
+
'cropping_jaw2edge_margin_std': 10, # Standard deviation for jaw-to-edge cropping
|
| 563 |
+
'crop_type': "dynamic_margin_crop_resize", # Type of cropping: "crop_resize", "dynamic_margin_crop_resize", or "resize"
|
| 564 |
+
})
|
| 565 |
+
print(len(dataset))
|
| 566 |
+
|
| 567 |
+
import torchvision
|
| 568 |
+
os.makedirs('debug', exist_ok=True)
|
| 569 |
+
for i in range(10): # Check 10 samples
|
| 570 |
+
sample = dataset[0]
|
| 571 |
+
print(f"processing {i}")
|
| 572 |
+
|
| 573 |
+
# Get images and mask
|
| 574 |
+
ref_img = (sample['pixel_values_ref_img'] + 1.0) / 2 # (b, c, h, w)
|
| 575 |
+
target_img = (sample['pixel_values_vid'] + 1.0) / 2
|
| 576 |
+
face_mask = sample['pixel_values_face_mask']
|
| 577 |
+
|
| 578 |
+
# Print dimension information
|
| 579 |
+
print(f"ref_img shape: {ref_img.shape}")
|
| 580 |
+
print(f"target_img shape: {target_img.shape}")
|
| 581 |
+
print(f"face_mask shape: {face_mask.shape}")
|
| 582 |
+
|
| 583 |
+
# Create visualization images
|
| 584 |
+
b, c, h, w = ref_img.shape
|
| 585 |
+
|
| 586 |
+
# Apply mask only to target image
|
| 587 |
+
target_mask = face_mask
|
| 588 |
+
|
| 589 |
+
# Keep reference image unchanged
|
| 590 |
+
ref_with_mask = ref_img.clone()
|
| 591 |
+
|
| 592 |
+
# Create mask overlay for target image
|
| 593 |
+
target_with_mask = target_img.clone()
|
| 594 |
+
target_with_mask = target_with_mask * (1 - target_mask) + target_mask # Apply mask only to target
|
| 595 |
+
|
| 596 |
+
# Save original images, mask, and overlay results
|
| 597 |
+
# First row: original images
|
| 598 |
+
# Second row: mask
|
| 599 |
+
# Third row: overlay effect
|
| 600 |
+
concatenated_img = torch.cat((
|
| 601 |
+
ref_img, target_img, # Original images
|
| 602 |
+
torch.zeros_like(ref_img), target_mask, # Mask (black for ref)
|
| 603 |
+
ref_with_mask, target_with_mask # Overlay effect
|
| 604 |
+
), dim=3)
|
| 605 |
+
|
| 606 |
+
torchvision.utils.save_image(
|
| 607 |
+
concatenated_img, f'debug/mask_check_{i}.jpg', nrow=2)
|
musetalk/data/sample_method.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
def summarize_tensor(x):
|
| 5 |
+
return f"\033[34m{str(tuple(x.shape)).ljust(24)}\033[0m (\033[31mmin {x.min().item():+.4f}\033[0m / \033[32mmean {x.mean().item():+.4f}\033[0m / \033[33mmax {x.max().item():+.4f}\033[0m)"
|
| 6 |
+
|
| 7 |
+
def calculate_mouth_open_similarity(landmarks_list, select_idx,top_k=50,ascending=True):
|
| 8 |
+
num_landmarks = len(landmarks_list)
|
| 9 |
+
mouth_open_ratios = np.zeros(num_landmarks) # Initialize as a numpy array
|
| 10 |
+
print(np.shape(landmarks_list))
|
| 11 |
+
## Calculate mouth opening ratios
|
| 12 |
+
for i, landmarks in enumerate(landmarks_list):
|
| 13 |
+
# Assuming landmarks are in the format [x, y] and accessible by index
|
| 14 |
+
mouth_top = landmarks[165] # Adjust index according to your landmarks format
|
| 15 |
+
mouth_bottom = landmarks[147] # Adjust index according to your landmarks format
|
| 16 |
+
mouth_open_ratio = np.linalg.norm(mouth_top - mouth_bottom)
|
| 17 |
+
mouth_open_ratios[i] = mouth_open_ratio
|
| 18 |
+
|
| 19 |
+
# Calculate differences matrix
|
| 20 |
+
differences_matrix = np.abs(mouth_open_ratios[:, np.newaxis] - mouth_open_ratios[select_idx])
|
| 21 |
+
differences_matrix_with_signs = mouth_open_ratios[:, np.newaxis] - mouth_open_ratios[select_idx]
|
| 22 |
+
print(differences_matrix.shape)
|
| 23 |
+
# Find top_k similar indices for each landmark set
|
| 24 |
+
if ascending:
|
| 25 |
+
top_indices = np.argsort(differences_matrix[i])[:top_k]
|
| 26 |
+
else:
|
| 27 |
+
top_indices = np.argsort(-differences_matrix[i])[:top_k]
|
| 28 |
+
similar_landmarks_indices = top_indices.tolist()
|
| 29 |
+
similar_landmarks_distances = differences_matrix_with_signs[i].tolist() #注意这里不要排序
|
| 30 |
+
|
| 31 |
+
return similar_landmarks_indices, similar_landmarks_distances
|
| 32 |
+
#############################################################################################
|
| 33 |
+
def get_closed_mouth(landmarks_list,ascending=True,top_k=50):
|
| 34 |
+
num_landmarks = len(landmarks_list)
|
| 35 |
+
|
| 36 |
+
mouth_open_ratios = np.zeros(num_landmarks) # Initialize as a numpy array
|
| 37 |
+
## Calculate mouth opening ratios
|
| 38 |
+
#print("landmarks shape",np.shape(landmarks_list))
|
| 39 |
+
for i, landmarks in enumerate(landmarks_list):
|
| 40 |
+
# Assuming landmarks are in the format [x, y] and accessible by index
|
| 41 |
+
#print(landmarks[165])
|
| 42 |
+
mouth_top = np.array(landmarks[165])# Adjust index according to your landmarks format
|
| 43 |
+
mouth_bottom = np.array(landmarks[147]) # Adjust index according to your landmarks format
|
| 44 |
+
mouth_open_ratio = np.linalg.norm(mouth_top - mouth_bottom)
|
| 45 |
+
mouth_open_ratios[i] = mouth_open_ratio
|
| 46 |
+
|
| 47 |
+
# Find top_k similar indices for each landmark set
|
| 48 |
+
if ascending:
|
| 49 |
+
top_indices = np.argsort(mouth_open_ratios)[:top_k]
|
| 50 |
+
else:
|
| 51 |
+
top_indices = np.argsort(-mouth_open_ratios)[:top_k]
|
| 52 |
+
return top_indices
|
| 53 |
+
|
| 54 |
+
def calculate_landmarks_similarity(selected_idx, landmarks_list,image_shapes, start_index, end_index, top_k=50,ascending=True):
|
| 55 |
+
"""
|
| 56 |
+
Calculate the similarity between sets of facial landmarks and return the indices of the most similar faces.
|
| 57 |
+
|
| 58 |
+
Parameters:
|
| 59 |
+
landmarks_list (list): A list containing sets of facial landmarks, each element is a set of landmarks.
|
| 60 |
+
image_shapes (list): A list containing the shape of each image, each element is a (width, height) tuple.
|
| 61 |
+
start_index (int): The starting index of the facial landmarks.
|
| 62 |
+
end_index (int): The ending index of the facial landmarks.
|
| 63 |
+
top_k (int): The number of most similar landmark sets to return. Default is 50.
|
| 64 |
+
ascending (bool): Controls the sorting order. If True, sort in ascending order; If False, sort in descending order. Default is True.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
similar_landmarks_indices (list): A list containing the indices of the most similar facial landmarks for each face.
|
| 68 |
+
resized_landmarks (list): A list containing the resized facial landmarks.
|
| 69 |
+
"""
|
| 70 |
+
num_landmarks = len(landmarks_list)
|
| 71 |
+
resized_landmarks = []
|
| 72 |
+
|
| 73 |
+
# Preprocess landmarks
|
| 74 |
+
for i in range(num_landmarks):
|
| 75 |
+
landmark_array = np.array(landmarks_list[i])
|
| 76 |
+
selected_landmarks = landmark_array[start_index:end_index]
|
| 77 |
+
resized_landmark = resize_landmark(selected_landmarks, w=image_shapes[i][0], h=image_shapes[i][1],new_w=256,new_h=256)
|
| 78 |
+
resized_landmarks.append(resized_landmark)
|
| 79 |
+
|
| 80 |
+
resized_landmarks_array = np.array(resized_landmarks) # Convert list to array for easier manipulation
|
| 81 |
+
|
| 82 |
+
# Calculate similarity
|
| 83 |
+
distances = np.linalg.norm(resized_landmarks_array - resized_landmarks_array[selected_idx][np.newaxis, :], axis=2)
|
| 84 |
+
overall_distances = np.mean(distances, axis=1) # Calculate mean distance for each set of landmarks
|
| 85 |
+
|
| 86 |
+
if ascending:
|
| 87 |
+
sorted_indices = np.argsort(overall_distances)
|
| 88 |
+
similar_landmarks_indices = sorted_indices[1:top_k+1].tolist() # Exclude self and take top_k
|
| 89 |
+
else:
|
| 90 |
+
sorted_indices = np.argsort(-overall_distances)
|
| 91 |
+
similar_landmarks_indices = sorted_indices[0:top_k].tolist()
|
| 92 |
+
|
| 93 |
+
return similar_landmarks_indices
|
| 94 |
+
|
| 95 |
+
def process_bbox_musetalk(face_array, landmark_array):
|
| 96 |
+
x_min_face, y_min_face, x_max_face, y_max_face = map(int, face_array)
|
| 97 |
+
x_min_lm = min([int(x) for x, y in landmark_array])
|
| 98 |
+
y_min_lm = min([int(y) for x, y in landmark_array])
|
| 99 |
+
x_max_lm = max([int(x) for x, y in landmark_array])
|
| 100 |
+
y_max_lm = max([int(y) for x, y in landmark_array])
|
| 101 |
+
x_min = min(x_min_face, x_min_lm)
|
| 102 |
+
y_min = min(y_min_face, y_min_lm)
|
| 103 |
+
x_max = max(x_max_face, x_max_lm)
|
| 104 |
+
y_max = max(y_max_face, y_max_lm)
|
| 105 |
+
|
| 106 |
+
x_min = max(x_min, 0)
|
| 107 |
+
y_min = max(y_min, 0)
|
| 108 |
+
|
| 109 |
+
return [x_min, y_min, x_max, y_max]
|
| 110 |
+
|
| 111 |
+
def shift_landmarks_to_face_coordinates(landmark_list, face_list):
|
| 112 |
+
"""
|
| 113 |
+
Translates the data in landmark_list to the coordinates of the cropped larger face.
|
| 114 |
+
|
| 115 |
+
Parameters:
|
| 116 |
+
landmark_list (list): A list containing multiple sets of facial landmarks.
|
| 117 |
+
face_list (list): A list containing multiple facial images.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
landmark_list_shift (list): The list of translated landmarks.
|
| 121 |
+
bbox_union (list): The list of union bounding boxes.
|
| 122 |
+
face_shapes (list): The list of facial shapes.
|
| 123 |
+
"""
|
| 124 |
+
landmark_list_shift = []
|
| 125 |
+
bbox_union = []
|
| 126 |
+
face_shapes = []
|
| 127 |
+
|
| 128 |
+
for i in range(len(face_list)):
|
| 129 |
+
landmark_array = np.array(landmark_list[i]) # 转换为numpy数组并创建副本
|
| 130 |
+
face_array = face_list[i]
|
| 131 |
+
f_landmark_bbox = process_bbox_musetalk(face_array, landmark_array)
|
| 132 |
+
x_min, y_min, x_max, y_max = f_landmark_bbox
|
| 133 |
+
landmark_array[:, 0] = landmark_array[:, 0] - f_landmark_bbox[0]
|
| 134 |
+
landmark_array[:, 1] = landmark_array[:, 1] - f_landmark_bbox[1]
|
| 135 |
+
landmark_list_shift.append(landmark_array)
|
| 136 |
+
bbox_union.append(f_landmark_bbox)
|
| 137 |
+
face_shapes.append((x_max - x_min, y_max - y_min))
|
| 138 |
+
|
| 139 |
+
return landmark_list_shift, bbox_union, face_shapes
|
| 140 |
+
|
| 141 |
+
def resize_landmark(landmark, w, h, new_w, new_h):
|
| 142 |
+
landmark_norm = landmark / [w, h]
|
| 143 |
+
landmark_resized = landmark_norm * [new_w, new_h]
|
| 144 |
+
|
| 145 |
+
return landmark_resized
|
| 146 |
+
|
| 147 |
+
def get_src_idx(drive_idx, T, sample_method,landmarks_list,image_shapes,top_k_ratio):
|
| 148 |
+
"""
|
| 149 |
+
Calculate the source index (src_idx) based on the given drive index, T, s, e, and sampling method.
|
| 150 |
+
|
| 151 |
+
Parameters:
|
| 152 |
+
- drive_idx (int): The current drive index.
|
| 153 |
+
- T (int): Total number of frames or a specific range limit.
|
| 154 |
+
- sample_method (str): Sampling method, which can be "random" or other methods.
|
| 155 |
+
- landmarks_list (list): List of facial landmarks.
|
| 156 |
+
- image_shapes (list): List of image shapes.
|
| 157 |
+
- top_k_ratio (float): Ratio for selecting top k similar frames.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
- src_idx (int): The calculated source index.
|
| 161 |
+
"""
|
| 162 |
+
if sample_method == "random":
|
| 163 |
+
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
| 164 |
+
elif sample_method == "pose_similarity":
|
| 165 |
+
top_k = int(top_k_ratio*len(landmarks_list))
|
| 166 |
+
try:
|
| 167 |
+
top_k = int(top_k_ratio*len(landmarks_list))
|
| 168 |
+
# facial contour
|
| 169 |
+
landmark_start_idx = 0
|
| 170 |
+
landmark_end_idx = 16
|
| 171 |
+
pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
|
| 172 |
+
src_idx = random.choice(pose_similarity_list)
|
| 173 |
+
while abs(src_idx-drive_idx)<5:
|
| 174 |
+
src_idx = random.choice(pose_similarity_list)
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(e)
|
| 177 |
+
return None
|
| 178 |
+
elif sample_method=="pose_similarity_and_closed_mouth":
|
| 179 |
+
# facial contour
|
| 180 |
+
landmark_start_idx = 0
|
| 181 |
+
landmark_end_idx = 16
|
| 182 |
+
try:
|
| 183 |
+
top_k = int(top_k_ratio*len(landmarks_list))
|
| 184 |
+
closed_mouth_list = get_closed_mouth(landmarks_list, ascending=True,top_k=top_k)
|
| 185 |
+
#print("closed_mouth_list",closed_mouth_list)
|
| 186 |
+
pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
|
| 187 |
+
#print("pose_similarity_list",pose_similarity_list)
|
| 188 |
+
common_list = list(set(closed_mouth_list).intersection(set(pose_similarity_list)))
|
| 189 |
+
if len(common_list) == 0:
|
| 190 |
+
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
| 191 |
+
else:
|
| 192 |
+
src_idx = random.choice(common_list)
|
| 193 |
+
|
| 194 |
+
while abs(src_idx-drive_idx) <5:
|
| 195 |
+
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
| 196 |
+
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(e)
|
| 199 |
+
return None
|
| 200 |
+
|
| 201 |
+
elif sample_method=="pose_similarity_and_mouth_dissimilarity":
|
| 202 |
+
top_k = int(top_k_ratio*len(landmarks_list))
|
| 203 |
+
try:
|
| 204 |
+
top_k = int(top_k_ratio*len(landmarks_list))
|
| 205 |
+
|
| 206 |
+
# facial contour for 68 landmarks format
|
| 207 |
+
landmark_start_idx = 0
|
| 208 |
+
landmark_end_idx = 16
|
| 209 |
+
|
| 210 |
+
pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
|
| 211 |
+
|
| 212 |
+
# Mouth inner coutour for 68 landmarks format
|
| 213 |
+
landmark_start_idx = 60
|
| 214 |
+
landmark_end_idx = 67
|
| 215 |
+
|
| 216 |
+
mouth_dissimilarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=False)
|
| 217 |
+
|
| 218 |
+
common_list = list(set(pose_similarity_list).intersection(set(mouth_dissimilarity_list)))
|
| 219 |
+
if len(common_list) == 0:
|
| 220 |
+
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
| 221 |
+
else:
|
| 222 |
+
src_idx = random.choice(common_list)
|
| 223 |
+
|
| 224 |
+
while abs(src_idx-drive_idx) <5:
|
| 225 |
+
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
| 226 |
+
|
| 227 |
+
except Exception as e:
|
| 228 |
+
print(e)
|
| 229 |
+
return None
|
| 230 |
+
|
| 231 |
+
else:
|
| 232 |
+
raise ValueError(f"Unknown sample_method: {sample_method}")
|
| 233 |
+
return src_idx
|
musetalk/loss/basic_loss.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from omegaconf import OmegaConf
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import nn, optim
|
| 7 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 8 |
+
from musetalk.loss.discriminator import MultiScaleDiscriminator,DiscriminatorFullModel
|
| 9 |
+
import musetalk.loss.vgg_face as vgg_face
|
| 10 |
+
|
| 11 |
+
class Interpolate(nn.Module):
|
| 12 |
+
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
|
| 13 |
+
super(Interpolate, self).__init__()
|
| 14 |
+
self.size = size
|
| 15 |
+
self.scale_factor = scale_factor
|
| 16 |
+
self.mode = mode
|
| 17 |
+
self.align_corners = align_corners
|
| 18 |
+
|
| 19 |
+
def forward(self, input):
|
| 20 |
+
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
|
| 21 |
+
|
| 22 |
+
def set_requires_grad(net, requires_grad=False):
|
| 23 |
+
if net is not None:
|
| 24 |
+
for param in net.parameters():
|
| 25 |
+
param.requires_grad = requires_grad
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
cfg = OmegaConf.load("config/audio_adapter/E7.yaml")
|
| 29 |
+
|
| 30 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 31 |
+
pyramid_scale = [1, 0.5, 0.25, 0.125]
|
| 32 |
+
vgg_IN = vgg_face.Vgg19().to(device)
|
| 33 |
+
pyramid = vgg_face.ImagePyramide(cfg.loss_params.pyramid_scale, 3).to(device)
|
| 34 |
+
vgg_IN.eval()
|
| 35 |
+
downsampler = Interpolate(size=(224, 224), mode='bilinear', align_corners=False)
|
| 36 |
+
|
| 37 |
+
image = torch.rand(8, 3, 256, 256).to(device)
|
| 38 |
+
image_pred = torch.rand(8, 3, 256, 256).to(device)
|
| 39 |
+
pyramide_real = pyramid(downsampler(image))
|
| 40 |
+
pyramide_generated = pyramid(downsampler(image_pred))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
loss_IN = 0
|
| 44 |
+
for scale in cfg.loss_params.pyramid_scale:
|
| 45 |
+
x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)])
|
| 46 |
+
y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)])
|
| 47 |
+
for i, weight in enumerate(cfg.loss_params.vgg_layer_weight):
|
| 48 |
+
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
| 49 |
+
loss_IN += weight * value
|
| 50 |
+
loss_IN /= sum(cfg.loss_params.vgg_layer_weight) # 对vgg不同层取均值,金字塔loss是每层叠
|
| 51 |
+
print(loss_IN)
|
| 52 |
+
|
| 53 |
+
#print(cfg.model_params.discriminator_params)
|
| 54 |
+
|
| 55 |
+
discriminator = MultiScaleDiscriminator(**cfg.model_params.discriminator_params).to(device)
|
| 56 |
+
discriminator_full = DiscriminatorFullModel(discriminator)
|
| 57 |
+
disc_scales = cfg.model_params.discriminator_params.scales
|
| 58 |
+
# Prepare optimizer and loss function
|
| 59 |
+
optimizer_D = optim.AdamW(discriminator.parameters(),
|
| 60 |
+
lr=cfg.discriminator_train_params.lr,
|
| 61 |
+
weight_decay=cfg.discriminator_train_params.weight_decay,
|
| 62 |
+
betas=cfg.discriminator_train_params.betas,
|
| 63 |
+
eps=cfg.discriminator_train_params.eps)
|
| 64 |
+
scheduler_D = CosineAnnealingLR(optimizer_D,
|
| 65 |
+
T_max=cfg.discriminator_train_params.epochs,
|
| 66 |
+
eta_min=1e-6)
|
| 67 |
+
|
| 68 |
+
discriminator.train()
|
| 69 |
+
|
| 70 |
+
set_requires_grad(discriminator, False)
|
| 71 |
+
|
| 72 |
+
loss_G = 0.
|
| 73 |
+
discriminator_maps_generated = discriminator(pyramide_generated)
|
| 74 |
+
discriminator_maps_real = discriminator(pyramide_real)
|
| 75 |
+
|
| 76 |
+
for scale in disc_scales:
|
| 77 |
+
key = 'prediction_map_%s' % scale
|
| 78 |
+
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
|
| 79 |
+
loss_G += value
|
| 80 |
+
|
| 81 |
+
print(loss_G)
|
musetalk/loss/conv.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
class Conv2d(nn.Module):
|
| 6 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
| 7 |
+
super().__init__(*args, **kwargs)
|
| 8 |
+
self.conv_block = nn.Sequential(
|
| 9 |
+
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
| 10 |
+
nn.BatchNorm2d(cout)
|
| 11 |
+
)
|
| 12 |
+
self.act = nn.ReLU()
|
| 13 |
+
self.residual = residual
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
out = self.conv_block(x)
|
| 17 |
+
if self.residual:
|
| 18 |
+
out += x
|
| 19 |
+
return self.act(out)
|
| 20 |
+
|
| 21 |
+
class nonorm_Conv2d(nn.Module):
|
| 22 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
| 23 |
+
super().__init__(*args, **kwargs)
|
| 24 |
+
self.conv_block = nn.Sequential(
|
| 25 |
+
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
| 26 |
+
)
|
| 27 |
+
self.act = nn.LeakyReLU(0.01, inplace=True)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
out = self.conv_block(x)
|
| 31 |
+
return self.act(out)
|
| 32 |
+
|
| 33 |
+
class Conv2dTranspose(nn.Module):
|
| 34 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
|
| 35 |
+
super().__init__(*args, **kwargs)
|
| 36 |
+
self.conv_block = nn.Sequential(
|
| 37 |
+
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
|
| 38 |
+
nn.BatchNorm2d(cout)
|
| 39 |
+
)
|
| 40 |
+
self.act = nn.ReLU()
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
out = self.conv_block(x)
|
| 44 |
+
return self.act(out)
|
musetalk/loss/discriminator.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch
|
| 4 |
+
from musetalk.loss.vgg_face import ImagePyramide
|
| 5 |
+
|
| 6 |
+
class DownBlock2d(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Simple block for processing video (encoder).
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
|
| 12 |
+
super(DownBlock2d, self).__init__()
|
| 13 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
|
| 14 |
+
|
| 15 |
+
if sn:
|
| 16 |
+
self.conv = nn.utils.spectral_norm(self.conv)
|
| 17 |
+
|
| 18 |
+
if norm:
|
| 19 |
+
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
| 20 |
+
else:
|
| 21 |
+
self.norm = None
|
| 22 |
+
self.pool = pool
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
out = x
|
| 26 |
+
out = self.conv(out)
|
| 27 |
+
if self.norm:
|
| 28 |
+
out = self.norm(out)
|
| 29 |
+
out = F.leaky_relu(out, 0.2)
|
| 30 |
+
if self.pool:
|
| 31 |
+
out = F.avg_pool2d(out, (2, 2))
|
| 32 |
+
return out
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Discriminator(nn.Module):
|
| 36 |
+
"""
|
| 37 |
+
Discriminator similar to Pix2Pix
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
|
| 41 |
+
sn=False, **kwargs):
|
| 42 |
+
super(Discriminator, self).__init__()
|
| 43 |
+
|
| 44 |
+
down_blocks = []
|
| 45 |
+
for i in range(num_blocks):
|
| 46 |
+
down_blocks.append(
|
| 47 |
+
DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
| 48 |
+
min(max_features, block_expansion * (2 ** (i + 1))),
|
| 49 |
+
norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
|
| 50 |
+
|
| 51 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
| 52 |
+
self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
|
| 53 |
+
if sn:
|
| 54 |
+
self.conv = nn.utils.spectral_norm(self.conv)
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
feature_maps = []
|
| 58 |
+
out = x
|
| 59 |
+
|
| 60 |
+
for down_block in self.down_blocks:
|
| 61 |
+
feature_maps.append(down_block(out))
|
| 62 |
+
out = feature_maps[-1]
|
| 63 |
+
prediction_map = self.conv(out)
|
| 64 |
+
|
| 65 |
+
return feature_maps, prediction_map
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class MultiScaleDiscriminator(nn.Module):
|
| 69 |
+
"""
|
| 70 |
+
Multi-scale (scale) discriminator
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, scales=(), **kwargs):
|
| 74 |
+
super(MultiScaleDiscriminator, self).__init__()
|
| 75 |
+
self.scales = scales
|
| 76 |
+
discs = {}
|
| 77 |
+
for scale in scales:
|
| 78 |
+
discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
|
| 79 |
+
self.discs = nn.ModuleDict(discs)
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
out_dict = {}
|
| 83 |
+
for scale, disc in self.discs.items():
|
| 84 |
+
scale = str(scale).replace('-', '.')
|
| 85 |
+
key = 'prediction_' + scale
|
| 86 |
+
#print(key)
|
| 87 |
+
#print(x)
|
| 88 |
+
feature_maps, prediction_map = disc(x[key])
|
| 89 |
+
out_dict['feature_maps_' + scale] = feature_maps
|
| 90 |
+
out_dict['prediction_map_' + scale] = prediction_map
|
| 91 |
+
return out_dict
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class DiscriminatorFullModel(torch.nn.Module):
|
| 96 |
+
"""
|
| 97 |
+
Merge all discriminator related updates into single model for better multi-gpu usage
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, discriminator):
|
| 101 |
+
super(DiscriminatorFullModel, self).__init__()
|
| 102 |
+
self.discriminator = discriminator
|
| 103 |
+
self.scales = self.discriminator.scales
|
| 104 |
+
print("scales",self.scales)
|
| 105 |
+
self.pyramid = ImagePyramide(self.scales, 3)
|
| 106 |
+
if torch.cuda.is_available():
|
| 107 |
+
self.pyramid = self.pyramid.cuda()
|
| 108 |
+
|
| 109 |
+
self.zero_tensor = None
|
| 110 |
+
|
| 111 |
+
def get_zero_tensor(self, input):
|
| 112 |
+
if self.zero_tensor is None:
|
| 113 |
+
self.zero_tensor = torch.FloatTensor(1).fill_(0).cuda()
|
| 114 |
+
self.zero_tensor.requires_grad_(False)
|
| 115 |
+
return self.zero_tensor.expand_as(input)
|
| 116 |
+
|
| 117 |
+
def forward(self, x, generated, gan_mode='ls'):
|
| 118 |
+
pyramide_real = self.pyramid(x)
|
| 119 |
+
pyramide_generated = self.pyramid(generated.detach())
|
| 120 |
+
|
| 121 |
+
discriminator_maps_generated = self.discriminator(pyramide_generated)
|
| 122 |
+
discriminator_maps_real = self.discriminator(pyramide_real)
|
| 123 |
+
|
| 124 |
+
value_total = 0
|
| 125 |
+
for scale in self.scales:
|
| 126 |
+
key = 'prediction_map_%s' % scale
|
| 127 |
+
if gan_mode == 'hinge':
|
| 128 |
+
value = -torch.mean(torch.min(discriminator_maps_real[key]-1, self.get_zero_tensor(discriminator_maps_real[key]))) - torch.mean(torch.min(-discriminator_maps_generated[key]-1, self.get_zero_tensor(discriminator_maps_generated[key])))
|
| 129 |
+
elif gan_mode == 'ls':
|
| 130 |
+
value = ((1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2).mean()
|
| 131 |
+
else:
|
| 132 |
+
raise ValueError('Unexpected gan_mode {}'.format(self.train_params['gan_mode']))
|
| 133 |
+
|
| 134 |
+
value_total += value
|
| 135 |
+
|
| 136 |
+
return value_total
|
| 137 |
+
|
| 138 |
+
def main():
|
| 139 |
+
discriminator = MultiScaleDiscriminator(scales=[1],
|
| 140 |
+
block_expansion=32,
|
| 141 |
+
max_features=512,
|
| 142 |
+
num_blocks=4,
|
| 143 |
+
sn=True,
|
| 144 |
+
image_channel=3,
|
| 145 |
+
estimate_jacobian=False)
|
musetalk/loss/resnet.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
__all__ = ['ResNet', 'resnet50']
|
| 5 |
+
|
| 6 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 7 |
+
"""3x3 convolution with padding"""
|
| 8 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 9 |
+
padding=1, bias=False)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BasicBlock(nn.Module):
|
| 13 |
+
expansion = 1
|
| 14 |
+
|
| 15 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 16 |
+
super(BasicBlock, self).__init__()
|
| 17 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 19 |
+
self.relu = nn.ReLU(inplace=True)
|
| 20 |
+
self.conv2 = conv3x3(planes, planes)
|
| 21 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 22 |
+
self.downsample = downsample
|
| 23 |
+
self.stride = stride
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
residual = x
|
| 27 |
+
|
| 28 |
+
out = self.conv1(x)
|
| 29 |
+
out = self.bn1(out)
|
| 30 |
+
out = self.relu(out)
|
| 31 |
+
|
| 32 |
+
out = self.conv2(out)
|
| 33 |
+
out = self.bn2(out)
|
| 34 |
+
|
| 35 |
+
if self.downsample is not None:
|
| 36 |
+
residual = self.downsample(x)
|
| 37 |
+
|
| 38 |
+
out += residual
|
| 39 |
+
out = self.relu(out)
|
| 40 |
+
|
| 41 |
+
return out
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Bottleneck(nn.Module):
|
| 45 |
+
expansion = 4
|
| 46 |
+
|
| 47 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 48 |
+
super(Bottleneck, self).__init__()
|
| 49 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
|
| 50 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 51 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 52 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 53 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
| 54 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
| 55 |
+
self.relu = nn.ReLU(inplace=True)
|
| 56 |
+
self.downsample = downsample
|
| 57 |
+
self.stride = stride
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
residual = x
|
| 61 |
+
|
| 62 |
+
out = self.conv1(x)
|
| 63 |
+
out = self.bn1(out)
|
| 64 |
+
out = self.relu(out)
|
| 65 |
+
|
| 66 |
+
out = self.conv2(out)
|
| 67 |
+
out = self.bn2(out)
|
| 68 |
+
out = self.relu(out)
|
| 69 |
+
|
| 70 |
+
out = self.conv3(out)
|
| 71 |
+
out = self.bn3(out)
|
| 72 |
+
|
| 73 |
+
if self.downsample is not None:
|
| 74 |
+
residual = self.downsample(x)
|
| 75 |
+
|
| 76 |
+
out += residual
|
| 77 |
+
out = self.relu(out)
|
| 78 |
+
|
| 79 |
+
return out
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ResNet(nn.Module):
|
| 83 |
+
|
| 84 |
+
def __init__(self, block, layers, num_classes=1000, include_top=True):
|
| 85 |
+
self.inplanes = 64
|
| 86 |
+
super(ResNet, self).__init__()
|
| 87 |
+
self.include_top = include_top
|
| 88 |
+
|
| 89 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
| 90 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 91 |
+
self.relu = nn.ReLU(inplace=True)
|
| 92 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True)
|
| 93 |
+
|
| 94 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 95 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 96 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 97 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 98 |
+
self.avgpool = nn.AvgPool2d(7, stride=1)
|
| 99 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 100 |
+
|
| 101 |
+
for m in self.modules():
|
| 102 |
+
if isinstance(m, nn.Conv2d):
|
| 103 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 104 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 105 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 106 |
+
m.weight.data.fill_(1)
|
| 107 |
+
m.bias.data.zero_()
|
| 108 |
+
|
| 109 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 110 |
+
downsample = None
|
| 111 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 112 |
+
downsample = nn.Sequential(
|
| 113 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 114 |
+
kernel_size=1, stride=stride, bias=False),
|
| 115 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
layers = []
|
| 119 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 120 |
+
self.inplanes = planes * block.expansion
|
| 121 |
+
for i in range(1, blocks):
|
| 122 |
+
layers.append(block(self.inplanes, planes))
|
| 123 |
+
|
| 124 |
+
return nn.Sequential(*layers)
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
x = x * 255.
|
| 128 |
+
x = x.flip(1)
|
| 129 |
+
x = self.conv1(x)
|
| 130 |
+
x = self.bn1(x)
|
| 131 |
+
x = self.relu(x)
|
| 132 |
+
x = self.maxpool(x)
|
| 133 |
+
|
| 134 |
+
x = self.layer1(x)
|
| 135 |
+
x = self.layer2(x)
|
| 136 |
+
x = self.layer3(x)
|
| 137 |
+
x = self.layer4(x)
|
| 138 |
+
|
| 139 |
+
x = self.avgpool(x)
|
| 140 |
+
|
| 141 |
+
if not self.include_top:
|
| 142 |
+
return x
|
| 143 |
+
|
| 144 |
+
x = x.view(x.size(0), -1)
|
| 145 |
+
x = self.fc(x)
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
def resnet50(**kwargs):
|
| 149 |
+
"""Constructs a ResNet-50 model.
|
| 150 |
+
"""
|
| 151 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
| 152 |
+
return model
|
musetalk/loss/syncnet.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
from .conv import Conv2d
|
| 6 |
+
|
| 7 |
+
logloss = nn.BCELoss(reduction="none")
|
| 8 |
+
def cosine_loss(a, v, y):
|
| 9 |
+
d = nn.functional.cosine_similarity(a, v)
|
| 10 |
+
d = d.clamp(0,1) # cosine_similarity的取值范围是【-1,1】,BCE如果输入负数会报错RuntimeError: CUDA error: device-side assert triggered
|
| 11 |
+
loss = logloss(d.unsqueeze(1), y).squeeze()
|
| 12 |
+
loss = loss.mean()
|
| 13 |
+
return loss, d
|
| 14 |
+
|
| 15 |
+
def get_sync_loss(
|
| 16 |
+
audio_embed,
|
| 17 |
+
gt_frames,
|
| 18 |
+
pred_frames,
|
| 19 |
+
syncnet,
|
| 20 |
+
adapted_weight,
|
| 21 |
+
frames_left_index=0,
|
| 22 |
+
frames_right_index=16,
|
| 23 |
+
):
|
| 24 |
+
# 跟gt_frames做随机的插入交换,节省显存开销
|
| 25 |
+
assert pred_frames.shape[1] == (frames_right_index - frames_left_index) * 3
|
| 26 |
+
# 3通道图像
|
| 27 |
+
frames_sync_loss = torch.cat(
|
| 28 |
+
[gt_frames[:, :3 * frames_left_index, ...], pred_frames, gt_frames[:, 3 * frames_right_index:, ...]],
|
| 29 |
+
axis=1
|
| 30 |
+
)
|
| 31 |
+
vision_embed = syncnet.get_image_embed(frames_sync_loss)
|
| 32 |
+
y = torch.ones(frames_sync_loss.size(0), 1).float().to(audio_embed.device)
|
| 33 |
+
loss, score = cosine_loss(audio_embed, vision_embed, y)
|
| 34 |
+
return loss, score
|
| 35 |
+
|
| 36 |
+
class SyncNet_color(nn.Module):
|
| 37 |
+
def __init__(self):
|
| 38 |
+
super(SyncNet_color, self).__init__()
|
| 39 |
+
|
| 40 |
+
self.face_encoder = nn.Sequential(
|
| 41 |
+
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
|
| 42 |
+
|
| 43 |
+
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
|
| 44 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
| 45 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
| 46 |
+
|
| 47 |
+
Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
| 48 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
| 49 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
| 50 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
| 51 |
+
|
| 52 |
+
Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
| 53 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
| 54 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
| 55 |
+
|
| 56 |
+
Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
| 57 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
| 58 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
| 59 |
+
|
| 60 |
+
Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
| 61 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
|
| 62 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
| 63 |
+
|
| 64 |
+
self.audio_encoder = nn.Sequential(
|
| 65 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
| 66 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
| 67 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
| 68 |
+
|
| 69 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
| 70 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
| 71 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
| 72 |
+
|
| 73 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
| 74 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
| 75 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
| 76 |
+
|
| 77 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
| 78 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
| 79 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
| 80 |
+
|
| 81 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
| 82 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
| 83 |
+
|
| 84 |
+
def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
|
| 85 |
+
face_embedding = self.face_encoder(face_sequences)
|
| 86 |
+
audio_embedding = self.audio_encoder(audio_sequences)
|
| 87 |
+
|
| 88 |
+
audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
|
| 89 |
+
face_embedding = face_embedding.view(face_embedding.size(0), -1)
|
| 90 |
+
|
| 91 |
+
audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
|
| 92 |
+
face_embedding = F.normalize(face_embedding, p=2, dim=1)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
return audio_embedding, face_embedding
|
musetalk/loss/vgg_face.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
This part of code contains a pretrained vgg_face model.
|
| 3 |
+
ref link: https://github.com/prlz77/vgg-face.pytorch
|
| 4 |
+
'''
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.utils.model_zoo
|
| 8 |
+
import pickle
|
| 9 |
+
from musetalk.loss import resnet as ResNet
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
MODEL_URL = "https://github.com/claudio-unipv/vggface-pytorch/releases/download/v0.1/vggface-9d491dd7c30312.pth"
|
| 13 |
+
VGG_FACE_PATH = '/apdcephfs_cq8/share_1367250/zhentaoyu/Driving/00_VASA/00_data/models/pretrain_models/resnet50_ft_weight.pkl'
|
| 14 |
+
|
| 15 |
+
# It was 93.5940, 104.7624, 129.1863 before dividing by 255
|
| 16 |
+
MEAN_RGB = [
|
| 17 |
+
0.367035294117647,
|
| 18 |
+
0.41083294117647057,
|
| 19 |
+
0.5066129411764705
|
| 20 |
+
]
|
| 21 |
+
def load_state_dict(model, fname):
|
| 22 |
+
"""
|
| 23 |
+
Set parameters converted from Caffe models authors of VGGFace2 provide.
|
| 24 |
+
See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/.
|
| 25 |
+
|
| 26 |
+
Arguments:
|
| 27 |
+
model: model
|
| 28 |
+
fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle.
|
| 29 |
+
"""
|
| 30 |
+
with open(fname, 'rb') as f:
|
| 31 |
+
weights = pickle.load(f, encoding='latin1')
|
| 32 |
+
|
| 33 |
+
own_state = model.state_dict()
|
| 34 |
+
for name, param in weights.items():
|
| 35 |
+
if name in own_state:
|
| 36 |
+
try:
|
| 37 |
+
own_state[name].copy_(torch.from_numpy(param))
|
| 38 |
+
except Exception:
|
| 39 |
+
raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\
|
| 40 |
+
'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size()))
|
| 41 |
+
else:
|
| 42 |
+
raise KeyError('unexpected key "{}" in state_dict'.format(name))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def vggface2(pretrained=True):
|
| 46 |
+
vggface = ResNet.resnet50(num_classes=8631, include_top=True)
|
| 47 |
+
load_state_dict(vggface, VGG_FACE_PATH)
|
| 48 |
+
return vggface
|
| 49 |
+
|
| 50 |
+
def vggface(pretrained=False, **kwargs):
|
| 51 |
+
"""VGGFace model.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
pretrained (bool): If True, returns pre-trained model
|
| 55 |
+
"""
|
| 56 |
+
model = VggFace(**kwargs)
|
| 57 |
+
if pretrained:
|
| 58 |
+
state = torch.utils.model_zoo.load_url(MODEL_URL)
|
| 59 |
+
model.load_state_dict(state)
|
| 60 |
+
return model
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class VggFace(torch.nn.Module):
|
| 64 |
+
def __init__(self, classes=2622):
|
| 65 |
+
"""VGGFace model.
|
| 66 |
+
|
| 67 |
+
Face recognition network. It takes as input a Bx3x224x224
|
| 68 |
+
batch of face images and gives as output a BxC score vector
|
| 69 |
+
(C is the number of identities).
|
| 70 |
+
Input images need to be scaled in the 0-1 range and then
|
| 71 |
+
normalized with respect to the mean RGB used during training.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
classes (int): number of identities recognized by the
|
| 75 |
+
network
|
| 76 |
+
|
| 77 |
+
"""
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.conv1 = _ConvBlock(3, 64, 64)
|
| 80 |
+
self.conv2 = _ConvBlock(64, 128, 128)
|
| 81 |
+
self.conv3 = _ConvBlock(128, 256, 256, 256)
|
| 82 |
+
self.conv4 = _ConvBlock(256, 512, 512, 512)
|
| 83 |
+
self.conv5 = _ConvBlock(512, 512, 512, 512)
|
| 84 |
+
self.dropout = torch.nn.Dropout(0.5)
|
| 85 |
+
self.fc1 = torch.nn.Linear(7 * 7 * 512, 4096)
|
| 86 |
+
self.fc2 = torch.nn.Linear(4096, 4096)
|
| 87 |
+
self.fc3 = torch.nn.Linear(4096, classes)
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
x = self.conv1(x)
|
| 91 |
+
x = self.conv2(x)
|
| 92 |
+
x = self.conv3(x)
|
| 93 |
+
x = self.conv4(x)
|
| 94 |
+
x = self.conv5(x)
|
| 95 |
+
x = x.view(x.size(0), -1)
|
| 96 |
+
x = self.dropout(F.relu(self.fc1(x)))
|
| 97 |
+
x = self.dropout(F.relu(self.fc2(x)))
|
| 98 |
+
x = self.fc3(x)
|
| 99 |
+
return x
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class _ConvBlock(torch.nn.Module):
|
| 103 |
+
"""A Convolutional block."""
|
| 104 |
+
|
| 105 |
+
def __init__(self, *units):
|
| 106 |
+
"""Create a block with len(units) - 1 convolutions.
|
| 107 |
+
|
| 108 |
+
convolution number i transforms the number of channels from
|
| 109 |
+
units[i - 1] to units[i] channels.
|
| 110 |
+
|
| 111 |
+
"""
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.convs = torch.nn.ModuleList([
|
| 114 |
+
torch.nn.Conv2d(in_, out, 3, 1, 1)
|
| 115 |
+
for in_, out in zip(units[:-1], units[1:])
|
| 116 |
+
])
|
| 117 |
+
|
| 118 |
+
def forward(self, x):
|
| 119 |
+
# Each convolution is followed by a ReLU, then the block is
|
| 120 |
+
# concluded by a max pooling.
|
| 121 |
+
for c in self.convs:
|
| 122 |
+
x = F.relu(c(x))
|
| 123 |
+
return F.max_pool2d(x, 2, 2, 0, ceil_mode=True)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
import numpy as np
|
| 128 |
+
from torchvision import models
|
| 129 |
+
class Vgg19(torch.nn.Module):
|
| 130 |
+
"""
|
| 131 |
+
Vgg19 network for perceptual loss.
|
| 132 |
+
"""
|
| 133 |
+
def __init__(self, requires_grad=False):
|
| 134 |
+
super(Vgg19, self).__init__()
|
| 135 |
+
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
| 136 |
+
self.slice1 = torch.nn.Sequential()
|
| 137 |
+
self.slice2 = torch.nn.Sequential()
|
| 138 |
+
self.slice3 = torch.nn.Sequential()
|
| 139 |
+
self.slice4 = torch.nn.Sequential()
|
| 140 |
+
self.slice5 = torch.nn.Sequential()
|
| 141 |
+
for x in range(2):
|
| 142 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
| 143 |
+
for x in range(2, 7):
|
| 144 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
| 145 |
+
for x in range(7, 12):
|
| 146 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
| 147 |
+
for x in range(12, 21):
|
| 148 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
| 149 |
+
for x in range(21, 30):
|
| 150 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
| 151 |
+
|
| 152 |
+
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
|
| 153 |
+
requires_grad=False)
|
| 154 |
+
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
|
| 155 |
+
requires_grad=False)
|
| 156 |
+
|
| 157 |
+
if not requires_grad:
|
| 158 |
+
for param in self.parameters():
|
| 159 |
+
param.requires_grad = False
|
| 160 |
+
|
| 161 |
+
def forward(self, X):
|
| 162 |
+
X = (X - self.mean) / self.std
|
| 163 |
+
h_relu1 = self.slice1(X)
|
| 164 |
+
h_relu2 = self.slice2(h_relu1)
|
| 165 |
+
h_relu3 = self.slice3(h_relu2)
|
| 166 |
+
h_relu4 = self.slice4(h_relu3)
|
| 167 |
+
h_relu5 = self.slice5(h_relu4)
|
| 168 |
+
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
| 169 |
+
return out
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
from torch import nn
|
| 173 |
+
class AntiAliasInterpolation2d(nn.Module):
|
| 174 |
+
"""
|
| 175 |
+
Band-limited downsampling, for better preservation of the input signal.
|
| 176 |
+
"""
|
| 177 |
+
def __init__(self, channels, scale):
|
| 178 |
+
super(AntiAliasInterpolation2d, self).__init__()
|
| 179 |
+
sigma = (1 / scale - 1) / 2
|
| 180 |
+
kernel_size = 2 * round(sigma * 4) + 1
|
| 181 |
+
self.ka = kernel_size // 2
|
| 182 |
+
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
|
| 183 |
+
|
| 184 |
+
kernel_size = [kernel_size, kernel_size]
|
| 185 |
+
sigma = [sigma, sigma]
|
| 186 |
+
# The gaussian kernel is the product of the
|
| 187 |
+
# gaussian function of each dimension.
|
| 188 |
+
kernel = 1
|
| 189 |
+
meshgrids = torch.meshgrid(
|
| 190 |
+
[
|
| 191 |
+
torch.arange(size, dtype=torch.float32)
|
| 192 |
+
for size in kernel_size
|
| 193 |
+
]
|
| 194 |
+
)
|
| 195 |
+
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
| 196 |
+
mean = (size - 1) / 2
|
| 197 |
+
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
|
| 198 |
+
|
| 199 |
+
# Make sure sum of values in gaussian kernel equals 1.
|
| 200 |
+
kernel = kernel / torch.sum(kernel)
|
| 201 |
+
# Reshape to depthwise convolutional weight
|
| 202 |
+
kernel = kernel.view(1, 1, *kernel.size())
|
| 203 |
+
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
| 204 |
+
|
| 205 |
+
self.register_buffer('weight', kernel)
|
| 206 |
+
self.groups = channels
|
| 207 |
+
self.scale = scale
|
| 208 |
+
inv_scale = 1 / scale
|
| 209 |
+
self.int_inv_scale = int(inv_scale)
|
| 210 |
+
|
| 211 |
+
def forward(self, input):
|
| 212 |
+
if self.scale == 1.0:
|
| 213 |
+
return input
|
| 214 |
+
|
| 215 |
+
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
|
| 216 |
+
out = F.conv2d(out, weight=self.weight, groups=self.groups)
|
| 217 |
+
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
|
| 218 |
+
|
| 219 |
+
return out
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class ImagePyramide(torch.nn.Module):
|
| 223 |
+
"""
|
| 224 |
+
Create image pyramide for computing pyramide perceptual loss.
|
| 225 |
+
"""
|
| 226 |
+
def __init__(self, scales, num_channels):
|
| 227 |
+
super(ImagePyramide, self).__init__()
|
| 228 |
+
downs = {}
|
| 229 |
+
for scale in scales:
|
| 230 |
+
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
|
| 231 |
+
self.downs = nn.ModuleDict(downs)
|
| 232 |
+
|
| 233 |
+
def forward(self, x):
|
| 234 |
+
out_dict = {}
|
| 235 |
+
for scale, down_module in self.downs.items():
|
| 236 |
+
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
|
| 237 |
+
return out_dict
|
musetalk/models/syncnet.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/models/stable_syncnet.py).
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from diffusers.models.attention import Attention as CrossAttention, FeedForward
|
| 14 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 15 |
+
from einops import rearrange
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SyncNet(nn.Module):
|
| 19 |
+
def __init__(self, config):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.audio_encoder = DownEncoder2D(
|
| 22 |
+
in_channels=config["audio_encoder"]["in_channels"],
|
| 23 |
+
block_out_channels=config["audio_encoder"]["block_out_channels"],
|
| 24 |
+
downsample_factors=config["audio_encoder"]["downsample_factors"],
|
| 25 |
+
dropout=config["audio_encoder"]["dropout"],
|
| 26 |
+
attn_blocks=config["audio_encoder"]["attn_blocks"],
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
self.visual_encoder = DownEncoder2D(
|
| 30 |
+
in_channels=config["visual_encoder"]["in_channels"],
|
| 31 |
+
block_out_channels=config["visual_encoder"]["block_out_channels"],
|
| 32 |
+
downsample_factors=config["visual_encoder"]["downsample_factors"],
|
| 33 |
+
dropout=config["visual_encoder"]["dropout"],
|
| 34 |
+
attn_blocks=config["visual_encoder"]["attn_blocks"],
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
self.eval()
|
| 38 |
+
|
| 39 |
+
def forward(self, image_sequences, audio_sequences):
|
| 40 |
+
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
| 41 |
+
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
| 42 |
+
|
| 43 |
+
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
| 44 |
+
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
| 45 |
+
|
| 46 |
+
# Make them unit vectors
|
| 47 |
+
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
| 48 |
+
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
| 49 |
+
|
| 50 |
+
return vision_embeds, audio_embeds
|
| 51 |
+
|
| 52 |
+
def get_image_embed(self, image_sequences):
|
| 53 |
+
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
| 54 |
+
|
| 55 |
+
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
| 56 |
+
|
| 57 |
+
# Make them unit vectors
|
| 58 |
+
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
| 59 |
+
|
| 60 |
+
return vision_embeds
|
| 61 |
+
|
| 62 |
+
def get_audio_embed(self, audio_sequences):
|
| 63 |
+
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
| 64 |
+
|
| 65 |
+
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
| 66 |
+
|
| 67 |
+
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
| 68 |
+
|
| 69 |
+
return audio_embeds
|
| 70 |
+
|
| 71 |
+
class ResnetBlock2D(nn.Module):
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
in_channels: int,
|
| 75 |
+
out_channels: int,
|
| 76 |
+
dropout: float = 0.0,
|
| 77 |
+
norm_num_groups: int = 32,
|
| 78 |
+
eps: float = 1e-6,
|
| 79 |
+
act_fn: str = "silu",
|
| 80 |
+
downsample_factor=2,
|
| 81 |
+
):
|
| 82 |
+
super().__init__()
|
| 83 |
+
|
| 84 |
+
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
| 85 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 86 |
+
|
| 87 |
+
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
|
| 88 |
+
self.dropout = nn.Dropout(dropout)
|
| 89 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 90 |
+
|
| 91 |
+
if act_fn == "relu":
|
| 92 |
+
self.act_fn = nn.ReLU()
|
| 93 |
+
elif act_fn == "silu":
|
| 94 |
+
self.act_fn = nn.SiLU()
|
| 95 |
+
|
| 96 |
+
if in_channels != out_channels:
|
| 97 |
+
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 98 |
+
else:
|
| 99 |
+
self.conv_shortcut = None
|
| 100 |
+
|
| 101 |
+
if isinstance(downsample_factor, list):
|
| 102 |
+
downsample_factor = tuple(downsample_factor)
|
| 103 |
+
|
| 104 |
+
if downsample_factor == 1:
|
| 105 |
+
self.downsample_conv = None
|
| 106 |
+
else:
|
| 107 |
+
self.downsample_conv = nn.Conv2d(
|
| 108 |
+
out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
|
| 109 |
+
)
|
| 110 |
+
self.pad = (0, 1, 0, 1)
|
| 111 |
+
if isinstance(downsample_factor, tuple):
|
| 112 |
+
if downsample_factor[0] == 1:
|
| 113 |
+
self.pad = (0, 1, 1, 1) # The padding order is from back to front
|
| 114 |
+
elif downsample_factor[1] == 1:
|
| 115 |
+
self.pad = (1, 1, 0, 1)
|
| 116 |
+
|
| 117 |
+
def forward(self, input_tensor):
|
| 118 |
+
hidden_states = input_tensor
|
| 119 |
+
|
| 120 |
+
hidden_states = self.norm1(hidden_states)
|
| 121 |
+
hidden_states = self.act_fn(hidden_states)
|
| 122 |
+
|
| 123 |
+
hidden_states = self.conv1(hidden_states)
|
| 124 |
+
hidden_states = self.norm2(hidden_states)
|
| 125 |
+
hidden_states = self.act_fn(hidden_states)
|
| 126 |
+
|
| 127 |
+
hidden_states = self.dropout(hidden_states)
|
| 128 |
+
hidden_states = self.conv2(hidden_states)
|
| 129 |
+
|
| 130 |
+
if self.conv_shortcut is not None:
|
| 131 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
| 132 |
+
|
| 133 |
+
hidden_states += input_tensor
|
| 134 |
+
|
| 135 |
+
if self.downsample_conv is not None:
|
| 136 |
+
hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
|
| 137 |
+
hidden_states = self.downsample_conv(hidden_states)
|
| 138 |
+
|
| 139 |
+
return hidden_states
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class AttentionBlock2D(nn.Module):
|
| 143 |
+
def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
|
| 144 |
+
super().__init__()
|
| 145 |
+
if not is_xformers_available():
|
| 146 |
+
raise ModuleNotFoundError(
|
| 147 |
+
"You have to install xformers to enable memory efficient attetion", name="xformers"
|
| 148 |
+
)
|
| 149 |
+
# inner_dim = dim_head * heads
|
| 150 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
|
| 151 |
+
self.norm2 = nn.LayerNorm(query_dim)
|
| 152 |
+
self.norm3 = nn.LayerNorm(query_dim)
|
| 153 |
+
|
| 154 |
+
self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
|
| 155 |
+
|
| 156 |
+
self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
| 157 |
+
self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
| 158 |
+
|
| 159 |
+
self.attn = CrossAttention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
|
| 160 |
+
self.attn._use_memory_efficient_attention_xformers = True
|
| 161 |
+
|
| 162 |
+
def forward(self, hidden_states):
|
| 163 |
+
assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
|
| 164 |
+
|
| 165 |
+
batch, channel, height, width = hidden_states.shape
|
| 166 |
+
residual = hidden_states
|
| 167 |
+
|
| 168 |
+
hidden_states = self.norm1(hidden_states)
|
| 169 |
+
hidden_states = self.conv_in(hidden_states)
|
| 170 |
+
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
|
| 171 |
+
|
| 172 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 173 |
+
hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
|
| 174 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
| 175 |
+
|
| 176 |
+
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width)
|
| 177 |
+
hidden_states = self.conv_out(hidden_states)
|
| 178 |
+
|
| 179 |
+
hidden_states = hidden_states + residual
|
| 180 |
+
return hidden_states
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class DownEncoder2D(nn.Module):
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
in_channels=4 * 16,
|
| 187 |
+
block_out_channels=[64, 128, 256, 256],
|
| 188 |
+
downsample_factors=[2, 2, 2, 2],
|
| 189 |
+
layers_per_block=2,
|
| 190 |
+
norm_num_groups=32,
|
| 191 |
+
attn_blocks=[1, 1, 1, 1],
|
| 192 |
+
dropout: float = 0.0,
|
| 193 |
+
act_fn="silu",
|
| 194 |
+
):
|
| 195 |
+
super().__init__()
|
| 196 |
+
self.layers_per_block = layers_per_block
|
| 197 |
+
|
| 198 |
+
# in
|
| 199 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
| 200 |
+
|
| 201 |
+
# down
|
| 202 |
+
self.down_blocks = nn.ModuleList([])
|
| 203 |
+
|
| 204 |
+
output_channels = block_out_channels[0]
|
| 205 |
+
for i, block_out_channel in enumerate(block_out_channels):
|
| 206 |
+
input_channels = output_channels
|
| 207 |
+
output_channels = block_out_channel
|
| 208 |
+
# is_final_block = i == len(block_out_channels) - 1
|
| 209 |
+
|
| 210 |
+
down_block = ResnetBlock2D(
|
| 211 |
+
in_channels=input_channels,
|
| 212 |
+
out_channels=output_channels,
|
| 213 |
+
downsample_factor=downsample_factors[i],
|
| 214 |
+
norm_num_groups=norm_num_groups,
|
| 215 |
+
dropout=dropout,
|
| 216 |
+
act_fn=act_fn,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
self.down_blocks.append(down_block)
|
| 220 |
+
|
| 221 |
+
if attn_blocks[i] == 1:
|
| 222 |
+
attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
|
| 223 |
+
self.down_blocks.append(attention_block)
|
| 224 |
+
|
| 225 |
+
# out
|
| 226 |
+
self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
| 227 |
+
self.act_fn_out = nn.ReLU()
|
| 228 |
+
|
| 229 |
+
def forward(self, hidden_states):
|
| 230 |
+
hidden_states = self.conv_in(hidden_states)
|
| 231 |
+
|
| 232 |
+
# down
|
| 233 |
+
for down_block in self.down_blocks:
|
| 234 |
+
hidden_states = down_block(hidden_states)
|
| 235 |
+
|
| 236 |
+
# post-process
|
| 237 |
+
hidden_states = self.norm_out(hidden_states)
|
| 238 |
+
hidden_states = self.act_fn_out(hidden_states)
|
| 239 |
+
|
| 240 |
+
return hidden_states
|
musetalk/utils/training_utils.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.optim as optim
|
| 7 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 8 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel
|
| 9 |
+
from transformers import WhisperModel
|
| 10 |
+
from diffusers.optimization import get_scheduler
|
| 11 |
+
from omegaconf import OmegaConf
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
|
| 14 |
+
from musetalk.models.syncnet import SyncNet
|
| 15 |
+
from musetalk.loss.discriminator import MultiScaleDiscriminator, DiscriminatorFullModel
|
| 16 |
+
from musetalk.loss.basic_loss import Interpolate
|
| 17 |
+
import musetalk.loss.vgg_face as vgg_face
|
| 18 |
+
from musetalk.data.dataset import PortraitDataset
|
| 19 |
+
from musetalk.utils.utils import (
|
| 20 |
+
get_image_pred,
|
| 21 |
+
process_audio_features,
|
| 22 |
+
process_and_save_images
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
class Net(nn.Module):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
unet: UNet2DConditionModel,
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.unet = unet
|
| 32 |
+
|
| 33 |
+
def forward(
|
| 34 |
+
self,
|
| 35 |
+
input_latents,
|
| 36 |
+
timesteps,
|
| 37 |
+
audio_prompts,
|
| 38 |
+
):
|
| 39 |
+
model_pred = self.unet(
|
| 40 |
+
input_latents,
|
| 41 |
+
timesteps,
|
| 42 |
+
encoder_hidden_states=audio_prompts
|
| 43 |
+
).sample
|
| 44 |
+
return model_pred
|
| 45 |
+
|
| 46 |
+
logger = logging.getLogger(__name__)
|
| 47 |
+
|
| 48 |
+
def initialize_models_and_optimizers(cfg, accelerator, weight_dtype):
|
| 49 |
+
"""Initialize models and optimizers"""
|
| 50 |
+
model_dict = {
|
| 51 |
+
'vae': None,
|
| 52 |
+
'unet': None,
|
| 53 |
+
'net': None,
|
| 54 |
+
'wav2vec': None,
|
| 55 |
+
'optimizer': None,
|
| 56 |
+
'lr_scheduler': None,
|
| 57 |
+
'scheduler_max_steps': None,
|
| 58 |
+
'trainable_params': None
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
model_dict['vae'] = AutoencoderKL.from_pretrained(
|
| 62 |
+
cfg.pretrained_model_name_or_path,
|
| 63 |
+
subfolder=cfg.vae_type,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
unet_config_file = os.path.join(
|
| 67 |
+
cfg.pretrained_model_name_or_path,
|
| 68 |
+
cfg.unet_sub_folder + "/musetalk.json"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
with open(unet_config_file, 'r') as f:
|
| 72 |
+
unet_config = json.load(f)
|
| 73 |
+
model_dict['unet'] = UNet2DConditionModel(**unet_config)
|
| 74 |
+
|
| 75 |
+
if not cfg.random_init_unet:
|
| 76 |
+
pretrained_unet_path = os.path.join(cfg.pretrained_model_name_or_path, cfg.unet_sub_folder, "pytorch_model.bin")
|
| 77 |
+
print(f"### Loading existing unet weights from {pretrained_unet_path}. ###")
|
| 78 |
+
checkpoint = torch.load(pretrained_unet_path, map_location=accelerator.device)
|
| 79 |
+
model_dict['unet'].load_state_dict(checkpoint)
|
| 80 |
+
|
| 81 |
+
unet_params = [p.numel() for n, p in model_dict['unet'].named_parameters()]
|
| 82 |
+
logger.info(f"unet {sum(unet_params) / 1e6}M-parameter")
|
| 83 |
+
|
| 84 |
+
model_dict['vae'].requires_grad_(False)
|
| 85 |
+
model_dict['unet'].requires_grad_(True)
|
| 86 |
+
|
| 87 |
+
model_dict['vae'].to(accelerator.device, dtype=weight_dtype)
|
| 88 |
+
|
| 89 |
+
model_dict['net'] = Net(model_dict['unet'])
|
| 90 |
+
|
| 91 |
+
model_dict['wav2vec'] = WhisperModel.from_pretrained(cfg.whisper_path).to(
|
| 92 |
+
device="cuda", dtype=weight_dtype).eval()
|
| 93 |
+
model_dict['wav2vec'].requires_grad_(False)
|
| 94 |
+
|
| 95 |
+
if cfg.solver.gradient_checkpointing:
|
| 96 |
+
model_dict['unet'].enable_gradient_checkpointing()
|
| 97 |
+
|
| 98 |
+
if cfg.solver.scale_lr:
|
| 99 |
+
learning_rate = (
|
| 100 |
+
cfg.solver.learning_rate
|
| 101 |
+
* cfg.solver.gradient_accumulation_steps
|
| 102 |
+
* cfg.data.train_bs
|
| 103 |
+
* accelerator.num_processes
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
learning_rate = cfg.solver.learning_rate
|
| 107 |
+
|
| 108 |
+
if cfg.solver.use_8bit_adam:
|
| 109 |
+
try:
|
| 110 |
+
import bitsandbytes as bnb
|
| 111 |
+
except ImportError:
|
| 112 |
+
raise ImportError(
|
| 113 |
+
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
| 114 |
+
)
|
| 115 |
+
optimizer_cls = bnb.optim.AdamW8bit
|
| 116 |
+
else:
|
| 117 |
+
optimizer_cls = torch.optim.AdamW
|
| 118 |
+
|
| 119 |
+
model_dict['trainable_params'] = list(filter(lambda p: p.requires_grad, model_dict['net'].parameters()))
|
| 120 |
+
if accelerator.is_main_process:
|
| 121 |
+
print('trainable params')
|
| 122 |
+
for n, p in model_dict['net'].named_parameters():
|
| 123 |
+
if p.requires_grad:
|
| 124 |
+
print(n)
|
| 125 |
+
|
| 126 |
+
model_dict['optimizer'] = optimizer_cls(
|
| 127 |
+
model_dict['trainable_params'],
|
| 128 |
+
lr=learning_rate,
|
| 129 |
+
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
|
| 130 |
+
weight_decay=cfg.solver.adam_weight_decay,
|
| 131 |
+
eps=cfg.solver.adam_epsilon,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
model_dict['scheduler_max_steps'] = cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps
|
| 135 |
+
model_dict['lr_scheduler'] = get_scheduler(
|
| 136 |
+
cfg.solver.lr_scheduler,
|
| 137 |
+
optimizer=model_dict['optimizer'],
|
| 138 |
+
num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps,
|
| 139 |
+
num_training_steps=model_dict['scheduler_max_steps'],
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
return model_dict
|
| 143 |
+
|
| 144 |
+
def initialize_dataloaders(cfg):
|
| 145 |
+
"""Initialize training and validation dataloaders"""
|
| 146 |
+
dataloader_dict = {
|
| 147 |
+
'train_dataset': None,
|
| 148 |
+
'val_dataset': None,
|
| 149 |
+
'train_dataloader': None,
|
| 150 |
+
'val_dataloader': None
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
dataloader_dict['train_dataset'] = PortraitDataset(cfg={
|
| 154 |
+
'image_size': cfg.data.image_size,
|
| 155 |
+
'T': cfg.data.n_sample_frames,
|
| 156 |
+
"sample_method": cfg.data.sample_method,
|
| 157 |
+
'top_k_ratio': cfg.data.top_k_ratio,
|
| 158 |
+
"contorl_face_min_size": cfg.data.contorl_face_min_size,
|
| 159 |
+
"dataset_key": cfg.data.dataset_key,
|
| 160 |
+
"padding_pixel_mouth": cfg.padding_pixel_mouth,
|
| 161 |
+
"whisper_path": cfg.whisper_path,
|
| 162 |
+
"min_face_size": cfg.data.min_face_size,
|
| 163 |
+
"cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
|
| 164 |
+
"cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
|
| 165 |
+
"crop_type": cfg.crop_type,
|
| 166 |
+
"random_margin_method": cfg.random_margin_method,
|
| 167 |
+
})
|
| 168 |
+
|
| 169 |
+
dataloader_dict['train_dataloader'] = torch.utils.data.DataLoader(
|
| 170 |
+
dataloader_dict['train_dataset'],
|
| 171 |
+
batch_size=cfg.data.train_bs,
|
| 172 |
+
shuffle=True,
|
| 173 |
+
num_workers=cfg.data.num_workers,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
dataloader_dict['val_dataset'] = PortraitDataset(cfg={
|
| 177 |
+
'image_size': cfg.data.image_size,
|
| 178 |
+
'T': cfg.data.n_sample_frames,
|
| 179 |
+
"sample_method": cfg.data.sample_method,
|
| 180 |
+
'top_k_ratio': cfg.data.top_k_ratio,
|
| 181 |
+
"contorl_face_min_size": cfg.data.contorl_face_min_size,
|
| 182 |
+
"dataset_key": cfg.data.dataset_key,
|
| 183 |
+
"padding_pixel_mouth": cfg.padding_pixel_mouth,
|
| 184 |
+
"whisper_path": cfg.whisper_path,
|
| 185 |
+
"min_face_size": cfg.data.min_face_size,
|
| 186 |
+
"cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
|
| 187 |
+
"cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
|
| 188 |
+
"crop_type": cfg.crop_type,
|
| 189 |
+
"random_margin_method": cfg.random_margin_method,
|
| 190 |
+
})
|
| 191 |
+
|
| 192 |
+
dataloader_dict['val_dataloader'] = torch.utils.data.DataLoader(
|
| 193 |
+
dataloader_dict['val_dataset'],
|
| 194 |
+
batch_size=cfg.data.train_bs,
|
| 195 |
+
shuffle=True,
|
| 196 |
+
num_workers=1,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
return dataloader_dict
|
| 200 |
+
|
| 201 |
+
def initialize_loss_functions(cfg, accelerator, scheduler_max_steps):
|
| 202 |
+
"""Initialize loss functions and discriminators"""
|
| 203 |
+
loss_dict = {
|
| 204 |
+
'L1_loss': nn.L1Loss(reduction='mean'),
|
| 205 |
+
'discriminator': None,
|
| 206 |
+
'mouth_discriminator': None,
|
| 207 |
+
'optimizer_D': None,
|
| 208 |
+
'mouth_optimizer_D': None,
|
| 209 |
+
'scheduler_D': None,
|
| 210 |
+
'mouth_scheduler_D': None,
|
| 211 |
+
'disc_scales': None,
|
| 212 |
+
'discriminator_full': None,
|
| 213 |
+
'mouth_discriminator_full': None
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
if cfg.loss_params.gan_loss > 0:
|
| 217 |
+
loss_dict['discriminator'] = MultiScaleDiscriminator(
|
| 218 |
+
**cfg.model_params.discriminator_params).to(accelerator.device)
|
| 219 |
+
loss_dict['discriminator_full'] = DiscriminatorFullModel(loss_dict['discriminator'])
|
| 220 |
+
loss_dict['disc_scales'] = cfg.model_params.discriminator_params.scales
|
| 221 |
+
loss_dict['optimizer_D'] = optim.AdamW(
|
| 222 |
+
loss_dict['discriminator'].parameters(),
|
| 223 |
+
lr=cfg.discriminator_train_params.lr,
|
| 224 |
+
weight_decay=cfg.discriminator_train_params.weight_decay,
|
| 225 |
+
betas=cfg.discriminator_train_params.betas,
|
| 226 |
+
eps=cfg.discriminator_train_params.eps)
|
| 227 |
+
loss_dict['scheduler_D'] = CosineAnnealingLR(
|
| 228 |
+
loss_dict['optimizer_D'],
|
| 229 |
+
T_max=scheduler_max_steps,
|
| 230 |
+
eta_min=1e-6
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
if cfg.loss_params.mouth_gan_loss > 0:
|
| 234 |
+
loss_dict['mouth_discriminator'] = MultiScaleDiscriminator(
|
| 235 |
+
**cfg.model_params.discriminator_params).to(accelerator.device)
|
| 236 |
+
loss_dict['mouth_discriminator_full'] = DiscriminatorFullModel(loss_dict['mouth_discriminator'])
|
| 237 |
+
loss_dict['mouth_optimizer_D'] = optim.AdamW(
|
| 238 |
+
loss_dict['mouth_discriminator'].parameters(),
|
| 239 |
+
lr=cfg.discriminator_train_params.lr,
|
| 240 |
+
weight_decay=cfg.discriminator_train_params.weight_decay,
|
| 241 |
+
betas=cfg.discriminator_train_params.betas,
|
| 242 |
+
eps=cfg.discriminator_train_params.eps)
|
| 243 |
+
loss_dict['mouth_scheduler_D'] = CosineAnnealingLR(
|
| 244 |
+
loss_dict['mouth_optimizer_D'],
|
| 245 |
+
T_max=scheduler_max_steps,
|
| 246 |
+
eta_min=1e-6
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
return loss_dict
|
| 250 |
+
|
| 251 |
+
def initialize_syncnet(cfg, accelerator, weight_dtype):
|
| 252 |
+
"""Initialize SyncNet model"""
|
| 253 |
+
if cfg.loss_params.sync_loss > 0 or cfg.use_adapted_weight:
|
| 254 |
+
if cfg.data.n_sample_frames != 16:
|
| 255 |
+
raise ValueError(
|
| 256 |
+
f"Invalid n_sample_frames {cfg.data.n_sample_frames} for sync_loss, it should be 16."
|
| 257 |
+
)
|
| 258 |
+
syncnet_config = OmegaConf.load(cfg.syncnet_config_path)
|
| 259 |
+
syncnet = SyncNet(OmegaConf.to_container(
|
| 260 |
+
syncnet_config.model)).to(accelerator.device)
|
| 261 |
+
print(
|
| 262 |
+
f"Load SyncNet checkpoint from: {syncnet_config.ckpt.inference_ckpt_path}")
|
| 263 |
+
checkpoint = torch.load(
|
| 264 |
+
syncnet_config.ckpt.inference_ckpt_path, map_location=accelerator.device)
|
| 265 |
+
syncnet.load_state_dict(checkpoint["state_dict"])
|
| 266 |
+
syncnet.to(dtype=weight_dtype)
|
| 267 |
+
syncnet.requires_grad_(False)
|
| 268 |
+
syncnet.eval()
|
| 269 |
+
return syncnet
|
| 270 |
+
return None
|
| 271 |
+
|
| 272 |
+
def initialize_vgg(cfg, accelerator):
|
| 273 |
+
"""Initialize VGG model"""
|
| 274 |
+
if cfg.loss_params.vgg_loss > 0:
|
| 275 |
+
vgg_IN = vgg_face.Vgg19().to(accelerator.device,)
|
| 276 |
+
pyramid = vgg_face.ImagePyramide(
|
| 277 |
+
cfg.loss_params.pyramid_scale, 3).to(accelerator.device)
|
| 278 |
+
vgg_IN.eval()
|
| 279 |
+
downsampler = Interpolate(
|
| 280 |
+
size=(224, 224), mode='bilinear', align_corners=False).to(accelerator.device)
|
| 281 |
+
return vgg_IN, pyramid, downsampler
|
| 282 |
+
return None, None, None
|
| 283 |
+
|
| 284 |
+
def validation(
|
| 285 |
+
cfg,
|
| 286 |
+
val_dataloader,
|
| 287 |
+
net,
|
| 288 |
+
vae,
|
| 289 |
+
wav2vec,
|
| 290 |
+
accelerator,
|
| 291 |
+
save_dir,
|
| 292 |
+
global_step,
|
| 293 |
+
weight_dtype,
|
| 294 |
+
syncnet_score=1,
|
| 295 |
+
):
|
| 296 |
+
"""Validation function for model evaluation"""
|
| 297 |
+
net.eval() # Set the model to evaluation mode
|
| 298 |
+
for batch in val_dataloader:
|
| 299 |
+
# The same ref_latents
|
| 300 |
+
ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
|
| 301 |
+
accelerator.device, non_blocking=True
|
| 302 |
+
)
|
| 303 |
+
pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
|
| 304 |
+
accelerator.device, non_blocking=True
|
| 305 |
+
)
|
| 306 |
+
bsz, num_frames, c, h, w = ref_pixel_values.shape
|
| 307 |
+
|
| 308 |
+
audio_prompts = process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype)
|
| 309 |
+
# audio feature for unet
|
| 310 |
+
audio_prompts = rearrange(
|
| 311 |
+
audio_prompts,
|
| 312 |
+
'b f c h w-> (b f) c h w'
|
| 313 |
+
)
|
| 314 |
+
audio_prompts = rearrange(
|
| 315 |
+
audio_prompts,
|
| 316 |
+
'(b f) c h w -> (b f) (c h) w',
|
| 317 |
+
b=bsz
|
| 318 |
+
)
|
| 319 |
+
# different masked_latents
|
| 320 |
+
image_pred_train = get_image_pred(
|
| 321 |
+
pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
|
| 322 |
+
image_pred_infer = get_image_pred(
|
| 323 |
+
ref_pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
|
| 324 |
+
|
| 325 |
+
process_and_save_images(
|
| 326 |
+
batch,
|
| 327 |
+
image_pred_train,
|
| 328 |
+
image_pred_infer,
|
| 329 |
+
save_dir,
|
| 330 |
+
global_step,
|
| 331 |
+
accelerator,
|
| 332 |
+
cfg.num_images_to_keep,
|
| 333 |
+
syncnet_score
|
| 334 |
+
)
|
| 335 |
+
# only infer 1 image in validation
|
| 336 |
+
break
|
| 337 |
+
net.train() # Set the model back to training mode
|
musetalk/utils/utils.py
CHANGED
|
@@ -2,6 +2,11 @@ import os
|
|
| 2 |
import cv2
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
ffmpeg_path = os.getenv('FFMPEG_PATH')
|
| 7 |
if ffmpeg_path is None:
|
|
@@ -11,7 +16,6 @@ elif ffmpeg_path not in os.getenv('PATH'):
|
|
| 11 |
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
|
| 12 |
|
| 13 |
|
| 14 |
-
from musetalk.whisper.audio2feature import Audio2Feature
|
| 15 |
from musetalk.models.vae import VAE
|
| 16 |
from musetalk.models.unet import UNet,PositionalEncoding
|
| 17 |
|
|
@@ -76,3 +80,248 @@ def datagen(
|
|
| 76 |
latent_batch = torch.cat(latent_batch, dim=0)
|
| 77 |
|
| 78 |
yield whisper_batch.to(device), latent_batch.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import cv2
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
+
from typing import Union, List
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
import shutil
|
| 9 |
+
import os.path as osp
|
| 10 |
|
| 11 |
ffmpeg_path = os.getenv('FFMPEG_PATH')
|
| 12 |
if ffmpeg_path is None:
|
|
|
|
| 16 |
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
|
| 17 |
|
| 18 |
|
|
|
|
| 19 |
from musetalk.models.vae import VAE
|
| 20 |
from musetalk.models.unet import UNet,PositionalEncoding
|
| 21 |
|
|
|
|
| 80 |
latent_batch = torch.cat(latent_batch, dim=0)
|
| 81 |
|
| 82 |
yield whisper_batch.to(device), latent_batch.to(device)
|
| 83 |
+
|
| 84 |
+
def cast_training_params(
|
| 85 |
+
model: Union[torch.nn.Module, List[torch.nn.Module]],
|
| 86 |
+
dtype=torch.float32,
|
| 87 |
+
):
|
| 88 |
+
if not isinstance(model, list):
|
| 89 |
+
model = [model]
|
| 90 |
+
for m in model:
|
| 91 |
+
for param in m.parameters():
|
| 92 |
+
# only upcast trainable parameters into fp32
|
| 93 |
+
if param.requires_grad:
|
| 94 |
+
param.data = param.to(dtype)
|
| 95 |
+
|
| 96 |
+
def rand_log_normal(
|
| 97 |
+
shape,
|
| 98 |
+
loc=0.,
|
| 99 |
+
scale=1.,
|
| 100 |
+
device='cpu',
|
| 101 |
+
dtype=torch.float32,
|
| 102 |
+
generator=None
|
| 103 |
+
):
|
| 104 |
+
"""Draws samples from an lognormal distribution."""
|
| 105 |
+
rnd_normal = torch.randn(
|
| 106 |
+
shape, device=device, dtype=dtype, generator=generator) # N(0, I)
|
| 107 |
+
sigma = (rnd_normal * scale + loc).exp()
|
| 108 |
+
return sigma
|
| 109 |
+
|
| 110 |
+
def get_mouth_region(frames, image_pred, pixel_values_face_mask):
|
| 111 |
+
# Initialize lists to store the results for each image in the batch
|
| 112 |
+
mouth_real_list = []
|
| 113 |
+
mouth_generated_list = []
|
| 114 |
+
|
| 115 |
+
# Process each image in the batch
|
| 116 |
+
for b in range(frames.shape[0]):
|
| 117 |
+
# Find the non-zero area in the face mask
|
| 118 |
+
non_zero_indices = torch.nonzero(pixel_values_face_mask[b])
|
| 119 |
+
# If there are no non-zero indices, skip this image
|
| 120 |
+
if non_zero_indices.numel() == 0:
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
min_y, max_y = torch.min(non_zero_indices[:, 1]), torch.max(
|
| 124 |
+
non_zero_indices[:, 1])
|
| 125 |
+
min_x, max_x = torch.min(non_zero_indices[:, 2]), torch.max(
|
| 126 |
+
non_zero_indices[:, 2])
|
| 127 |
+
|
| 128 |
+
# Crop the frames and image_pred according to the non-zero area
|
| 129 |
+
frames_cropped = frames[b, :, min_y:max_y, min_x:max_x]
|
| 130 |
+
image_pred_cropped = image_pred[b, :, min_y:max_y, min_x:max_x]
|
| 131 |
+
# Resize the cropped images to 256*256
|
| 132 |
+
frames_resized = F.interpolate(frames_cropped.unsqueeze(
|
| 133 |
+
0), size=(256, 256), mode='bilinear', align_corners=False)
|
| 134 |
+
image_pred_resized = F.interpolate(image_pred_cropped.unsqueeze(
|
| 135 |
+
0), size=(256, 256), mode='bilinear', align_corners=False)
|
| 136 |
+
|
| 137 |
+
# Append the resized images to the result lists
|
| 138 |
+
mouth_real_list.append(frames_resized)
|
| 139 |
+
mouth_generated_list.append(image_pred_resized)
|
| 140 |
+
|
| 141 |
+
# Convert the lists to tensors if they are not empty
|
| 142 |
+
mouth_real = torch.cat(mouth_real_list, dim=0) if mouth_real_list else None
|
| 143 |
+
mouth_generated = torch.cat(
|
| 144 |
+
mouth_generated_list, dim=0) if mouth_generated_list else None
|
| 145 |
+
|
| 146 |
+
return mouth_real, mouth_generated
|
| 147 |
+
|
| 148 |
+
def get_image_pred(pixel_values,
|
| 149 |
+
ref_pixel_values,
|
| 150 |
+
audio_prompts,
|
| 151 |
+
vae,
|
| 152 |
+
net,
|
| 153 |
+
weight_dtype):
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
bsz, num_frames, c, h, w = pixel_values.shape
|
| 156 |
+
|
| 157 |
+
masked_pixel_values = pixel_values.clone()
|
| 158 |
+
masked_pixel_values[:, :, :, h//2:, :] = -1
|
| 159 |
+
|
| 160 |
+
masked_frames = rearrange(
|
| 161 |
+
masked_pixel_values, 'b f c h w -> (b f) c h w')
|
| 162 |
+
masked_latents = vae.encode(masked_frames).latent_dist.mode()
|
| 163 |
+
masked_latents = masked_latents * vae.config.scaling_factor
|
| 164 |
+
masked_latents = masked_latents.float()
|
| 165 |
+
|
| 166 |
+
ref_frames = rearrange(ref_pixel_values, 'b f c h w-> (b f) c h w')
|
| 167 |
+
ref_latents = vae.encode(ref_frames).latent_dist.mode()
|
| 168 |
+
ref_latents = ref_latents * vae.config.scaling_factor
|
| 169 |
+
ref_latents = ref_latents.float()
|
| 170 |
+
|
| 171 |
+
input_latents = torch.cat([masked_latents, ref_latents], dim=1)
|
| 172 |
+
input_latents = input_latents.to(weight_dtype)
|
| 173 |
+
timesteps = torch.tensor([0], device=input_latents.device)
|
| 174 |
+
latents_pred = net(
|
| 175 |
+
input_latents,
|
| 176 |
+
timesteps,
|
| 177 |
+
audio_prompts,
|
| 178 |
+
)
|
| 179 |
+
latents_pred = (1 / vae.config.scaling_factor) * latents_pred
|
| 180 |
+
image_pred = vae.decode(latents_pred).sample
|
| 181 |
+
image_pred = image_pred.float()
|
| 182 |
+
|
| 183 |
+
return image_pred
|
| 184 |
+
|
| 185 |
+
def process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype):
|
| 186 |
+
with torch.no_grad():
|
| 187 |
+
audio_feature_length_per_frame = 2 * \
|
| 188 |
+
(cfg.data.audio_padding_length_left +
|
| 189 |
+
cfg.data.audio_padding_length_right + 1)
|
| 190 |
+
audio_feats = batch['audio_feature'].to(weight_dtype)
|
| 191 |
+
audio_feats = wav2vec.encoder(
|
| 192 |
+
audio_feats, output_hidden_states=True).hidden_states
|
| 193 |
+
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype) # [B, T, 10, 5, 384]
|
| 194 |
+
|
| 195 |
+
start_ts = batch['audio_offset']
|
| 196 |
+
step_ts = batch['audio_step']
|
| 197 |
+
audio_feats = torch.cat([torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_left]),
|
| 198 |
+
audio_feats,
|
| 199 |
+
torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_right])], 1)
|
| 200 |
+
audio_prompts = []
|
| 201 |
+
for bb in range(bsz):
|
| 202 |
+
audio_feats_list = []
|
| 203 |
+
for f in range(num_frames):
|
| 204 |
+
cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
|
| 205 |
+
audio_clip = audio_feats[bb:bb+1,
|
| 206 |
+
cur_t: cur_t+audio_feature_length_per_frame]
|
| 207 |
+
|
| 208 |
+
audio_feats_list.append(audio_clip)
|
| 209 |
+
audio_feats_list = torch.stack(audio_feats_list, 1)
|
| 210 |
+
audio_prompts.append(audio_feats_list)
|
| 211 |
+
audio_prompts = torch.cat(audio_prompts) # B, T, 10, 5, 384
|
| 212 |
+
return audio_prompts
|
| 213 |
+
|
| 214 |
+
def save_checkpoint(model, save_dir, ckpt_num, name="appearance_net", total_limit=None, logger=None):
|
| 215 |
+
save_path = os.path.join(save_dir, f"{name}-{ckpt_num}.pth")
|
| 216 |
+
|
| 217 |
+
if total_limit is not None:
|
| 218 |
+
checkpoints = os.listdir(save_dir)
|
| 219 |
+
checkpoints = [d for d in checkpoints if d.endswith(".pth")]
|
| 220 |
+
checkpoints = [d for d in checkpoints if name in d]
|
| 221 |
+
checkpoints = sorted(
|
| 222 |
+
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
if len(checkpoints) >= total_limit:
|
| 226 |
+
num_to_remove = len(checkpoints) - total_limit + 1
|
| 227 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 228 |
+
logger.info(
|
| 229 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
| 230 |
+
)
|
| 231 |
+
logger.info(
|
| 232 |
+
f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
| 233 |
+
|
| 234 |
+
for removing_checkpoint in removing_checkpoints:
|
| 235 |
+
removing_checkpoint = os.path.join(
|
| 236 |
+
save_dir, removing_checkpoint)
|
| 237 |
+
os.remove(removing_checkpoint)
|
| 238 |
+
|
| 239 |
+
state_dict = model.state_dict()
|
| 240 |
+
torch.save(state_dict, save_path)
|
| 241 |
+
|
| 242 |
+
def save_models(accelerator, net, save_dir, global_step, cfg, logger=None):
|
| 243 |
+
unwarp_net = accelerator.unwrap_model(net)
|
| 244 |
+
save_checkpoint(
|
| 245 |
+
unwarp_net.unet,
|
| 246 |
+
save_dir,
|
| 247 |
+
global_step,
|
| 248 |
+
name="unet",
|
| 249 |
+
total_limit=cfg.total_limit,
|
| 250 |
+
logger=logger
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
def delete_additional_ckpt(base_path, num_keep):
|
| 254 |
+
dirs = []
|
| 255 |
+
for d in os.listdir(base_path):
|
| 256 |
+
if d.startswith("checkpoint-"):
|
| 257 |
+
dirs.append(d)
|
| 258 |
+
num_tot = len(dirs)
|
| 259 |
+
if num_tot <= num_keep:
|
| 260 |
+
return
|
| 261 |
+
# ensure ckpt is sorted and delete the ealier!
|
| 262 |
+
del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
|
| 263 |
+
for d in del_dirs:
|
| 264 |
+
path_to_dir = osp.join(base_path, d)
|
| 265 |
+
if osp.exists(path_to_dir):
|
| 266 |
+
shutil.rmtree(path_to_dir)
|
| 267 |
+
|
| 268 |
+
def seed_everything(seed):
|
| 269 |
+
import random
|
| 270 |
+
|
| 271 |
+
import numpy as np
|
| 272 |
+
|
| 273 |
+
torch.manual_seed(seed)
|
| 274 |
+
torch.cuda.manual_seed_all(seed)
|
| 275 |
+
np.random.seed(seed % (2**32))
|
| 276 |
+
random.seed(seed)
|
| 277 |
+
|
| 278 |
+
def process_and_save_images(
|
| 279 |
+
batch,
|
| 280 |
+
image_pred,
|
| 281 |
+
image_pred_infer,
|
| 282 |
+
save_dir,
|
| 283 |
+
global_step,
|
| 284 |
+
accelerator,
|
| 285 |
+
num_images_to_keep=10,
|
| 286 |
+
syncnet_score=1
|
| 287 |
+
):
|
| 288 |
+
# Rearrange the tensors
|
| 289 |
+
print("image_pred.shape: ", image_pred.shape)
|
| 290 |
+
pixel_values_ref_img = rearrange(batch['pixel_values_ref_img'], "b f c h w -> (b f) c h w")
|
| 291 |
+
pixel_values = rearrange(batch["pixel_values_vid"], 'b f c h w -> (b f) c h w')
|
| 292 |
+
|
| 293 |
+
# Create masked pixel values
|
| 294 |
+
masked_pixel_values = batch["pixel_values_vid"].clone()
|
| 295 |
+
_, _, _, h, _ = batch["pixel_values_vid"].shape
|
| 296 |
+
masked_pixel_values[:, :, :, h//2:, :] = -1
|
| 297 |
+
masked_pixel_values = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
|
| 298 |
+
|
| 299 |
+
# Keep only the specified number of images
|
| 300 |
+
pixel_values = pixel_values[:num_images_to_keep, :, :, :]
|
| 301 |
+
masked_pixel_values = masked_pixel_values[:num_images_to_keep, :, :, :]
|
| 302 |
+
pixel_values_ref_img = pixel_values_ref_img[:num_images_to_keep, :, :, :]
|
| 303 |
+
image_pred = image_pred.detach()[:num_images_to_keep, :, :, :]
|
| 304 |
+
image_pred_infer = image_pred_infer.detach()[:num_images_to_keep, :, :, :]
|
| 305 |
+
|
| 306 |
+
# Concatenate images
|
| 307 |
+
concat = torch.cat([
|
| 308 |
+
masked_pixel_values * 0.5 + 0.5,
|
| 309 |
+
pixel_values_ref_img * 0.5 + 0.5,
|
| 310 |
+
image_pred * 0.5 + 0.5,
|
| 311 |
+
pixel_values * 0.5 + 0.5,
|
| 312 |
+
image_pred_infer * 0.5 + 0.5,
|
| 313 |
+
], dim=2)
|
| 314 |
+
print("concat.shape: ", concat.shape)
|
| 315 |
+
|
| 316 |
+
# Create the save directory if it doesn't exist
|
| 317 |
+
os.makedirs(f'{save_dir}/samples/', exist_ok=True)
|
| 318 |
+
|
| 319 |
+
# Try to save the concatenated image
|
| 320 |
+
try:
|
| 321 |
+
# Concatenate images horizontally and convert to numpy array
|
| 322 |
+
final_image = torch.cat([concat[i] for i in range(concat.shape[0])], dim=-1).permute(1, 2, 0).cpu().numpy()[:, :, [2, 1, 0]] * 255
|
| 323 |
+
# Save the image
|
| 324 |
+
cv2.imwrite(f'{save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg', final_image)
|
| 325 |
+
print(f"Image saved successfully: {save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg")
|
| 326 |
+
except Exception as e:
|
| 327 |
+
print(f"Failed to save image: {e}")
|
scripts/preprocess.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import subprocess
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
from typing import Tuple, List, Union
|
| 6 |
+
import decord
|
| 7 |
+
import json
|
| 8 |
+
import cv2
|
| 9 |
+
from musetalk.utils.face_detection import FaceAlignment,LandmarksType
|
| 10 |
+
from mmpose.apis import inference_topdown, init_model
|
| 11 |
+
from mmpose.structures import merge_data_samples
|
| 12 |
+
import torch
|
| 13 |
+
import numpy as np
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
ffmpeg_path = "./ffmpeg-4.4-amd64-static/"
|
| 17 |
+
if ffmpeg_path not in os.getenv('PATH'):
|
| 18 |
+
print("add ffmpeg to path")
|
| 19 |
+
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
|
| 20 |
+
|
| 21 |
+
class AnalyzeFace:
|
| 22 |
+
def __init__(self, device: Union[str, torch.device], config_file: str, checkpoint_file: str):
|
| 23 |
+
"""
|
| 24 |
+
Initialize the AnalyzeFace class with the given device, config file, and checkpoint file.
|
| 25 |
+
|
| 26 |
+
Parameters:
|
| 27 |
+
device (Union[str, torch.device]): The device to run the models on ('cuda' or 'cpu').
|
| 28 |
+
config_file (str): Path to the mmpose model configuration file.
|
| 29 |
+
checkpoint_file (str): Path to the mmpose model checkpoint file.
|
| 30 |
+
"""
|
| 31 |
+
self.device = device
|
| 32 |
+
self.dwpose = init_model(config_file, checkpoint_file, device=self.device)
|
| 33 |
+
self.facedet = FaceAlignment(LandmarksType._2D, flip_input=False, device=self.device)
|
| 34 |
+
|
| 35 |
+
def __call__(self, im: np.ndarray) -> Tuple[List[np.ndarray], np.ndarray]:
|
| 36 |
+
"""
|
| 37 |
+
Detect faces and keypoints in the given image.
|
| 38 |
+
|
| 39 |
+
Parameters:
|
| 40 |
+
im (np.ndarray): The input image.
|
| 41 |
+
maxface (bool): Whether to detect the maximum face. Default is True.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Tuple[List[np.ndarray], np.ndarray]: A tuple containing the bounding boxes and keypoints.
|
| 45 |
+
"""
|
| 46 |
+
try:
|
| 47 |
+
# Ensure the input image has the correct shape
|
| 48 |
+
if im.ndim == 3:
|
| 49 |
+
im = np.expand_dims(im, axis=0)
|
| 50 |
+
elif im.ndim != 4 or im.shape[0] != 1:
|
| 51 |
+
raise ValueError("Input image must have shape (1, H, W, C)")
|
| 52 |
+
|
| 53 |
+
bbox = self.facedet.get_detections_for_batch(np.asarray(im))
|
| 54 |
+
results = inference_topdown(self.dwpose, np.asarray(im)[0])
|
| 55 |
+
results = merge_data_samples(results)
|
| 56 |
+
keypoints = results.pred_instances.keypoints
|
| 57 |
+
face_land_mark= keypoints[0][23:91]
|
| 58 |
+
face_land_mark = face_land_mark.astype(np.int32)
|
| 59 |
+
|
| 60 |
+
return face_land_mark, bbox
|
| 61 |
+
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(f"Error during face analysis: {e}")
|
| 64 |
+
return np.array([]),[]
|
| 65 |
+
|
| 66 |
+
def convert_video(org_path: str, dst_path: str, vid_list: List[str]) -> None:
|
| 67 |
+
|
| 68 |
+
"""
|
| 69 |
+
Convert video files to a specified format and save them to the destination path.
|
| 70 |
+
|
| 71 |
+
Parameters:
|
| 72 |
+
org_path (str): The directory containing the original video files.
|
| 73 |
+
dst_path (str): The directory where the converted video files will be saved.
|
| 74 |
+
vid_list (List[str]): A list of video file names to process.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
None
|
| 78 |
+
"""
|
| 79 |
+
for idx, vid in enumerate(vid_list):
|
| 80 |
+
if vid.endswith('.mp4'):
|
| 81 |
+
org_vid_path = os.path.join(org_path, vid)
|
| 82 |
+
dst_vid_path = os.path.join(dst_path, vid)
|
| 83 |
+
|
| 84 |
+
if org_vid_path != dst_vid_path:
|
| 85 |
+
cmd = [
|
| 86 |
+
"ffmpeg", "-hide_banner", "-y", "-i", org_vid_path,
|
| 87 |
+
"-r", "25", "-crf", "15", "-c:v", "libx264",
|
| 88 |
+
"-pix_fmt", "yuv420p", dst_vid_path
|
| 89 |
+
]
|
| 90 |
+
subprocess.run(cmd, check=True)
|
| 91 |
+
|
| 92 |
+
if idx % 1000 == 0:
|
| 93 |
+
print(f"### {idx} videos converted ###")
|
| 94 |
+
|
| 95 |
+
def segment_video(org_path: str, dst_path: str, vid_list: List[str], segment_duration: int = 30) -> None:
|
| 96 |
+
"""
|
| 97 |
+
Segment video files into smaller clips of specified duration.
|
| 98 |
+
|
| 99 |
+
Parameters:
|
| 100 |
+
org_path (str): The directory containing the original video files.
|
| 101 |
+
dst_path (str): The directory where the segmented video files will be saved.
|
| 102 |
+
vid_list (List[str]): A list of video file names to process.
|
| 103 |
+
segment_duration (int): The duration of each segment in seconds. Default is 30 seconds.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
None
|
| 107 |
+
"""
|
| 108 |
+
for idx, vid in enumerate(vid_list):
|
| 109 |
+
if vid.endswith('.mp4'):
|
| 110 |
+
input_file = os.path.join(org_path, vid)
|
| 111 |
+
original_filename = os.path.basename(input_file)
|
| 112 |
+
|
| 113 |
+
command = [
|
| 114 |
+
'ffmpeg', '-i', input_file, '-c', 'copy', '-map', '0',
|
| 115 |
+
'-segment_time', str(segment_duration), '-f', 'segment',
|
| 116 |
+
'-reset_timestamps', '1',
|
| 117 |
+
os.path.join(dst_path, f'clip%03d_{original_filename}')
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
subprocess.run(command, check=True)
|
| 121 |
+
|
| 122 |
+
def extract_audio(org_path: str, dst_path: str, vid_list: List[str]) -> None:
|
| 123 |
+
"""
|
| 124 |
+
Extract audio from video files and save as WAV format.
|
| 125 |
+
|
| 126 |
+
Parameters:
|
| 127 |
+
org_path (str): The directory containing the original video files.
|
| 128 |
+
dst_path (str): The directory where the extracted audio files will be saved.
|
| 129 |
+
vid_list (List[str]): A list of video file names to process.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
None
|
| 133 |
+
"""
|
| 134 |
+
for idx, vid in enumerate(vid_list):
|
| 135 |
+
if vid.endswith('.mp4'):
|
| 136 |
+
video_path = os.path.join(org_path, vid)
|
| 137 |
+
audio_output_path = os.path.join(dst_path, os.path.splitext(vid)[0] + ".wav")
|
| 138 |
+
try:
|
| 139 |
+
command = [
|
| 140 |
+
'ffmpeg', '-hide_banner', '-y', '-i', video_path,
|
| 141 |
+
'-vn', '-acodec', 'pcm_s16le', '-f', 'wav',
|
| 142 |
+
'-ar', '16000', '-ac', '1', audio_output_path,
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
subprocess.run(command, check=True)
|
| 146 |
+
print(f"Audio saved to: {audio_output_path}")
|
| 147 |
+
except subprocess.CalledProcessError as e:
|
| 148 |
+
print(f"Error extracting audio from {vid}: {e}")
|
| 149 |
+
|
| 150 |
+
def split_data(video_files: List[str], val_list_hdtf: List[str]) -> (List[str], List[str]):
|
| 151 |
+
"""
|
| 152 |
+
Split video files into training and validation sets based on val_list_hdtf.
|
| 153 |
+
|
| 154 |
+
Parameters:
|
| 155 |
+
video_files (List[str]): A list of video file names.
|
| 156 |
+
val_list_hdtf (List[str]): A list of validation file identifiers.
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
(List[str], List[str]): A tuple containing the training and validation file lists.
|
| 160 |
+
"""
|
| 161 |
+
val_files = [f for f in video_files if any(val_id in f for val_id in val_list_hdtf)]
|
| 162 |
+
train_files = [f for f in video_files if f not in val_files]
|
| 163 |
+
return train_files, val_files
|
| 164 |
+
|
| 165 |
+
def save_list_to_file(file_path: str, data_list: List[str]) -> None:
|
| 166 |
+
"""
|
| 167 |
+
Save a list of strings to a file, each string on a new line.
|
| 168 |
+
|
| 169 |
+
Parameters:
|
| 170 |
+
file_path (str): The path to the file where the list will be saved.
|
| 171 |
+
data_list (List[str]): The list of strings to save.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
None
|
| 175 |
+
"""
|
| 176 |
+
with open(file_path, 'w') as file:
|
| 177 |
+
for item in data_list:
|
| 178 |
+
file.write(f"{item}\n")
|
| 179 |
+
|
| 180 |
+
def generate_train_list(cfg):
|
| 181 |
+
train_file_path = cfg.video_clip_file_list_train
|
| 182 |
+
val_file_path = cfg.video_clip_file_list_val
|
| 183 |
+
val_list_hdtf = cfg.val_list_hdtf
|
| 184 |
+
|
| 185 |
+
meta_list = os.listdir(cfg.meta_root)
|
| 186 |
+
|
| 187 |
+
sorted_meta_list = sorted(meta_list)
|
| 188 |
+
train_files, val_files = split_data(meta_list, val_list_hdtf)
|
| 189 |
+
|
| 190 |
+
save_list_to_file(train_file_path, train_files)
|
| 191 |
+
save_list_to_file(val_file_path, val_files)
|
| 192 |
+
|
| 193 |
+
print(val_list_hdtf)
|
| 194 |
+
|
| 195 |
+
def analyze_video(org_path: str, dst_path: str, vid_list: List[str]) -> None:
|
| 196 |
+
"""
|
| 197 |
+
Convert video files to a specified format and save them to the destination path.
|
| 198 |
+
|
| 199 |
+
Parameters:
|
| 200 |
+
org_path (str): The directory containing the original video files.
|
| 201 |
+
dst_path (str): The directory where the meta json will be saved.
|
| 202 |
+
vid_list (List[str]): A list of video file names to process.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
None
|
| 206 |
+
"""
|
| 207 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 208 |
+
config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
|
| 209 |
+
checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth'
|
| 210 |
+
|
| 211 |
+
analyze_face = AnalyzeFace(device, config_file, checkpoint_file)
|
| 212 |
+
|
| 213 |
+
for vid in tqdm(vid_list, desc="Processing videos"):
|
| 214 |
+
#vid = "clip005_WDA_BernieSanders_000.mp4"
|
| 215 |
+
#print(vid)
|
| 216 |
+
if vid.endswith('.mp4'):
|
| 217 |
+
vid_path = os.path.join(org_path, vid)
|
| 218 |
+
wav_path = vid_path.replace(".mp4",".wav")
|
| 219 |
+
vid_meta = os.path.join(dst_path, os.path.splitext(vid)[0] + ".json")
|
| 220 |
+
if os.path.exists(vid_meta):
|
| 221 |
+
continue
|
| 222 |
+
print('process video {}'.format(vid))
|
| 223 |
+
|
| 224 |
+
total_bbox_list = []
|
| 225 |
+
total_pts_list = []
|
| 226 |
+
isvalid = True
|
| 227 |
+
|
| 228 |
+
# process
|
| 229 |
+
try:
|
| 230 |
+
cap = decord.VideoReader(vid_path, fault_tol=1)
|
| 231 |
+
except Exception as e:
|
| 232 |
+
print(e)
|
| 233 |
+
continue
|
| 234 |
+
|
| 235 |
+
total_frames = len(cap)
|
| 236 |
+
for frame_idx in range(total_frames):
|
| 237 |
+
frame = cap[frame_idx]
|
| 238 |
+
if frame_idx==0:
|
| 239 |
+
video_height,video_width,_ = frame.shape
|
| 240 |
+
frame_bgr = cv2.cvtColor(frame.asnumpy(), cv2.COLOR_BGR2RGB)
|
| 241 |
+
pts_list, bbox_list = analyze_face(frame_bgr)
|
| 242 |
+
|
| 243 |
+
if len(bbox_list)>0 and None not in bbox_list:
|
| 244 |
+
bbox = bbox_list[0]
|
| 245 |
+
else:
|
| 246 |
+
isvalid = False
|
| 247 |
+
bbox = []
|
| 248 |
+
print(f"set isvalid to False as broken img in {frame_idx} of {vid}")
|
| 249 |
+
break
|
| 250 |
+
|
| 251 |
+
#print(pts_list)
|
| 252 |
+
if len(pts_list)>0 and pts_list is not None:
|
| 253 |
+
pts = pts_list.tolist()
|
| 254 |
+
else:
|
| 255 |
+
isvalid = False
|
| 256 |
+
pts = []
|
| 257 |
+
break
|
| 258 |
+
|
| 259 |
+
if frame_idx==0:
|
| 260 |
+
x1,y1,x2,y2 = bbox
|
| 261 |
+
face_height, face_width = y2-y1,x2-x1
|
| 262 |
+
|
| 263 |
+
total_pts_list.append(pts)
|
| 264 |
+
total_bbox_list.append(bbox)
|
| 265 |
+
|
| 266 |
+
meta_data = {
|
| 267 |
+
"mp4_path": vid_path,
|
| 268 |
+
"wav_path": wav_path,
|
| 269 |
+
"video_size": [video_height, video_width],
|
| 270 |
+
"face_size": [face_height, face_width],
|
| 271 |
+
"frames": total_frames,
|
| 272 |
+
"face_list": total_bbox_list,
|
| 273 |
+
"landmark_list": total_pts_list,
|
| 274 |
+
"isvalid":isvalid,
|
| 275 |
+
}
|
| 276 |
+
with open(vid_meta, 'w') as f:
|
| 277 |
+
json.dump(meta_data, f, indent=4)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def main(cfg):
|
| 282 |
+
# Ensure all necessary directories exist
|
| 283 |
+
os.makedirs(cfg.video_root_25fps, exist_ok=True)
|
| 284 |
+
os.makedirs(cfg.video_audio_clip_root, exist_ok=True)
|
| 285 |
+
os.makedirs(cfg.meta_root, exist_ok=True)
|
| 286 |
+
os.makedirs(os.path.dirname(cfg.video_file_list), exist_ok=True)
|
| 287 |
+
os.makedirs(os.path.dirname(cfg.video_clip_file_list_train), exist_ok=True)
|
| 288 |
+
os.makedirs(os.path.dirname(cfg.video_clip_file_list_val), exist_ok=True)
|
| 289 |
+
|
| 290 |
+
vid_list = os.listdir(cfg.video_root_raw)
|
| 291 |
+
sorted_vid_list = sorted(vid_list)
|
| 292 |
+
|
| 293 |
+
# Save video file list
|
| 294 |
+
with open(cfg.video_file_list, 'w') as file:
|
| 295 |
+
for vid in sorted_vid_list:
|
| 296 |
+
file.write(vid + '\n')
|
| 297 |
+
|
| 298 |
+
# 1. Convert videos to 25 FPS
|
| 299 |
+
convert_video(cfg.video_root_raw, cfg.video_root_25fps, sorted_vid_list)
|
| 300 |
+
|
| 301 |
+
# 2. Segment videos into 30-second clips
|
| 302 |
+
segment_video(cfg.video_root_25fps, cfg.video_audio_clip_root, vid_list, segment_duration=cfg.clip_len_second)
|
| 303 |
+
|
| 304 |
+
# 3. Extract audio
|
| 305 |
+
clip_vid_list = os.listdir(cfg.video_audio_clip_root)
|
| 306 |
+
extract_audio(cfg.video_audio_clip_root, cfg.video_audio_clip_root, clip_vid_list)
|
| 307 |
+
|
| 308 |
+
# 4. Generate video metadata
|
| 309 |
+
analyze_video(cfg.video_audio_clip_root, cfg.meta_root, clip_vid_list)
|
| 310 |
+
|
| 311 |
+
# 5. Generate training and validation set lists
|
| 312 |
+
generate_train_list(cfg)
|
| 313 |
+
print("done")
|
| 314 |
+
|
| 315 |
+
if __name__ == "__main__":
|
| 316 |
+
parser = argparse.ArgumentParser()
|
| 317 |
+
parser.add_argument("--config", type=str, default="./configs/training/preprocess.yaml")
|
| 318 |
+
args = parser.parse_args()
|
| 319 |
+
config = OmegaConf.load(args.config)
|
| 320 |
+
|
| 321 |
+
main(config)
|
| 322 |
+
|
train.py
ADDED
|
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import diffusers
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch.utils.checkpoint
|
| 11 |
+
import transformers
|
| 12 |
+
import warnings
|
| 13 |
+
import random
|
| 14 |
+
|
| 15 |
+
from accelerate import Accelerator
|
| 16 |
+
from accelerate.utils import LoggerType
|
| 17 |
+
from accelerate import InitProcessGroupKwargs
|
| 18 |
+
from accelerate.logging import get_logger
|
| 19 |
+
from accelerate.utils import DistributedDataParallelKwargs
|
| 20 |
+
from datetime import datetime
|
| 21 |
+
from datetime import timedelta
|
| 22 |
+
|
| 23 |
+
from diffusers.utils import check_min_version
|
| 24 |
+
from einops import rearrange
|
| 25 |
+
from omegaconf import OmegaConf
|
| 26 |
+
from tqdm.auto import tqdm
|
| 27 |
+
|
| 28 |
+
from musetalk.utils.utils import (
|
| 29 |
+
delete_additional_ckpt,
|
| 30 |
+
seed_everything,
|
| 31 |
+
get_mouth_region,
|
| 32 |
+
process_audio_features,
|
| 33 |
+
save_models
|
| 34 |
+
)
|
| 35 |
+
from musetalk.loss.basic_loss import set_requires_grad
|
| 36 |
+
from musetalk.loss.syncnet import get_sync_loss
|
| 37 |
+
from musetalk.utils.training_utils import (
|
| 38 |
+
initialize_models_and_optimizers,
|
| 39 |
+
initialize_dataloaders,
|
| 40 |
+
initialize_loss_functions,
|
| 41 |
+
initialize_syncnet,
|
| 42 |
+
initialize_vgg,
|
| 43 |
+
validation
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
logger = get_logger(__name__, log_level="INFO")
|
| 47 |
+
warnings.filterwarnings("ignore")
|
| 48 |
+
check_min_version("0.10.0.dev0")
|
| 49 |
+
|
| 50 |
+
def main(cfg):
|
| 51 |
+
exp_name = cfg.exp_name
|
| 52 |
+
save_dir = f"{cfg.output_dir}/{exp_name}"
|
| 53 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 54 |
+
|
| 55 |
+
kwargs = DistributedDataParallelKwargs()
|
| 56 |
+
process_group_kwargs = InitProcessGroupKwargs(
|
| 57 |
+
timeout=timedelta(seconds=5400))
|
| 58 |
+
accelerator = Accelerator(
|
| 59 |
+
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
|
| 60 |
+
log_with=["tensorboard", LoggerType.TENSORBOARD],
|
| 61 |
+
project_dir=os.path.join(save_dir, "./tensorboard"),
|
| 62 |
+
kwargs_handlers=[kwargs, process_group_kwargs],
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Make one log on every process with the configuration for debugging.
|
| 66 |
+
logging.basicConfig(
|
| 67 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 68 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 69 |
+
level=logging.INFO,
|
| 70 |
+
)
|
| 71 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 72 |
+
if accelerator.is_local_main_process:
|
| 73 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 74 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 75 |
+
else:
|
| 76 |
+
transformers.utils.logging.set_verbosity_error()
|
| 77 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 78 |
+
|
| 79 |
+
# If passed along, set the training seed now.
|
| 80 |
+
if cfg.seed is not None:
|
| 81 |
+
print('cfg.seed', cfg.seed, accelerator.process_index)
|
| 82 |
+
seed_everything(cfg.seed + accelerator.process_index)
|
| 83 |
+
|
| 84 |
+
weight_dtype = torch.float32
|
| 85 |
+
|
| 86 |
+
model_dict = initialize_models_and_optimizers(cfg, accelerator, weight_dtype)
|
| 87 |
+
dataloader_dict = initialize_dataloaders(cfg)
|
| 88 |
+
loss_dict = initialize_loss_functions(cfg, accelerator, model_dict['scheduler_max_steps'])
|
| 89 |
+
syncnet = initialize_syncnet(cfg, accelerator, weight_dtype)
|
| 90 |
+
vgg_IN, pyramid, downsampler = initialize_vgg(cfg, accelerator)
|
| 91 |
+
|
| 92 |
+
# Prepare everything with our `accelerator`.
|
| 93 |
+
model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader'] = accelerator.prepare(
|
| 94 |
+
model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader']
|
| 95 |
+
)
|
| 96 |
+
print("length train/val", len(dataloader_dict['train_dataloader']), len(dataloader_dict['val_dataloader']))
|
| 97 |
+
|
| 98 |
+
# Calculate training steps and epochs
|
| 99 |
+
num_update_steps_per_epoch = math.ceil(
|
| 100 |
+
len(dataloader_dict['train_dataloader']) / cfg.solver.gradient_accumulation_steps
|
| 101 |
+
)
|
| 102 |
+
num_train_epochs = math.ceil(
|
| 103 |
+
cfg.solver.max_train_steps / num_update_steps_per_epoch
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Initialize trackers on the main process
|
| 107 |
+
if accelerator.is_main_process:
|
| 108 |
+
run_time = datetime.now().strftime("%Y%m%d-%H%M")
|
| 109 |
+
accelerator.init_trackers(
|
| 110 |
+
cfg.exp_name,
|
| 111 |
+
init_kwargs={"mlflow": {"run_name": run_time}},
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Calculate total batch size
|
| 115 |
+
total_batch_size = (
|
| 116 |
+
cfg.data.train_bs
|
| 117 |
+
* accelerator.num_processes
|
| 118 |
+
* cfg.solver.gradient_accumulation_steps
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Log training information
|
| 122 |
+
logger.info("***** Running training *****")
|
| 123 |
+
logger.info(f"Num Epochs = {num_train_epochs}")
|
| 124 |
+
logger.info(f"Instantaneous batch size per device = {cfg.data.train_bs}")
|
| 125 |
+
logger.info(
|
| 126 |
+
f"Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
| 127 |
+
)
|
| 128 |
+
logger.info(
|
| 129 |
+
f"Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}")
|
| 130 |
+
logger.info(f"Total optimization steps = {cfg.solver.max_train_steps}")
|
| 131 |
+
|
| 132 |
+
global_step = 0
|
| 133 |
+
first_epoch = 0
|
| 134 |
+
|
| 135 |
+
# Load checkpoint if resuming training
|
| 136 |
+
if cfg.resume_from_checkpoint:
|
| 137 |
+
resume_dir = save_dir
|
| 138 |
+
dirs = os.listdir(resume_dir)
|
| 139 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 140 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 141 |
+
if len(dirs) > 0:
|
| 142 |
+
path = dirs[-1]
|
| 143 |
+
accelerator.load_state(os.path.join(resume_dir, path))
|
| 144 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 145 |
+
global_step = int(path.split("-")[1])
|
| 146 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 147 |
+
resume_step = global_step % num_update_steps_per_epoch
|
| 148 |
+
|
| 149 |
+
# Initialize progress bar
|
| 150 |
+
progress_bar = tqdm(
|
| 151 |
+
range(global_step, cfg.solver.max_train_steps),
|
| 152 |
+
disable=not accelerator.is_local_main_process,
|
| 153 |
+
)
|
| 154 |
+
progress_bar.set_description("Steps")
|
| 155 |
+
|
| 156 |
+
# Log model types
|
| 157 |
+
print("log type of models")
|
| 158 |
+
print("unet", model_dict['unet'].dtype)
|
| 159 |
+
print("vae", model_dict['vae'].dtype)
|
| 160 |
+
print("wav2vec", model_dict['wav2vec'].dtype)
|
| 161 |
+
|
| 162 |
+
def get_ganloss_weight(step):
|
| 163 |
+
"""Calculate GAN loss weight based on training step"""
|
| 164 |
+
if step < cfg.discriminator_train_params.start_gan:
|
| 165 |
+
return 0.0
|
| 166 |
+
else:
|
| 167 |
+
return 1.0
|
| 168 |
+
|
| 169 |
+
# Training loop
|
| 170 |
+
for epoch in range(first_epoch, num_train_epochs):
|
| 171 |
+
# Set models to training mode
|
| 172 |
+
model_dict['unet'].train()
|
| 173 |
+
if cfg.loss_params.gan_loss > 0:
|
| 174 |
+
loss_dict['discriminator'].train()
|
| 175 |
+
if cfg.loss_params.mouth_gan_loss > 0:
|
| 176 |
+
loss_dict['mouth_discriminator'].train()
|
| 177 |
+
|
| 178 |
+
# Initialize loss accumulators
|
| 179 |
+
train_loss = 0.0
|
| 180 |
+
train_loss_D = 0.0
|
| 181 |
+
train_loss_D_mouth = 0.0
|
| 182 |
+
l1_loss_accum = 0.0
|
| 183 |
+
vgg_loss_accum = 0.0
|
| 184 |
+
gan_loss_accum = 0.0
|
| 185 |
+
gan_loss_accum_mouth = 0.0
|
| 186 |
+
fm_loss_accum = 0.0
|
| 187 |
+
sync_loss_accum = 0.0
|
| 188 |
+
adapted_weight_accum = 0.0
|
| 189 |
+
|
| 190 |
+
t_data_start = time.time()
|
| 191 |
+
for step, batch in enumerate(dataloader_dict['train_dataloader']):
|
| 192 |
+
t_data = time.time() - t_data_start
|
| 193 |
+
t_model_start = time.time()
|
| 194 |
+
|
| 195 |
+
with torch.no_grad():
|
| 196 |
+
# Process input data
|
| 197 |
+
pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
|
| 198 |
+
accelerator.device,
|
| 199 |
+
non_blocking=True
|
| 200 |
+
)
|
| 201 |
+
bsz, num_frames, c, h, w = pixel_values.shape
|
| 202 |
+
|
| 203 |
+
# Process reference images
|
| 204 |
+
ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
|
| 205 |
+
accelerator.device,
|
| 206 |
+
non_blocking=True
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Get face mask for GAN
|
| 210 |
+
pixel_values_face_mask = batch['pixel_values_face_mask']
|
| 211 |
+
|
| 212 |
+
# Process audio features
|
| 213 |
+
audio_prompts = process_audio_features(cfg, batch, model_dict['wav2vec'], bsz, num_frames, weight_dtype)
|
| 214 |
+
|
| 215 |
+
# Initialize adapted weight
|
| 216 |
+
adapted_weight = 1
|
| 217 |
+
|
| 218 |
+
# Process sync loss if enabled
|
| 219 |
+
if cfg.loss_params.sync_loss > 0:
|
| 220 |
+
mels = batch['mel']
|
| 221 |
+
# Prepare frames for latentsync (combine channels and frames)
|
| 222 |
+
gt_frames = rearrange(pixel_values, 'b f c h w-> b (f c) h w')
|
| 223 |
+
# Use lower half of face for latentsync
|
| 224 |
+
height = gt_frames.shape[2]
|
| 225 |
+
gt_frames = gt_frames[:, :, height // 2:, :]
|
| 226 |
+
|
| 227 |
+
# Get audio embeddings
|
| 228 |
+
audio_embed = syncnet.get_audio_embed(mels)
|
| 229 |
+
|
| 230 |
+
# Calculate adapted weight based on audio-visual similarity
|
| 231 |
+
if cfg.use_adapted_weight:
|
| 232 |
+
vision_embed_gt = syncnet.get_vision_embed(gt_frames)
|
| 233 |
+
image_audio_sim_gt = F.cosine_similarity(
|
| 234 |
+
audio_embed,
|
| 235 |
+
vision_embed_gt,
|
| 236 |
+
dim=1
|
| 237 |
+
)[0]
|
| 238 |
+
|
| 239 |
+
if image_audio_sim_gt < 0.05 or image_audio_sim_gt > 0.65:
|
| 240 |
+
if cfg.adapted_weight_type == "cut_off":
|
| 241 |
+
adapted_weight = 0.0 # Skip this batch
|
| 242 |
+
print(
|
| 243 |
+
f"\nThe i-a similarity in step {global_step} is {image_audio_sim_gt}, set adapted_weight to {adapted_weight}.")
|
| 244 |
+
elif cfg.adapted_weight_type == "linear":
|
| 245 |
+
adapted_weight = image_audio_sim_gt
|
| 246 |
+
else:
|
| 247 |
+
print(f"unknown adapted_weight_type: {cfg.adapted_weight_type}")
|
| 248 |
+
adapted_weight = 1
|
| 249 |
+
|
| 250 |
+
# Random frame selection for memory efficiency
|
| 251 |
+
max_start = 16 - cfg.num_backward_frames
|
| 252 |
+
frames_left_index = random.randint(0, max_start) if max_start > 0 else 0
|
| 253 |
+
frames_right_index = frames_left_index + cfg.num_backward_frames
|
| 254 |
+
else:
|
| 255 |
+
frames_left_index = 0
|
| 256 |
+
frames_right_index = cfg.data.n_sample_frames
|
| 257 |
+
|
| 258 |
+
# Extract frames for backward pass
|
| 259 |
+
pixel_values_backward = pixel_values[:, frames_left_index:frames_right_index, ...]
|
| 260 |
+
ref_pixel_values_backward = ref_pixel_values[:, frames_left_index:frames_right_index, ...]
|
| 261 |
+
pixel_values_face_mask_backward = pixel_values_face_mask[:, frames_left_index:frames_right_index, ...]
|
| 262 |
+
audio_prompts_backward = audio_prompts[:, frames_left_index:frames_right_index, ...]
|
| 263 |
+
|
| 264 |
+
# Encode target images
|
| 265 |
+
frames = rearrange(pixel_values_backward, 'b f c h w-> (b f) c h w')
|
| 266 |
+
latents = model_dict['vae'].encode(frames).latent_dist.mode()
|
| 267 |
+
latents = latents * model_dict['vae'].config.scaling_factor
|
| 268 |
+
latents = latents.float()
|
| 269 |
+
|
| 270 |
+
# Create masked images
|
| 271 |
+
masked_pixel_values = pixel_values_backward.clone()
|
| 272 |
+
masked_pixel_values[:, :, :, h//2:, :] = -1
|
| 273 |
+
masked_frames = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
|
| 274 |
+
masked_latents = model_dict['vae'].encode(masked_frames).latent_dist.mode()
|
| 275 |
+
masked_latents = masked_latents * model_dict['vae'].config.scaling_factor
|
| 276 |
+
masked_latents = masked_latents.float()
|
| 277 |
+
|
| 278 |
+
# Encode reference images
|
| 279 |
+
ref_frames = rearrange(ref_pixel_values_backward, 'b f c h w-> (b f) c h w')
|
| 280 |
+
ref_latents = model_dict['vae'].encode(ref_frames).latent_dist.mode()
|
| 281 |
+
ref_latents = ref_latents * model_dict['vae'].config.scaling_factor
|
| 282 |
+
ref_latents = ref_latents.float()
|
| 283 |
+
|
| 284 |
+
# Prepare face mask and audio features
|
| 285 |
+
pixel_values_face_mask_backward = rearrange(
|
| 286 |
+
pixel_values_face_mask_backward,
|
| 287 |
+
"b f c h w -> (b f) c h w"
|
| 288 |
+
)
|
| 289 |
+
audio_prompts_backward = rearrange(
|
| 290 |
+
audio_prompts_backward,
|
| 291 |
+
'b f c h w-> (b f) c h w'
|
| 292 |
+
)
|
| 293 |
+
audio_prompts_backward = rearrange(
|
| 294 |
+
audio_prompts_backward,
|
| 295 |
+
'(b f) c h w -> (b f) (c h) w',
|
| 296 |
+
b=bsz
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# Apply reference dropout (currently inactive)
|
| 300 |
+
dropout = nn.Dropout(p=cfg.ref_dropout_rate)
|
| 301 |
+
ref_latents = dropout(ref_latents)
|
| 302 |
+
|
| 303 |
+
# Prepare model inputs
|
| 304 |
+
input_latents = torch.cat([masked_latents, ref_latents], dim=1)
|
| 305 |
+
input_latents = input_latents.to(weight_dtype)
|
| 306 |
+
timesteps = torch.tensor([0], device=input_latents.device)
|
| 307 |
+
|
| 308 |
+
# Forward pass
|
| 309 |
+
latents_pred = model_dict['net'](
|
| 310 |
+
input_latents,
|
| 311 |
+
timesteps,
|
| 312 |
+
audio_prompts_backward,
|
| 313 |
+
)
|
| 314 |
+
latents_pred = (1 / model_dict['vae'].config.scaling_factor) * latents_pred
|
| 315 |
+
image_pred = model_dict['vae'].decode(latents_pred).sample
|
| 316 |
+
|
| 317 |
+
# Convert to float
|
| 318 |
+
image_pred = image_pred.float()
|
| 319 |
+
frames = frames.float()
|
| 320 |
+
|
| 321 |
+
# Calculate L1 loss
|
| 322 |
+
l1_loss = loss_dict['L1_loss'](frames, image_pred)
|
| 323 |
+
l1_loss_accum += l1_loss.item()
|
| 324 |
+
loss = cfg.loss_params.l1_loss * l1_loss * adapted_weight
|
| 325 |
+
|
| 326 |
+
# Process mouth GAN loss if enabled
|
| 327 |
+
if cfg.loss_params.mouth_gan_loss > 0:
|
| 328 |
+
frames_mouth, image_pred_mouth = get_mouth_region(
|
| 329 |
+
frames,
|
| 330 |
+
image_pred,
|
| 331 |
+
pixel_values_face_mask_backward
|
| 332 |
+
)
|
| 333 |
+
pyramide_real_mouth = pyramid(downsampler(frames_mouth))
|
| 334 |
+
pyramide_generated_mouth = pyramid(downsampler(image_pred_mouth))
|
| 335 |
+
|
| 336 |
+
# Process VGG loss if enabled
|
| 337 |
+
if cfg.loss_params.vgg_loss > 0:
|
| 338 |
+
pyramide_real = pyramid(downsampler(frames))
|
| 339 |
+
pyramide_generated = pyramid(downsampler(image_pred))
|
| 340 |
+
|
| 341 |
+
loss_IN = 0
|
| 342 |
+
for scale in cfg.loss_params.pyramid_scale:
|
| 343 |
+
x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)])
|
| 344 |
+
y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)])
|
| 345 |
+
for i, weight in enumerate(cfg.loss_params.vgg_layer_weight):
|
| 346 |
+
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
| 347 |
+
loss_IN += weight * value
|
| 348 |
+
loss_IN /= sum(cfg.loss_params.vgg_layer_weight)
|
| 349 |
+
loss += loss_IN * cfg.loss_params.vgg_loss * adapted_weight
|
| 350 |
+
vgg_loss_accum += loss_IN.item()
|
| 351 |
+
|
| 352 |
+
# Process GAN loss if enabled
|
| 353 |
+
if cfg.loss_params.gan_loss > 0:
|
| 354 |
+
set_requires_grad(loss_dict['discriminator'], False)
|
| 355 |
+
loss_G = 0.
|
| 356 |
+
discriminator_maps_generated = loss_dict['discriminator'](pyramide_generated)
|
| 357 |
+
discriminator_maps_real = loss_dict['discriminator'](pyramide_real)
|
| 358 |
+
|
| 359 |
+
for scale in loss_dict['disc_scales']:
|
| 360 |
+
key = 'prediction_map_%s' % scale
|
| 361 |
+
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
|
| 362 |
+
loss_G += value
|
| 363 |
+
gan_loss_accum += loss_G.item()
|
| 364 |
+
|
| 365 |
+
loss += loss_G * cfg.loss_params.gan_loss * get_ganloss_weight(global_step) * adapted_weight
|
| 366 |
+
|
| 367 |
+
# Process feature matching loss if enabled
|
| 368 |
+
if cfg.loss_params.fm_loss[0] > 0:
|
| 369 |
+
L_feature_matching = 0.
|
| 370 |
+
for scale in loss_dict['disc_scales']:
|
| 371 |
+
key = 'feature_maps_%s' % scale
|
| 372 |
+
for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
|
| 373 |
+
value = torch.abs(a - b).mean()
|
| 374 |
+
L_feature_matching += value * cfg.loss_params.fm_loss[i]
|
| 375 |
+
loss += L_feature_matching * adapted_weight
|
| 376 |
+
fm_loss_accum += L_feature_matching.item()
|
| 377 |
+
|
| 378 |
+
# Process mouth GAN loss if enabled
|
| 379 |
+
if cfg.loss_params.mouth_gan_loss > 0:
|
| 380 |
+
set_requires_grad(loss_dict['mouth_discriminator'], False)
|
| 381 |
+
loss_G = 0.
|
| 382 |
+
mouth_discriminator_maps_generated = loss_dict['mouth_discriminator'](pyramide_generated_mouth)
|
| 383 |
+
mouth_discriminator_maps_real = loss_dict['mouth_discriminator'](pyramide_real_mouth)
|
| 384 |
+
|
| 385 |
+
for scale in loss_dict['disc_scales']:
|
| 386 |
+
key = 'prediction_map_%s' % scale
|
| 387 |
+
value = ((1 - mouth_discriminator_maps_generated[key]) ** 2).mean()
|
| 388 |
+
loss_G += value
|
| 389 |
+
gan_loss_accum_mouth += loss_G.item()
|
| 390 |
+
|
| 391 |
+
loss += loss_G * cfg.loss_params.mouth_gan_loss * get_ganloss_weight(global_step) * adapted_weight
|
| 392 |
+
|
| 393 |
+
# Process feature matching loss for mouth if enabled
|
| 394 |
+
if cfg.loss_params.fm_loss[0] > 0:
|
| 395 |
+
L_feature_matching = 0.
|
| 396 |
+
for scale in loss_dict['disc_scales']:
|
| 397 |
+
key = 'feature_maps_%s' % scale
|
| 398 |
+
for i, (a, b) in enumerate(zip(mouth_discriminator_maps_real[key], mouth_discriminator_maps_generated[key])):
|
| 399 |
+
value = torch.abs(a - b).mean()
|
| 400 |
+
L_feature_matching += value * cfg.loss_params.fm_loss[i]
|
| 401 |
+
loss += L_feature_matching * adapted_weight
|
| 402 |
+
fm_loss_accum += L_feature_matching.item()
|
| 403 |
+
|
| 404 |
+
# Process sync loss if enabled
|
| 405 |
+
if cfg.loss_params.sync_loss > 0:
|
| 406 |
+
pred_frames = rearrange(
|
| 407 |
+
image_pred, '(b f) c h w-> b (f c) h w', f=pixel_values_backward.shape[1])
|
| 408 |
+
pred_frames = pred_frames[:, :, height // 2 :, :]
|
| 409 |
+
sync_loss, image_audio_sim_pred = get_sync_loss(
|
| 410 |
+
audio_embed,
|
| 411 |
+
gt_frames,
|
| 412 |
+
pred_frames,
|
| 413 |
+
syncnet,
|
| 414 |
+
adapted_weight,
|
| 415 |
+
frames_left_index=frames_left_index,
|
| 416 |
+
frames_right_index=frames_right_index,
|
| 417 |
+
)
|
| 418 |
+
sync_loss_accum += sync_loss.item()
|
| 419 |
+
loss += sync_loss * cfg.loss_params.sync_loss * adapted_weight
|
| 420 |
+
|
| 421 |
+
# Backward pass
|
| 422 |
+
avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean()
|
| 423 |
+
train_loss += avg_loss.item()
|
| 424 |
+
accelerator.backward(loss)
|
| 425 |
+
|
| 426 |
+
# Train discriminator if GAN loss is enabled
|
| 427 |
+
if cfg.loss_params.gan_loss > 0:
|
| 428 |
+
set_requires_grad(loss_dict['discriminator'], True)
|
| 429 |
+
loss_D = loss_dict['discriminator_full'](frames, image_pred.detach())
|
| 430 |
+
avg_loss_D = accelerator.gather(loss_D.repeat(cfg.data.train_bs)).mean()
|
| 431 |
+
train_loss_D += avg_loss_D.item() / 1
|
| 432 |
+
loss_D = loss_D * get_ganloss_weight(global_step) * adapted_weight
|
| 433 |
+
accelerator.backward(loss_D)
|
| 434 |
+
|
| 435 |
+
if accelerator.sync_gradients:
|
| 436 |
+
accelerator.clip_grad_norm_(
|
| 437 |
+
loss_dict['discriminator'].parameters(), cfg.solver.max_grad_norm)
|
| 438 |
+
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
|
| 439 |
+
loss_dict['optimizer_D'].step()
|
| 440 |
+
loss_dict['scheduler_D'].step()
|
| 441 |
+
loss_dict['optimizer_D'].zero_grad()
|
| 442 |
+
|
| 443 |
+
# Train mouth discriminator if mouth GAN loss is enabled
|
| 444 |
+
if cfg.loss_params.mouth_gan_loss > 0:
|
| 445 |
+
set_requires_grad(loss_dict['mouth_discriminator'], True)
|
| 446 |
+
mouth_loss_D = loss_dict['mouth_discriminator_full'](
|
| 447 |
+
frames_mouth, image_pred_mouth.detach())
|
| 448 |
+
avg_mouth_loss_D = accelerator.gather(
|
| 449 |
+
mouth_loss_D.repeat(cfg.data.train_bs)).mean()
|
| 450 |
+
train_loss_D_mouth += avg_mouth_loss_D.item() / 1
|
| 451 |
+
mouth_loss_D = mouth_loss_D * get_ganloss_weight(global_step) * adapted_weight
|
| 452 |
+
accelerator.backward(mouth_loss_D)
|
| 453 |
+
|
| 454 |
+
if accelerator.sync_gradients:
|
| 455 |
+
accelerator.clip_grad_norm_(
|
| 456 |
+
loss_dict['mouth_discriminator'].parameters(), cfg.solver.max_grad_norm)
|
| 457 |
+
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
|
| 458 |
+
loss_dict['mouth_optimizer_D'].step()
|
| 459 |
+
loss_dict['mouth_scheduler_D'].step()
|
| 460 |
+
loss_dict['mouth_optimizer_D'].zero_grad()
|
| 461 |
+
|
| 462 |
+
# Update main model
|
| 463 |
+
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
|
| 464 |
+
if accelerator.sync_gradients:
|
| 465 |
+
accelerator.clip_grad_norm_(
|
| 466 |
+
model_dict['trainable_params'],
|
| 467 |
+
cfg.solver.max_grad_norm,
|
| 468 |
+
)
|
| 469 |
+
model_dict['optimizer'].step()
|
| 470 |
+
model_dict['lr_scheduler'].step()
|
| 471 |
+
model_dict['optimizer'].zero_grad()
|
| 472 |
+
|
| 473 |
+
# Update progress and log metrics
|
| 474 |
+
if accelerator.sync_gradients:
|
| 475 |
+
progress_bar.update(1)
|
| 476 |
+
global_step += 1
|
| 477 |
+
accelerator.log({
|
| 478 |
+
"train_loss": train_loss,
|
| 479 |
+
"train_loss_D": train_loss_D,
|
| 480 |
+
"train_loss_D_mouth": train_loss_D_mouth,
|
| 481 |
+
"l1_loss": l1_loss_accum,
|
| 482 |
+
"vgg_loss": vgg_loss_accum,
|
| 483 |
+
"gan_loss": gan_loss_accum,
|
| 484 |
+
"fm_loss": fm_loss_accum,
|
| 485 |
+
"sync_loss": sync_loss_accum,
|
| 486 |
+
"adapted_weight": adapted_weight_accum,
|
| 487 |
+
"lr": model_dict['lr_scheduler'].get_last_lr()[0],
|
| 488 |
+
}, step=global_step)
|
| 489 |
+
|
| 490 |
+
# Reset loss accumulators
|
| 491 |
+
train_loss = 0.0
|
| 492 |
+
l1_loss_accum = 0.0
|
| 493 |
+
vgg_loss_accum = 0.0
|
| 494 |
+
gan_loss_accum = 0.0
|
| 495 |
+
fm_loss_accum = 0.0
|
| 496 |
+
sync_loss_accum = 0.0
|
| 497 |
+
adapted_weight_accum = 0.0
|
| 498 |
+
train_loss_D = 0.0
|
| 499 |
+
train_loss_D_mouth = 0.0
|
| 500 |
+
|
| 501 |
+
# Run validation if needed
|
| 502 |
+
if global_step % cfg.val_freq == 0 or global_step == 10:
|
| 503 |
+
try:
|
| 504 |
+
validation(
|
| 505 |
+
cfg,
|
| 506 |
+
dataloader_dict['val_dataloader'],
|
| 507 |
+
model_dict['net'],
|
| 508 |
+
model_dict['vae'],
|
| 509 |
+
model_dict['wav2vec'],
|
| 510 |
+
accelerator,
|
| 511 |
+
save_dir,
|
| 512 |
+
global_step,
|
| 513 |
+
weight_dtype,
|
| 514 |
+
syncnet_score=adapted_weight,
|
| 515 |
+
)
|
| 516 |
+
except Exception as e:
|
| 517 |
+
print(f"An error occurred during validation: {e}")
|
| 518 |
+
|
| 519 |
+
# Save checkpoint if needed
|
| 520 |
+
if global_step % cfg.checkpointing_steps == 0:
|
| 521 |
+
save_path = os.path.join(save_dir, f"checkpoint-{global_step}")
|
| 522 |
+
try:
|
| 523 |
+
start_time = time.time()
|
| 524 |
+
if accelerator.is_main_process:
|
| 525 |
+
save_models(
|
| 526 |
+
accelerator,
|
| 527 |
+
model_dict['net'],
|
| 528 |
+
save_dir,
|
| 529 |
+
global_step,
|
| 530 |
+
cfg,
|
| 531 |
+
logger=logger
|
| 532 |
+
)
|
| 533 |
+
delete_additional_ckpt(save_dir, cfg.total_limit)
|
| 534 |
+
elapsed_time = time.time() - start_time
|
| 535 |
+
if elapsed_time > 300:
|
| 536 |
+
print(f"Skipping storage as it took too long in step {global_step}.")
|
| 537 |
+
else:
|
| 538 |
+
print(f"Resume states saved at {save_dir} successfully in {elapsed_time}s.")
|
| 539 |
+
except Exception as e:
|
| 540 |
+
print(f"Error when saving model in step {global_step}:", e)
|
| 541 |
+
|
| 542 |
+
# Update progress bar
|
| 543 |
+
t_model = time.time() - t_model_start
|
| 544 |
+
logs = {
|
| 545 |
+
"step_loss": loss.detach().item(),
|
| 546 |
+
"lr": model_dict['lr_scheduler'].get_last_lr()[0],
|
| 547 |
+
"td": f"{t_data:.2f}s",
|
| 548 |
+
"tm": f"{t_model:.2f}s",
|
| 549 |
+
}
|
| 550 |
+
t_data_start = time.time()
|
| 551 |
+
progress_bar.set_postfix(**logs)
|
| 552 |
+
|
| 553 |
+
if global_step >= cfg.solver.max_train_steps:
|
| 554 |
+
break
|
| 555 |
+
|
| 556 |
+
# Save model after each epoch
|
| 557 |
+
if (epoch + 1) % cfg.save_model_epoch_interval == 0:
|
| 558 |
+
try:
|
| 559 |
+
start_time = time.time()
|
| 560 |
+
if accelerator.is_main_process:
|
| 561 |
+
save_models(accelerator, model_dict['net'], save_dir, global_step, cfg)
|
| 562 |
+
accelerator.save_state(save_path)
|
| 563 |
+
elapsed_time = time.time() - start_time
|
| 564 |
+
if elapsed_time > 120:
|
| 565 |
+
print(f"Skipping storage as it took too long in step {global_step}.")
|
| 566 |
+
else:
|
| 567 |
+
print(f"Model saved successfully in {elapsed_time}s.")
|
| 568 |
+
except Exception as e:
|
| 569 |
+
print(f"Error when saving model in step {global_step}:", e)
|
| 570 |
+
accelerator.wait_for_everyone()
|
| 571 |
+
|
| 572 |
+
# End training
|
| 573 |
+
accelerator.end_training()
|
| 574 |
+
|
| 575 |
+
if __name__ == "__main__":
|
| 576 |
+
parser = argparse.ArgumentParser()
|
| 577 |
+
parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml")
|
| 578 |
+
args = parser.parse_args()
|
| 579 |
+
config = OmegaConf.load(args.config)
|
| 580 |
+
main(config)
|
train.sh
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# MuseTalk Training Script
|
| 4 |
+
# This script combines both training stages for the MuseTalk model
|
| 5 |
+
# Usage: sh train.sh [stage1|stage2]
|
| 6 |
+
# Example: sh train.sh stage1 # To run stage 1 training
|
| 7 |
+
# Example: sh train.sh stage2 # To run stage 2 training
|
| 8 |
+
|
| 9 |
+
# Check if stage argument is provided
|
| 10 |
+
if [ $# -ne 1 ]; then
|
| 11 |
+
echo "Error: Please specify the training stage"
|
| 12 |
+
echo "Usage: ./train.sh [stage1|stage2]"
|
| 13 |
+
exit 1
|
| 14 |
+
fi
|
| 15 |
+
|
| 16 |
+
STAGE=$1
|
| 17 |
+
|
| 18 |
+
# Validate stage argument
|
| 19 |
+
if [ "$STAGE" != "stage1" ] && [ "$STAGE" != "stage2" ]; then
|
| 20 |
+
echo "Error: Invalid stage. Must be either 'stage1' or 'stage2'"
|
| 21 |
+
exit 1
|
| 22 |
+
fi
|
| 23 |
+
|
| 24 |
+
# Launch distributed training using accelerate
|
| 25 |
+
# --config_file: Path to the GPU configuration file
|
| 26 |
+
# --main_process_port: Port number for the main process, used for distributed training communication
|
| 27 |
+
# train.py: Training script
|
| 28 |
+
# --config: Path to the training configuration file
|
| 29 |
+
echo "Starting $STAGE training..."
|
| 30 |
+
accelerate launch --config_file ./configs/training/gpu.yaml \
|
| 31 |
+
--main_process_port 29502 \
|
| 32 |
+
train.py --config ./configs/training/$STAGE.yaml
|
| 33 |
+
|
| 34 |
+
echo "Training completed for $STAGE"
|